From 4a02725274ebb24444a7319e1b1175cb62d71e34 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 6 May 2026 12:55:11 -0700 Subject: [PATCH 01/10] adding basics --- pyrit/backend/models/scenarios.py | 113 ++++- pyrit/backend/routes/scenarios.py | 159 ++++++- pyrit/backend/services/__init__.py | 6 + .../backend/services/scenario_run_service.py | 422 +++++++++++++++++ .../unit/backend/test_scenario_run_routes.py | 315 +++++++++++++ .../unit/backend/test_scenario_run_service.py | 425 ++++++++++++++++++ 6 files changed, 1437 insertions(+), 3 deletions(-) create mode 100644 pyrit/backend/services/scenario_run_service.py create mode 100644 tests/unit/backend/test_scenario_run_routes.py create mode 100644 tests/unit/backend/test_scenario_run_service.py diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index a47e431805..5f16aee428 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -5,9 +5,11 @@ Scenario API response models. Scenarios are multi-attack security testing campaigns. These models represent -the metadata about available scenarios (listing), not scenario execution results. +the metadata about available scenarios (listing) and scenario execution (runs). """ +from datetime import datetime +from enum import StrEnum from typing import Optional from pydantic import BaseModel, Field @@ -35,3 +37,112 @@ class ScenarioListResponse(BaseModel): items: list[ScenarioSummary] = Field(..., description="List of scenario summaries") pagination: PaginationInfo = Field(..., description="Pagination metadata") + + +# ============================================================================ +# Scenario Run Models +# ============================================================================ + + +class ScenarioRunStatus(StrEnum): + """Status of a scenario run.""" + + PENDING = "pending" + INITIALIZING = "initializing" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class RunScenarioRequest(BaseModel): + """Request body for starting a scenario run.""" + + scenario_name: str = Field(..., description="Registry key of the scenario to run") + target_name: str = Field(..., description="Name of a registered target from the TargetRegistry") + initializers: list[str] | None = Field( + None, description="Initializer names to run before scenario (e.g., ['target', 'load_default_datasets'])" + ) + strategies: list[str] | None = Field(None, description="Strategy names to use (uses scenario default if omitted)") + dataset_names: list[str] | None = Field( + None, description="Dataset names to use (uses scenario default if omitted)" + ) + max_dataset_size: int | None = Field(None, ge=1, description="Maximum items per dataset") + max_concurrency: int = Field(10, ge=1, le=100, description="Maximum concurrent operations") + max_retries: int = Field(0, ge=0, le=10, description="Maximum retry attempts on failure") + memory_labels: dict[str, str] | None = Field(None, description="Labels to attach to memory entries") + + +class ScenarioRunResult(BaseModel): + """Summary of a completed scenario run's results.""" + + scenario_result_id: str = Field(..., description="UUID of the ScenarioResult in memory") + run_state: str = Field(..., description="Final scenario run state (COMPLETED, FAILED)") + strategies_used: list[str] = Field(..., description="Strategy names that were executed") + total_attacks: int = Field(..., ge=0, description="Total number of atomic attacks") + completed_attacks: int = Field(..., ge=0, description="Number of attacks that completed") + number_tries: int = Field(..., ge=0, description="Number of execution attempts") + completion_time: datetime | None = Field(None, description="When the scenario finished") + + +class ScenarioRunResponse(BaseModel): + """Response for a scenario run (status + optional result).""" + + run_id: str = Field(..., description="Unique identifier for this run") + scenario_name: str = Field(..., description="Registry key of the scenario being run") + status: ScenarioRunStatus = Field(..., description="Current run status") + created_at: datetime = Field(..., description="When the run was created") + updated_at: datetime = Field(..., description="When the run status last changed") + error: str | None = Field(None, description="Error message if status is FAILED") + result: ScenarioRunResult | None = Field(None, description="Result details if status is COMPLETED") + + +class ScenarioRunListResponse(BaseModel): + """Response for listing scenario runs.""" + + items: list[ScenarioRunResponse] = Field(..., description="List of scenario runs") + + +# ============================================================================ +# Scenario Results Detail Models +# ============================================================================ + + +class AttackResultDetail(BaseModel): + """Detailed result of a single attack within a scenario.""" + + attack_result_id: str = Field(..., description="Unique ID of this attack result") + conversation_id: str = Field(..., description="Conversation ID that produced this result") + objective: str = Field(..., description="Natural-language description of the attacker's objective") + outcome: str = Field(..., description="Attack outcome: success, failure, or undetermined") + outcome_reason: str | None = Field(None, description="Reason for the outcome") + last_response: str | None = Field(None, description="Model response from the final turn") + score_value: str | None = Field(None, description="Score value from the objective scorer") + executed_turns: int = Field(0, ge=0, description="Number of turns executed") + execution_time_ms: int = Field(0, ge=0, description="Execution time in milliseconds") + timestamp: datetime | None = Field(None, description="When the result was created") + + +class AtomicAttackResults(BaseModel): + """Results grouped by atomic attack name.""" + + atomic_attack_name: str = Field(..., description="Name of the atomic attack (strategy)") + display_group: str | None = Field(None, description="Display group label for UI grouping") + results: list[AttackResultDetail] = Field(..., description="Individual attack results") + success_count: int = Field(0, ge=0, description="Number of successful attacks") + failure_count: int = Field(0, ge=0, description="Number of failed attacks") + total_count: int = Field(0, ge=0, description="Total number of attack results") + + +class ScenarioResultDetailResponse(BaseModel): + """Full detailed results of a scenario run.""" + + scenario_result_id: str = Field(..., description="UUID of the ScenarioResult") + scenario_name: str = Field(..., description="Name of the scenario") + scenario_version: int = Field(..., description="Version of the scenario") + run_state: str = Field(..., description="Final run state (COMPLETED, FAILED, etc.)") + objective_achieved_rate: int = Field(..., ge=0, le=100, description="Success rate as percentage (0-100)") + number_tries: int = Field(..., ge=0, description="Number of execution attempts") + completion_time: datetime | None = Field(None, description="When the scenario finished") + labels: dict[str, str] = Field(default_factory=dict, description="Labels attached to this run") + attacks: list[AtomicAttackResults] = Field(..., description="Results grouped by atomic attack") diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py index 9cd3e2ef43..010e6ad54b 100644 --- a/pyrit/backend/routes/scenarios.py +++ b/pyrit/backend/routes/scenarios.py @@ -4,7 +4,8 @@ """ Scenario API routes. -Provides endpoints for listing available scenarios and their metadata. +Provides endpoints for listing available scenarios, their metadata, +and managing scenario runs. """ from typing import Optional @@ -12,7 +13,15 @@ from fastapi import APIRouter, HTTPException, Query, status from pyrit.backend.models.common import ProblemDetail -from pyrit.backend.models.scenarios import ScenarioListResponse, ScenarioSummary +from pyrit.backend.models.scenarios import ( + RunScenarioRequest, + ScenarioListResponse, + ScenarioResultDetailResponse, + ScenarioRunListResponse, + ScenarioRunResponse, + ScenarioSummary, +) +from pyrit.backend.services.scenario_run_service import get_scenario_run_service from pyrit.backend.services.scenario_service import get_scenario_service router = APIRouter(prefix="/scenarios", tags=["scenarios"]) @@ -39,6 +48,152 @@ async def list_scenarios( return await service.list_scenarios_async(limit=limit, cursor=cursor) +# ============================================================================ +# Scenario Runs +# ============================================================================ + + +@router.post( + "/runs", + response_model=ScenarioRunResponse, + status_code=status.HTTP_202_ACCEPTED, + responses={ + 400: {"model": ProblemDetail, "description": "Invalid request (bad scenario/target/strategy)"}, + }, +) +async def start_scenario_run(request: RunScenarioRequest) -> ScenarioRunResponse: + """ + Start a new scenario run as a background task. + + Returns immediately with a run_id that can be polled for status. + + Args: + request: Scenario run configuration. + + Returns: + ScenarioRunResponse: Run metadata with PENDING status. + """ + service = get_scenario_run_service() + try: + return await service.start_run_async(request=request) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from None + + +@router.get( + "/runs", + response_model=ScenarioRunListResponse, +) +async def list_scenario_runs() -> ScenarioRunListResponse: + """ + List all tracked scenario runs. + + Returns: + ScenarioRunListResponse: All runs, most recent first. + """ + service = get_scenario_run_service() + return service.list_runs() + + +@router.get( + "/runs/{run_id}", + response_model=ScenarioRunResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Run not found"}, + }, +) +async def get_scenario_run(run_id: str) -> ScenarioRunResponse: + """ + Get the current status and result of a scenario run. + + Args: + run_id: The unique run identifier returned by POST /runs. + + Returns: + ScenarioRunResponse: Current run status (and result if completed). + """ + service = get_scenario_run_service() + run = service.get_run(run_id=run_id) + if run is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Scenario run '{run_id}' not found", + ) + return run + + +@router.delete( + "/runs/{run_id}", + response_model=ScenarioRunResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Run not found"}, + 409: {"model": ProblemDetail, "description": "Run already in terminal state"}, + }, +) +async def cancel_scenario_run(run_id: str) -> ScenarioRunResponse: + """ + Cancel a running scenario. + + Args: + run_id: The unique run identifier to cancel. + + Returns: + ScenarioRunResponse: Updated run with CANCELLED status. + """ + service = get_scenario_run_service() + try: + result = await service.cancel_run_async(run_id=run_id) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from None + + if result is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Scenario run '{run_id}' not found", + ) + return result + + +@router.get( + "/runs/{run_id}/results", + response_model=ScenarioResultDetailResponse, + responses={ + 404: {"model": ProblemDetail, "description": "Run not found"}, + 409: {"model": ProblemDetail, "description": "Run not yet completed"}, + }, +) +async def get_scenario_run_results(run_id: str) -> ScenarioResultDetailResponse: + """ + Get detailed results for a completed scenario run. + + Returns per-attack outcomes including objectives, responses, scores, + and success/failure counts. + + Args: + run_id: The unique run identifier. + + Returns: + ScenarioResultDetailResponse: Full attack-level results. + """ + service = get_scenario_run_service() + try: + result = service.get_run_results(run_id=run_id) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from None + + if result is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Scenario run '{run_id}' not found", + ) + return result + + +# ============================================================================ +# Scenario Detail (catch-all path — must be last) +# ============================================================================ + + @router.get( "/{scenario_name:path}", response_model=ScenarioSummary, diff --git a/pyrit/backend/services/__init__.py b/pyrit/backend/services/__init__.py index 29807150ae..646afb0bf0 100644 --- a/pyrit/backend/services/__init__.py +++ b/pyrit/backend/services/__init__.py @@ -19,6 +19,10 @@ ScenarioService, get_scenario_service, ) +from pyrit.backend.services.scenario_run_service import ( + ScenarioRunService, + get_scenario_run_service, +) from pyrit.backend.services.target_service import ( TargetService, get_target_service, @@ -31,6 +35,8 @@ "get_converter_service", "ScenarioService", "get_scenario_service", + "ScenarioRunService", + "get_scenario_run_service", "TargetService", "get_target_service", ] diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py new file mode 100644 index 0000000000..8ec96c5fc2 --- /dev/null +++ b/pyrit/backend/services/scenario_run_service.py @@ -0,0 +1,422 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario run service for executing scenarios as background tasks. + +Manages the lifecycle of scenario runs: starting, tracking status, +retrieving results, and cancellation. +""" + +import asyncio +import logging +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from functools import lru_cache +from typing import Any + +from pyrit.backend.models.scenarios import ( + RunScenarioRequest, + ScenarioRunListResponse, + ScenarioRunResponse, + ScenarioRunResult, + ScenarioRunStatus, +) + +logger = logging.getLogger(__name__) + +MAX_CONCURRENT_RUNS = 3 +MAX_COMPLETED_RUNS = 50 + + +@dataclass +class _RunInfo: + """Internal tracking state for a scenario run.""" + + run_id: str + request: RunScenarioRequest + status: ScenarioRunStatus = ScenarioRunStatus.PENDING + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + task: asyncio.Task[None] | None = None + error: str | None = None + result: ScenarioRunResult | None = None + + +class ScenarioRunService: + """ + Service for managing scenario run lifecycle. + + Runs are tracked in-memory and executed as background asyncio tasks. + """ + + def __init__(self) -> None: + """Initialize the scenario run service.""" + self._runs: dict[str, _RunInfo] = {} + + async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunResponse: + """ + Start a new scenario run as a background task. + + Validates inputs synchronously, then spawns an asyncio task for execution. + + Args: + request: The run request with scenario name, target, and options. + + Returns: + ScenarioRunResponse with run_id and PENDING status. + + Raises: + ValueError: If scenario or target cannot be found, or concurrent limit exceeded. + """ + # Check concurrent run limit + active_count = sum( + 1 + for r in self._runs.values() + if r.status in (ScenarioRunStatus.PENDING, ScenarioRunStatus.INITIALIZING, ScenarioRunStatus.RUNNING) + ) + if active_count >= MAX_CONCURRENT_RUNS: + raise ValueError( + f"Maximum concurrent runs ({MAX_CONCURRENT_RUNS}) reached. " + "Wait for an existing run to complete or cancel one." + ) + + # Validate scenario exists + from pyrit.registry import ScenarioRegistry + + scenario_registry = ScenarioRegistry.get_registry_singleton() + try: + scenario_registry.get_class(request.scenario_name) + except KeyError as e: + raise ValueError(str(e)) from None + + # Create run info + run_id = str(uuid.uuid4()) + info = _RunInfo(run_id=run_id, request=request) + self._runs[run_id] = info + + # Evict old completed runs if over limit + self._evict_completed_runs() + + # Spawn background task + task = asyncio.create_task(self._execute_run_async(run_id=run_id)) + info.task = task + + return self._to_response(info) + + def get_run(self, *, run_id: str) -> ScenarioRunResponse | None: + """ + Get the current status of a scenario run. + + Args: + run_id: The unique run identifier. + + Returns: + ScenarioRunResponse if found, None otherwise. + """ + info = self._runs.get(run_id) + if info is None: + return None + return self._to_response(info) + + def list_runs(self) -> ScenarioRunListResponse: + """ + List all tracked scenario runs (most recent first). + + Returns: + ScenarioRunListResponse with all runs. + """ + items = [self._to_response(info) for info in reversed(self._runs.values())] + return ScenarioRunListResponse(items=items) + + async def cancel_run_async(self, *, run_id: str) -> ScenarioRunResponse | None: + """ + Cancel a running scenario. + + Args: + run_id: The unique run identifier. + + Returns: + Updated ScenarioRunResponse if found, None if run_id not found. + + Raises: + ValueError: If the run is already in a terminal state. + """ + info = self._runs.get(run_id) + if info is None: + return None + + terminal_states = (ScenarioRunStatus.COMPLETED, ScenarioRunStatus.FAILED, ScenarioRunStatus.CANCELLED) + if info.status in terminal_states: + raise ValueError(f"Cannot cancel run in '{info.status}' state.") + + # Cancel the asyncio task + if info.task is not None and not info.task.done(): + info.task.cancel() + + info.status = ScenarioRunStatus.CANCELLED + info.updated_at = datetime.now(timezone.utc) + return self._to_response(info) + + async def _execute_run_async(self, *, run_id: str) -> None: + """ + Execute a scenario run (background task entry point). + + Mirrors the flow in pyrit.cli.frontend_core.run_scenario_async. + + Args: + run_id: The run to execute. + """ + info = self._runs[run_id] + request = info.request + + try: + # --- Phase 1: Initialize --- + info.status = ScenarioRunStatus.INITIALIZING + info.updated_at = datetime.now(timezone.utc) + + from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry + from pyrit.scenario.core import DatasetConfiguration + + # Run initializers if requested + if request.initializers: + initializer_registry = InitializerRegistry.get_registry_singleton() + for initializer_name in request.initializers: + try: + initializer_class = initializer_registry.get_class(initializer_name) + except KeyError as e: + raise ValueError(f"Initializer not found: {e}") from None + instance = initializer_class() + await instance.initialize_async() + + # Resolve target + target_registry = TargetRegistry.get_registry_singleton() + objective_target = target_registry.get_instance_by_name(request.target_name) + if objective_target is None: + available_names = target_registry.get_names() + if not available_names: + raise ValueError( + f"Target '{request.target_name}' not found. The target registry is empty. " + "Make sure to include an initializer that registers targets " + "(e.g., initializers: ['target'])." + ) + raise ValueError( + f"Target '{request.target_name}' not found in registry. " + f"Available targets: {', '.join(available_names)}" + ) + + # Resolve scenario class + scenario_registry = ScenarioRegistry.get_registry_singleton() + scenario_class = scenario_registry.get_class(request.scenario_name) + + # --- Phase 2: Run --- + info.status = ScenarioRunStatus.RUNNING + info.updated_at = datetime.now(timezone.utc) + + # Build init kwargs + init_kwargs: dict[str, Any] = { + "objective_target": objective_target, + "max_concurrency": request.max_concurrency, + "max_retries": request.max_retries, + } + + if request.memory_labels: + init_kwargs["memory_labels"] = request.memory_labels + + # Resolve strategies + if request.strategies: + strategy_class = scenario_class.get_strategy_class() + strategy_enums = [] + for name in request.strategies: + try: + strategy_enums.append(strategy_class(name)) + except ValueError: + available_strategies = [s.value for s in strategy_class] + raise ValueError( + f"Strategy '{name}' not found for scenario '{request.scenario_name}'. " + f"Available: {', '.join(available_strategies)}" + ) from None + init_kwargs["scenario_strategies"] = strategy_enums + + # Build dataset config + if request.dataset_names: + init_kwargs["dataset_config"] = DatasetConfiguration( + dataset_names=request.dataset_names, + max_dataset_size=request.max_dataset_size, + ) + elif request.max_dataset_size is not None: + default_config = scenario_class.default_dataset_config() + default_config.max_dataset_size = request.max_dataset_size + init_kwargs["dataset_config"] = default_config + + # Instantiate and execute + scenario = scenario_class() # type: ignore[call-arg] + await scenario.initialize_async(**init_kwargs) + scenario_result = await scenario.run_async() + + # --- Phase 3: Store result --- + info.status = ScenarioRunStatus.COMPLETED + info.updated_at = datetime.now(timezone.utc) + info.result = ScenarioRunResult( + scenario_result_id=str(scenario_result.id), + run_state=scenario_result.scenario_run_state, + strategies_used=scenario_result.get_strategies_used(), + total_attacks=len(scenario_result.attack_results), + completed_attacks=len(scenario_result.attack_results), + number_tries=scenario_result.number_tries, + completion_time=scenario_result.completion_time, + ) + + except asyncio.CancelledError: + info.status = ScenarioRunStatus.CANCELLED + info.updated_at = datetime.now(timezone.utc) + logger.info(f"Scenario run {run_id} was cancelled.") + + except Exception as e: + info.status = ScenarioRunStatus.FAILED + info.updated_at = datetime.now(timezone.utc) + info.error = str(e) + logger.exception(f"Scenario run {run_id} failed: {e}") + + def _evict_completed_runs(self) -> None: + """Remove oldest completed runs if over the retention limit.""" + terminal_states = (ScenarioRunStatus.COMPLETED, ScenarioRunStatus.FAILED, ScenarioRunStatus.CANCELLED) + completed = [r for r in self._runs.values() if r.status in terminal_states] + if len(completed) > MAX_COMPLETED_RUNS: + # Sort by creation time, remove oldest + completed.sort(key=lambda r: r.created_at) + for run_info in completed[: len(completed) - MAX_COMPLETED_RUNS]: + del self._runs[run_info.run_id] + + def get_run_results(self, *, run_id: str) -> "ScenarioResultDetailResponse | None": + """ + Get detailed results for a completed scenario run. + + Retrieves the full ScenarioResult from CentralMemory and maps it + to a detailed response model with per-attack outcomes. + + Args: + run_id: The unique run identifier. + + Returns: + ScenarioResultDetailResponse if the run is completed and results exist, None if run not found. + + Raises: + ValueError: If the run is not in a completed state or results not found in memory. + """ + from pyrit.backend.models.scenarios import ( + AtomicAttackResults, + AttackResultDetail, + ScenarioResultDetailResponse, + ) + from pyrit.memory import CentralMemory + from pyrit.models import AttackOutcome + + info = self._runs.get(run_id) + if info is None: + return None + + if info.status != ScenarioRunStatus.COMPLETED or info.result is None: + raise ValueError( + f"Results are only available for completed runs. Current status: '{info.status}'." + ) + + # Retrieve from CentralMemory + memory = CentralMemory.get_memory_instance() + results = memory.get_scenario_results(scenario_result_ids=[info.result.scenario_result_id]) + if not results: + raise ValueError( + f"Scenario result '{info.result.scenario_result_id}' not found in memory." + ) + + scenario_result = results[0] + display_groups = scenario_result.get_display_groups() + + # Build per-attack detail + attacks: list[AtomicAttackResults] = [] + for attack_name, attack_results in scenario_result.attack_results.items(): + details: list[AttackResultDetail] = [] + success_count = 0 + failure_count = 0 + + for ar in attack_results: + score_value = None + if ar.last_score is not None: + score_value = ar.last_score.get_value() + + last_response_text = None + if ar.last_response is not None: + last_response_text = ar.last_response.value if hasattr(ar.last_response, "value") else str(ar.last_response) + + details.append( + AttackResultDetail( + attack_result_id=ar.attack_result_id, + conversation_id=ar.conversation_id, + objective=ar.objective, + outcome=ar.outcome.value, + outcome_reason=ar.outcome_reason, + last_response=last_response_text, + score_value=score_value, + executed_turns=ar.executed_turns, + execution_time_ms=ar.execution_time_ms, + timestamp=ar.timestamp, + ) + ) + + if ar.outcome == AttackOutcome.SUCCESS: + success_count += 1 + elif ar.outcome == AttackOutcome.FAILURE: + failure_count += 1 + + # Find display group for this attack + display_group = None + if hasattr(scenario_result, "_display_group_map") and scenario_result._display_group_map: + display_group = scenario_result._display_group_map.get(attack_name) + + attacks.append( + AtomicAttackResults( + atomic_attack_name=attack_name, + display_group=display_group, + results=details, + success_count=success_count, + failure_count=failure_count, + total_count=len(details), + ) + ) + + return ScenarioResultDetailResponse( + scenario_result_id=str(scenario_result.id), + scenario_name=scenario_result.scenario_identifier.name, + scenario_version=scenario_result.scenario_identifier.version, + run_state=scenario_result.scenario_run_state, + objective_achieved_rate=scenario_result.objective_achieved_rate(), + number_tries=scenario_result.number_tries, + completion_time=scenario_result.completion_time, + labels=scenario_result.labels, + attacks=attacks, + ) + + @staticmethod + def _to_response(info: _RunInfo) -> ScenarioRunResponse: + """Convert internal run info to API response model.""" + return ScenarioRunResponse( + run_id=info.run_id, + scenario_name=info.request.scenario_name, + status=info.status, + created_at=info.created_at, + updated_at=info.updated_at, + error=info.error, + result=info.result, + ) + + +@lru_cache(maxsize=1) +def get_scenario_run_service() -> ScenarioRunService: + """ + Get the global scenario run service instance. + + Returns: + The singleton ScenarioRunService instance. + """ + return ScenarioRunService() diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py new file mode 100644 index 0000000000..5498118594 --- /dev/null +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -0,0 +1,315 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for scenario run API routes. +""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from pyrit.backend.main import app +from pyrit.backend.models.scenarios import ( + ScenarioRunListResponse, + ScenarioRunResponse, + ScenarioRunStatus, +) +from pyrit.backend.services.scenario_run_service import get_scenario_run_service + + +@pytest.fixture +def client() -> TestClient: + """Create a test client for the FastAPI app.""" + return TestClient(app) + + +@pytest.fixture(autouse=True) +def clear_service_cache(): + """Clear the service singleton cache between tests.""" + get_scenario_run_service.cache_clear() + yield + get_scenario_run_service.cache_clear() + + +def _mock_run_response( + *, + run_id: str = "test-run-id", + scenario_name: str = "foundry.red_team_agent", + run_status: ScenarioRunStatus = ScenarioRunStatus.PENDING, +) -> ScenarioRunResponse: + """Create a mock ScenarioRunResponse.""" + return ScenarioRunResponse( + run_id=run_id, + scenario_name=scenario_name, + status=run_status, + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + error=None, + result=None, + ) + + +class TestStartScenarioRunRoute: + """Tests for POST /api/scenarios/runs.""" + + def test_start_run_returns_202(self, client: TestClient) -> None: + """Test that a valid request returns 202 Accepted.""" + mock_response = _mock_run_response() + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.start_run_async = AsyncMock(return_value=mock_response) + mock_get.return_value = mock_service + + response = client.post( + "/api/scenarios/runs", + json={"scenario_name": "foundry.red_team_agent", "target_name": "my_target"}, + ) + + assert response.status_code == status.HTTP_202_ACCEPTED + data = response.json() + assert data["run_id"] == "test-run-id" + assert data["status"] == "pending" + + def test_start_run_invalid_scenario_returns_400(self, client: TestClient) -> None: + """Test that an invalid scenario returns 400.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.start_run_async = AsyncMock( + side_effect=ValueError("'bad.scenario' not found in registry.") + ) + mock_get.return_value = mock_service + + response = client.post( + "/api/scenarios/runs", + json={"scenario_name": "bad.scenario", "target_name": "my_target"}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "not found" in response.json()["detail"] + + def test_start_run_missing_required_fields_returns_422(self, client: TestClient) -> None: + """Test that missing required fields returns 422.""" + response = client.post("/api/scenarios/runs", json={}) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_start_run_with_all_options(self, client: TestClient) -> None: + """Test that all optional fields are accepted.""" + mock_response = _mock_run_response() + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.start_run_async = AsyncMock(return_value=mock_response) + mock_get.return_value = mock_service + + response = client.post( + "/api/scenarios/runs", + json={ + "scenario_name": "foundry.red_team_agent", + "target_name": "my_target", + "initializers": ["target", "load_default_datasets"], + "strategies": ["base64", "rot13"], + "dataset_names": ["harmful_content"], + "max_dataset_size": 50, + "max_concurrency": 5, + "max_retries": 2, + "memory_labels": {"team": "red"}, + }, + ) + + assert response.status_code == status.HTTP_202_ACCEPTED + + +class TestListScenarioRunsRoute: + """Tests for GET /api/scenarios/runs.""" + + def test_list_runs_returns_200(self, client: TestClient) -> None: + """Test that list runs returns 200 with empty list.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.list_runs.return_value = ScenarioRunListResponse(items=[]) + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["items"] == [] + + def test_list_runs_returns_multiple_runs(self, client: TestClient) -> None: + """Test that list runs returns all tracked runs.""" + runs = [ + _mock_run_response(run_id="run-1"), + _mock_run_response(run_id="run-2", run_status=ScenarioRunStatus.RUNNING), + ] + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.list_runs.return_value = ScenarioRunListResponse(items=runs) + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs") + + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["items"]) == 2 + + +class TestGetScenarioRunRoute: + """Tests for GET /api/scenarios/runs/{run_id}.""" + + def test_get_run_returns_200(self, client: TestClient) -> None: + """Test that getting an existing run returns 200.""" + mock_response = _mock_run_response(run_status=ScenarioRunStatus.RUNNING) + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run.return_value = mock_response + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/test-run-id") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["status"] == "running" + + def test_get_run_not_found_returns_404(self, client: TestClient) -> None: + """Test that getting a non-existent run returns 404.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run.return_value = None + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestCancelScenarioRunRoute: + """Tests for DELETE /api/scenarios/runs/{run_id}.""" + + def test_cancel_run_returns_200(self, client: TestClient) -> None: + """Test that cancelling a running scenario returns 200.""" + mock_response = _mock_run_response(run_status=ScenarioRunStatus.CANCELLED) + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.cancel_run_async = AsyncMock(return_value=mock_response) + mock_get.return_value = mock_service + + response = client.delete("/api/scenarios/runs/test-run-id") + + assert response.status_code == status.HTTP_200_OK + assert response.json()["status"] == "cancelled" + + def test_cancel_run_not_found_returns_404(self, client: TestClient) -> None: + """Test that cancelling a non-existent run returns 404.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.cancel_run_async = AsyncMock(return_value=None) + mock_get.return_value = mock_service + + response = client.delete("/api/scenarios/runs/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_cancel_completed_run_returns_409(self, client: TestClient) -> None: + """Test that cancelling a completed run returns 409 Conflict.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.cancel_run_async = AsyncMock( + side_effect=ValueError("Cannot cancel run in 'completed' state.") + ) + mock_get.return_value = mock_service + + response = client.delete("/api/scenarios/runs/test-run-id") + + assert response.status_code == status.HTTP_409_CONFLICT + assert "Cannot cancel" in response.json()["detail"] + + +class TestGetScenarioRunResultsRoute: + """Tests for GET /api/scenarios/runs/{run_id}/results.""" + + def test_get_results_returns_200(self, client: TestClient) -> None: + """Test that getting results of a completed run returns 200.""" + from pyrit.backend.models.scenarios import ( + AtomicAttackResults, + AttackResultDetail, + ScenarioResultDetailResponse, + ) + + mock_result = ScenarioResultDetailResponse( + scenario_result_id="result-uuid", + scenario_name="foundry.red_team_agent", + scenario_version=1, + run_state="COMPLETED", + objective_achieved_rate=50, + number_tries=1, + completion_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + labels={"team": "red"}, + attacks=[ + AtomicAttackResults( + atomic_attack_name="base64_attack", + display_group="encoding", + results=[ + AttackResultDetail( + attack_result_id="ar-1", + conversation_id="conv-1", + objective="Extract sensitive info", + outcome="success", + outcome_reason="Model revealed data", + last_response="Here is the data...", + score_value="1.0", + executed_turns=3, + execution_time_ms=1500, + timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), + ), + ], + success_count=1, + failure_count=0, + total_count=1, + ), + ], + ) + + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run_results.return_value = mock_result + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/test-run-id/results") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["scenario_result_id"] == "result-uuid" + assert data["objective_achieved_rate"] == 50 + assert len(data["attacks"]) == 1 + assert data["attacks"][0]["atomic_attack_name"] == "base64_attack" + assert data["attacks"][0]["results"][0]["outcome"] == "success" + + def test_get_results_not_found_returns_404(self, client: TestClient) -> None: + """Test that getting results of a non-existent run returns 404.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run_results.return_value = None + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/nonexistent/results") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_get_results_not_completed_returns_409(self, client: TestClient) -> None: + """Test that getting results of a non-completed run returns 409.""" + with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: + mock_service = MagicMock() + mock_service.get_run_results.side_effect = ValueError( + "Results are only available for completed runs. Current status: 'running'." + ) + mock_get.return_value = mock_service + + response = client.get("/api/scenarios/runs/test-run-id/results") + + assert response.status_code == status.HTTP_409_CONFLICT + assert "only available for completed runs" in response.json()["detail"] diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py new file mode 100644 index 0000000000..0cb449ea86 --- /dev/null +++ b/tests/unit/backend/test_scenario_run_service.py @@ -0,0 +1,425 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for ScenarioRunService. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.backend.models.scenarios import ( + RunScenarioRequest, + ScenarioRunStatus, +) +from pyrit.backend.services.scenario_run_service import ( + MAX_CONCURRENT_RUNS, + ScenarioRunService, + get_scenario_run_service, +) + +# The service uses deferred imports inside methods, so we patch at the source module. +_REGISTRY_PATCH_BASE = "pyrit.registry" + + +@pytest.fixture(autouse=True) +def clear_service_cache(): + """Clear the singleton cache between tests.""" + get_scenario_run_service.cache_clear() + yield + get_scenario_run_service.cache_clear() + + +def _make_request( + *, + scenario_name: str = "foundry.red_team_agent", + target_name: str = "my_target", + initializers: list[str] | None = None, + strategies: list[str] | None = None, +) -> RunScenarioRequest: + """Create a RunScenarioRequest for testing.""" + return RunScenarioRequest( + scenario_name=scenario_name, + target_name=target_name, + initializers=initializers, + strategies=strategies, + ) + + +@pytest.fixture +def mock_scenario_registry(): + """Patch ScenarioRegistry.get_registry_singleton to return a mock.""" + mock_registry = MagicMock() + mock_registry.get_class.return_value = MagicMock() + with patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_registry): + yield mock_registry + + +@pytest.fixture +def mock_target_registry(): + """Patch TargetRegistry.get_registry_singleton to return a mock.""" + mock_registry = MagicMock() + mock_registry.get_instance_by_name.return_value = MagicMock() + mock_registry.get_names.return_value = ["my_target"] + with patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_registry): + yield mock_registry + + +@pytest.fixture +def mock_initializer_registry(): + """Patch InitializerRegistry.get_registry_singleton to return a mock.""" + mock_instance = MagicMock() + mock_instance.initialize_async = AsyncMock() + mock_class = MagicMock(return_value=mock_instance) + + mock_registry = MagicMock() + mock_registry.get_class.return_value = mock_class + with patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton", return_value=mock_registry): + yield mock_registry, mock_class, mock_instance + + +class TestScenarioRunServiceStartRun: + """Tests for ScenarioRunService.start_run_async.""" + + async def test_start_run_returns_pending_status(self, mock_scenario_registry) -> None: + """Test that starting a run returns PENDING status with a run_id.""" + service = ScenarioRunService() + + with patch.object(service, "_execute_run_async", new_callable=AsyncMock): + response = await service.start_run_async(request=_make_request()) + + assert response.run_id is not None + assert response.status == ScenarioRunStatus.PENDING + assert response.scenario_name == "foundry.red_team_agent" + assert response.error is None + assert response.result is None + + async def test_start_run_invalid_scenario_raises_value_error(self) -> None: + """Test that an invalid scenario name raises ValueError.""" + service = ScenarioRunService() + + mock_registry = MagicMock() + mock_registry.get_class.side_effect = KeyError("'bad.scenario' not found in registry. Available: foo") + with patch( + f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_registry + ): + with pytest.raises(ValueError, match="not found in registry"): + await service.start_run_async(request=_make_request(scenario_name="bad.scenario")) + + async def test_start_run_exceeds_concurrent_limit(self, mock_scenario_registry) -> None: + """Test that exceeding concurrent run limit raises ValueError.""" + service = ScenarioRunService() + + with patch.object(service, "_execute_run_async", new_callable=AsyncMock): + # Fill up to the limit + for _ in range(MAX_CONCURRENT_RUNS): + await service.start_run_async(request=_make_request()) + + # Next one should fail + with pytest.raises(ValueError, match="Maximum concurrent runs"): + await service.start_run_async(request=_make_request()) + + +class TestScenarioRunServiceGetRun: + """Tests for ScenarioRunService.get_run.""" + + async def test_get_run_returns_none_for_unknown_id(self) -> None: + """Test that get_run returns None for non-existent run_id.""" + service = ScenarioRunService() + result = service.get_run(run_id="nonexistent-id") + assert result is None + + async def test_get_run_returns_existing_run(self, mock_scenario_registry) -> None: + """Test that get_run returns a started run.""" + service = ScenarioRunService() + + with patch.object(service, "_execute_run_async", new_callable=AsyncMock): + response = await service.start_run_async(request=_make_request()) + + fetched = service.get_run(run_id=response.run_id) + assert fetched is not None + assert fetched.run_id == response.run_id + assert fetched.scenario_name == "foundry.red_team_agent" + + +class TestScenarioRunServiceListRuns: + """Tests for ScenarioRunService.list_runs.""" + + async def test_list_runs_empty(self) -> None: + """Test that list_runs returns empty list initially.""" + service = ScenarioRunService() + result = service.list_runs() + assert result.items == [] + + async def test_list_runs_returns_all_runs(self, mock_scenario_registry) -> None: + """Test that list_runs returns all tracked runs.""" + service = ScenarioRunService() + + with patch.object(service, "_execute_run_async", new_callable=AsyncMock): + await service.start_run_async(request=_make_request()) + await service.start_run_async(request=_make_request()) + + result = service.list_runs() + assert len(result.items) == 2 + + +class TestScenarioRunServiceCancelRun: + """Tests for ScenarioRunService.cancel_run_async.""" + + async def test_cancel_run_returns_none_for_unknown_id(self) -> None: + """Test that cancel returns None for non-existent run_id.""" + service = ScenarioRunService() + result = await service.cancel_run_async(run_id="nonexistent-id") + assert result is None + + async def test_cancel_run_sets_cancelled_status(self, mock_scenario_registry) -> None: + """Test that cancelling a running scenario sets CANCELLED status.""" + service = ScenarioRunService() + + with patch.object(service, "_execute_run_async", new_callable=AsyncMock): + response = await service.start_run_async(request=_make_request()) + + result = await service.cancel_run_async(run_id=response.run_id) + assert result is not None + assert result.status == ScenarioRunStatus.CANCELLED + + async def test_cancel_completed_run_raises_value_error(self, mock_scenario_registry) -> None: + """Test that cancelling a completed run raises ValueError.""" + service = ScenarioRunService() + + with patch.object(service, "_execute_run_async", new_callable=AsyncMock): + response = await service.start_run_async(request=_make_request()) + + # Manually set to COMPLETED + service._runs[response.run_id].status = ScenarioRunStatus.COMPLETED + + with pytest.raises(ValueError, match="Cannot cancel run"): + await service.cancel_run_async(run_id=response.run_id) + + +class TestScenarioRunServiceExecution: + """Tests for the background execution logic.""" + + async def test_execute_run_completes_successfully(self) -> None: + """Test that a successful execution transitions to COMPLETED.""" + service = ScenarioRunService() + + mock_scenario_result = MagicMock() + mock_scenario_result.id = "result-uuid" + mock_scenario_result.scenario_run_state = "COMPLETED" + mock_scenario_result.get_strategies_used.return_value = ["base64"] + mock_scenario_result.attack_results = {"attack1": []} + mock_scenario_result.number_tries = 1 + mock_scenario_result.completion_time = None + + mock_scenario_instance = MagicMock() + mock_scenario_instance.initialize_async = AsyncMock() + mock_scenario_instance.run_async = AsyncMock(return_value=mock_scenario_result) + + mock_scenario_class = MagicMock(return_value=mock_scenario_instance) + mock_scenario_class.get_strategy_class.return_value = MagicMock() + mock_scenario_class.default_dataset_config.return_value = MagicMock() + + mock_target = MagicMock() + + with ( + patch( + f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton" + ) as mock_sr, + patch( + f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton" + ) as mock_tr, + patch( + f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton" + ), + ): + mock_sr.return_value.get_class.return_value = mock_scenario_class + mock_tr.return_value.get_instance_by_name.return_value = mock_target + + response = await service.start_run_async(request=_make_request()) + + # Wait for the background task to complete + task = service._runs[response.run_id].task + assert task is not None + await task + + run = service.get_run(run_id=response.run_id) + assert run is not None + assert run.status == ScenarioRunStatus.COMPLETED + assert run.result is not None + assert run.result.scenario_result_id == "result-uuid" + assert run.result.strategies_used == ["base64"] + + async def test_execute_run_fails_with_error(self) -> None: + """Test that a failed execution transitions to FAILED with error message.""" + service = ScenarioRunService() + + with ( + patch( + f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton" + ) as mock_sr, + patch( + f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton" + ) as mock_tr, + patch( + f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton" + ), + ): + mock_sr.return_value.get_class.return_value = MagicMock() + mock_tr.return_value.get_instance_by_name.return_value = None + mock_tr.return_value.get_names.return_value = ["other_target"] + + response = await service.start_run_async(request=_make_request()) + + # Wait for the background task + task = service._runs[response.run_id].task + assert task is not None + await task + + run = service.get_run(run_id=response.run_id) + assert run is not None + assert run.status == ScenarioRunStatus.FAILED + assert run.error is not None + assert "my_target" in run.error + + async def test_execute_run_with_initializers(self) -> None: + """Test that initializers are run before scenario execution.""" + service = ScenarioRunService() + + mock_scenario_result = MagicMock() + mock_scenario_result.id = "result-uuid" + mock_scenario_result.scenario_run_state = "COMPLETED" + mock_scenario_result.get_strategies_used.return_value = [] + mock_scenario_result.attack_results = {} + mock_scenario_result.number_tries = 1 + mock_scenario_result.completion_time = None + + mock_scenario_instance = MagicMock() + mock_scenario_instance.initialize_async = AsyncMock() + mock_scenario_instance.run_async = AsyncMock(return_value=mock_scenario_result) + + mock_scenario_class = MagicMock(return_value=mock_scenario_instance) + + mock_initializer_instance = MagicMock() + mock_initializer_instance.initialize_async = AsyncMock() + mock_initializer_class = MagicMock(return_value=mock_initializer_instance) + + mock_target = MagicMock() + + with ( + patch( + f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton" + ) as mock_sr, + patch( + f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton" + ) as mock_tr, + patch( + f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton" + ) as mock_ir, + ): + mock_sr.return_value.get_class.return_value = mock_scenario_class + mock_tr.return_value.get_instance_by_name.return_value = mock_target + mock_ir.return_value.get_class.return_value = mock_initializer_class + + response = await service.start_run_async( + request=_make_request(initializers=["target", "load_default_datasets"]) + ) + + task = service._runs[response.run_id].task + assert task is not None + await task + + # Initializer should have been called twice (once per name) + assert mock_initializer_instance.initialize_async.await_count == 2 + + run = service.get_run(run_id=response.run_id) + assert run is not None + assert run.status == ScenarioRunStatus.COMPLETED + + +class TestScenarioRunServiceGetResults: + """Tests for ScenarioRunService.get_run_results.""" + + def test_get_results_returns_none_for_unknown_id(self) -> None: + """Test that get_run_results returns None for non-existent run_id.""" + service = ScenarioRunService() + result = service.get_run_results(run_id="nonexistent-id") + assert result is None + + async def test_get_results_raises_if_not_completed(self, mock_scenario_registry) -> None: + """Test that get_run_results raises ValueError if run is not completed.""" + service = ScenarioRunService() + + with patch.object(service, "_execute_run_async", new_callable=AsyncMock): + response = await service.start_run_async(request=_make_request()) + + # Run is in PENDING state + with pytest.raises(ValueError, match="only available for completed runs"): + service.get_run_results(run_id=response.run_id) + + async def test_get_results_returns_details_for_completed_run(self, mock_scenario_registry) -> None: + """Test that get_run_results returns full details for a completed run.""" + from pyrit.backend.models.scenarios import ScenarioRunResult + from pyrit.models import AttackOutcome + + service = ScenarioRunService() + + with patch.object(service, "_execute_run_async", new_callable=AsyncMock): + response = await service.start_run_async(request=_make_request()) + + # Manually set run to completed with a result + info = service._runs[response.run_id] + info.status = ScenarioRunStatus.COMPLETED + info.result = ScenarioRunResult( + scenario_result_id="sr-123", + run_state="COMPLETED", + strategies_used=["base64"], + total_attacks=1, + completed_attacks=1, + number_tries=1, + completion_time=None, + ) + + # Mock CentralMemory and ScenarioResult + mock_attack_result = MagicMock() + mock_attack_result.attack_result_id = "ar-1" + mock_attack_result.conversation_id = "conv-1" + mock_attack_result.objective = "Extract info" + mock_attack_result.outcome = AttackOutcome.SUCCESS + mock_attack_result.outcome_reason = "Model complied" + mock_attack_result.last_response = MagicMock(value="Here is the data") + mock_attack_result.last_score = MagicMock() + mock_attack_result.last_score.get_value.return_value = "1.0" + mock_attack_result.executed_turns = 3 + mock_attack_result.execution_time_ms = 1500 + mock_attack_result.timestamp = None + + mock_scenario_result = MagicMock() + mock_scenario_result.id = "sr-123" + mock_scenario_result.scenario_identifier.name = "foundry.red_team_agent" + mock_scenario_result.scenario_identifier.version = 1 + mock_scenario_result.scenario_run_state = "COMPLETED" + mock_scenario_result.objective_achieved_rate.return_value = 100 + mock_scenario_result.number_tries = 1 + mock_scenario_result.completion_time = None + mock_scenario_result.labels = {} + mock_scenario_result.attack_results = {"base64_attack": [mock_attack_result]} + mock_scenario_result.get_display_groups.return_value = {"base64_attack": [mock_attack_result]} + mock_scenario_result._display_group_map = {} + + mock_memory = MagicMock() + mock_memory.get_scenario_results.return_value = [mock_scenario_result] + + with patch("pyrit.memory.CentralMemory.get_memory_instance", return_value=mock_memory): + detail = service.get_run_results(run_id=response.run_id) + + assert detail is not None + assert detail.scenario_result_id == "sr-123" + assert detail.objective_achieved_rate == 100 + assert len(detail.attacks) == 1 + assert detail.attacks[0].atomic_attack_name == "base64_attack" + assert detail.attacks[0].success_count == 1 + assert detail.attacks[0].results[0].objective == "Extract info" + assert detail.attacks[0].results[0].outcome == "success" From 88d71e69bd10612c84e34e64785bff1861ccbb22 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 6 May 2026 15:21:34 -0700 Subject: [PATCH 02/10] refactoring --- pyrit/backend/models/scenarios.py | 24 +- pyrit/backend/routes/scenarios.py | 76 ++--- pyrit/backend/services/__init__.py | 8 +- .../backend/services/scenario_run_service.py | 222 +++++++------ .../unit/backend/test_scenario_run_routes.py | 10 +- .../unit/backend/test_scenario_run_service.py | 300 +++++++++++------- tests/unit/backend/test_scenario_service.py | 20 +- 7 files changed, 376 insertions(+), 284 deletions(-) diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index 5f16aee428..494972ddfb 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -10,7 +10,7 @@ from datetime import datetime from enum import StrEnum -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel, Field @@ -64,13 +64,27 @@ class RunScenarioRequest(BaseModel): None, description="Initializer names to run before scenario (e.g., ['target', 'load_default_datasets'])" ) strategies: list[str] | None = Field(None, description="Strategy names to use (uses scenario default if omitted)") - dataset_names: list[str] | None = Field( - None, description="Dataset names to use (uses scenario default if omitted)" - ) + dataset_names: list[str] | None = Field(None, description="Dataset names to use (uses scenario default if omitted)") max_dataset_size: int | None = Field(None, ge=1, description="Maximum items per dataset") max_concurrency: int = Field(10, ge=1, le=100, description="Maximum concurrent operations") - max_retries: int = Field(0, ge=0, le=10, description="Maximum retry attempts on failure") + max_retries: int = Field(0, ge=0, le=20, description="Maximum retry attempts on failure") memory_labels: dict[str, str] | None = Field(None, description="Labels to attach to memory entries") + scenario_params: dict[str, Any] | None = Field( + None, + description="Custom parameters for the scenario (passed to scenario.set_params_from_args). " + "Keys are parameter names declared by the scenario's supported_parameters().", + ) + initializer_args: dict[str, dict[str, Any]] | None = Field( + None, + description="Per-initializer arguments keyed by initializer name. " + "Each value is a dict of args passed to that initializer's set_params_from_args(). " + "Example: {'target': {'endpoint': 'https://...'}}.", + ) + scenario_result_id: str | None = Field( + None, + description="Optional ID of an existing ScenarioResult to resume. " + "If provided, the scenario will resume from prior progress instead of starting fresh.", + ) class ScenarioRunResult(BaseModel): diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py index 010e6ad54b..6c1fde6c47 100644 --- a/pyrit/backend/routes/scenarios.py +++ b/pyrit/backend/routes/scenarios.py @@ -6,6 +6,10 @@ Provides endpoints for listing available scenarios, their metadata, and managing scenario runs. + +Route structure: + /api/scenarios/catalog — scenario catalog (list + detail) + /api/scenarios/runs — scenario execution lifecycle """ from typing import Optional @@ -27,8 +31,13 @@ router = APIRouter(prefix="/scenarios", tags=["scenarios"]) +# ============================================================================ +# Scenario Catalog +# ============================================================================ + + @router.get( - "", + "/catalog", response_model=ScenarioListResponse, ) async def list_scenarios( @@ -39,7 +48,7 @@ async def list_scenarios( List all available scenarios. Returns scenario metadata including strategies, datasets, and defaults. - Use GET /api/scenarios/{scenario_name} for full details on a specific scenario. + Use GET /api/scenarios/catalog/{scenario_name} for full details on a specific scenario. Returns: ScenarioListResponse: Paginated list of scenario summaries. @@ -48,6 +57,35 @@ async def list_scenarios( return await service.list_scenarios_async(limit=limit, cursor=cursor) +@router.get( + "/catalog/{scenario_name:path}", + response_model=ScenarioSummary, + responses={ + 404: {"model": ProblemDetail, "description": "Scenario not found"}, + }, +) +async def get_scenario(scenario_name: str) -> ScenarioSummary: + """ + Get details for a specific scenario. + + Args: + scenario_name: Registry name of the scenario (e.g., 'foundry.red_team_agent'). + + Returns: + ScenarioSummary: Full scenario metadata. + """ + service = get_scenario_service() + + scenario = await service.get_scenario_async(scenario_name=scenario_name) + if not scenario: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Scenario '{scenario_name}' not found", + ) + + return scenario + + # ============================================================================ # Scenario Runs # ============================================================================ @@ -187,37 +225,3 @@ async def get_scenario_run_results(run_id: str) -> ScenarioResultDetailResponse: detail=f"Scenario run '{run_id}' not found", ) return result - - -# ============================================================================ -# Scenario Detail (catch-all path — must be last) -# ============================================================================ - - -@router.get( - "/{scenario_name:path}", - response_model=ScenarioSummary, - responses={ - 404: {"model": ProblemDetail, "description": "Scenario not found"}, - }, -) -async def get_scenario(scenario_name: str) -> ScenarioSummary: - """ - Get details for a specific scenario. - - Args: - scenario_name: Registry name of the scenario (e.g., 'foundry.red_team_agent'). - - Returns: - ScenarioSummary: Full scenario metadata. - """ - service = get_scenario_service() - - scenario = await service.get_scenario_async(scenario_name=scenario_name) - if not scenario: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Scenario '{scenario_name}' not found", - ) - - return scenario diff --git a/pyrit/backend/services/__init__.py b/pyrit/backend/services/__init__.py index 646afb0bf0..d36f69a830 100644 --- a/pyrit/backend/services/__init__.py +++ b/pyrit/backend/services/__init__.py @@ -15,14 +15,14 @@ ConverterService, get_converter_service, ) -from pyrit.backend.services.scenario_service import ( - ScenarioService, - get_scenario_service, -) from pyrit.backend.services.scenario_run_service import ( ScenarioRunService, get_scenario_run_service, ) +from pyrit.backend.services.scenario_service import ( + ScenarioService, + get_scenario_service, +) from pyrit.backend.services.target_service import ( TargetService, get_target_service, diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 8ec96c5fc2..3da690b3a1 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -42,6 +42,7 @@ class _RunInfo: task: asyncio.Task[None] | None = None error: str | None = None result: ScenarioRunResult | None = None + scenario: Any = None class ScenarioRunService: @@ -59,16 +60,20 @@ async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunRe """ Start a new scenario run as a background task. - Validates inputs synchronously, then spawns an asyncio task for execution. + Performs all validation and initialization eagerly (initializers, target + resolution, strategy validation, scenario.initialize_async) so errors are + returned immediately. On success, spawns a background task that only + executes scenario.run_async. Args: request: The run request with scenario name, target, and options. Returns: - ScenarioRunResponse with run_id and PENDING status. + ScenarioRunResponse with run_id and RUNNING status. Raises: - ValueError: If scenario or target cannot be found, or concurrent limit exceeded. + ValueError: If scenario, target, initializer, or strategy cannot be found, + or concurrent limit exceeded. """ # Check concurrent run limit active_count = sum( @@ -82,24 +87,18 @@ async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunRe "Wait for an existing run to complete or cancel one." ) - # Validate scenario exists - from pyrit.registry import ScenarioRegistry - - scenario_registry = ScenarioRegistry.get_registry_singleton() - try: - scenario_registry.get_class(request.scenario_name) - except KeyError as e: - raise ValueError(str(e)) from None + # Perform all initialization eagerly — errors propagate to caller + scenario = await self._initialize_run_async(request=request) - # Create run info + # Create run info in RUNNING state (initialization already complete) run_id = str(uuid.uuid4()) - info = _RunInfo(run_id=run_id, request=request) + info = _RunInfo(run_id=run_id, request=request, status=ScenarioRunStatus.RUNNING, scenario=scenario) self._runs[run_id] = info # Evict old completed runs if over limit self._evict_completed_runs() - # Spawn background task + # Spawn background task (only runs scenario.run_async) task = asyncio.create_task(self._execute_run_async(run_id=run_id)) info.task = task @@ -159,103 +158,122 @@ async def cancel_run_async(self, *, run_id: str) -> ScenarioRunResponse | None: info.updated_at = datetime.now(timezone.utc) return self._to_response(info) - async def _execute_run_async(self, *, run_id: str) -> None: + async def _initialize_run_async(self, *, request: RunScenarioRequest) -> Any: """ - Execute a scenario run (background task entry point). + Validate inputs and initialize the scenario eagerly. - Mirrors the flow in pyrit.cli.frontend_core.run_scenario_async. + Performs all validation (scenario, initializers, target, strategies) and + calls scenario.initialize_async so that any errors are raised immediately + to the caller. Args: - run_id: The run to execute. + request: The run request with scenario name, target, and options. + + Returns: + The fully initialized Scenario instance ready for run_async. + + Raises: + ValueError: If any validation fails (bad scenario name, missing target, + invalid strategy, unknown initializer, etc.). """ - info = self._runs[run_id] - request = info.request + from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry + from pyrit.scenario.core import DatasetConfiguration + # Validate scenario exists + scenario_registry = ScenarioRegistry.get_registry_singleton() try: - # --- Phase 1: Initialize --- - info.status = ScenarioRunStatus.INITIALIZING - info.updated_at = datetime.now(timezone.utc) + scenario_class = scenario_registry.get_class(request.scenario_name) + except KeyError as e: + raise ValueError(str(e)) from None - from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry - from pyrit.scenario.core import DatasetConfiguration - - # Run initializers if requested - if request.initializers: - initializer_registry = InitializerRegistry.get_registry_singleton() - for initializer_name in request.initializers: - try: - initializer_class = initializer_registry.get_class(initializer_name) - except KeyError as e: - raise ValueError(f"Initializer not found: {e}") from None - instance = initializer_class() - await instance.initialize_async() - - # Resolve target - target_registry = TargetRegistry.get_registry_singleton() - objective_target = target_registry.get_instance_by_name(request.target_name) - if objective_target is None: - available_names = target_registry.get_names() - if not available_names: - raise ValueError( - f"Target '{request.target_name}' not found. The target registry is empty. " - "Make sure to include an initializer that registers targets " - "(e.g., initializers: ['target'])." - ) + # Validate and run initializers + if request.initializers: + initializer_registry = InitializerRegistry.get_registry_singleton() + for initializer_name in request.initializers: + try: + initializer_class = initializer_registry.get_class(initializer_name) + except KeyError as e: + raise ValueError(f"Initializer not found: {e}") from None + instance = initializer_class() + if request.initializer_args and initializer_name in request.initializer_args: + instance.set_params_from_args(args=request.initializer_args[initializer_name]) + await instance.initialize_async() + + # Resolve target + target_registry = TargetRegistry.get_registry_singleton() + objective_target = target_registry.get_instance_by_name(request.target_name) + if objective_target is None: + available_names = target_registry.get_names() + if not available_names: raise ValueError( - f"Target '{request.target_name}' not found in registry. " - f"Available targets: {', '.join(available_names)}" + f"Target '{request.target_name}' not found. The target registry is empty. " + "Make sure to include an initializer that registers targets " + "(e.g., initializers: ['target'])." ) + raise ValueError( + f"Target '{request.target_name}' not found in registry. " + f"Available targets: {', '.join(available_names)}" + ) - # Resolve scenario class - scenario_registry = ScenarioRegistry.get_registry_singleton() - scenario_class = scenario_registry.get_class(request.scenario_name) + # Build init kwargs + init_kwargs: dict[str, Any] = { + "objective_target": objective_target, + "max_concurrency": request.max_concurrency, + "max_retries": request.max_retries, + } + + if request.memory_labels: + init_kwargs["memory_labels"] = request.memory_labels + + # Validate and resolve strategies + if request.strategies: + strategy_class = scenario_class.get_strategy_class() + strategy_enums = [] + for name in request.strategies: + try: + strategy_enums.append(strategy_class(name)) + except ValueError: + available_strategies = [s.value for s in strategy_class] + raise ValueError( + f"Strategy '{name}' not found for scenario '{request.scenario_name}'. " + f"Available: {', '.join(available_strategies)}" + ) from None + init_kwargs["scenario_strategies"] = strategy_enums + + # Build dataset config + if request.dataset_names: + init_kwargs["dataset_config"] = DatasetConfiguration( + dataset_names=request.dataset_names, + max_dataset_size=request.max_dataset_size, + ) + elif request.max_dataset_size is not None: + default_config = scenario_class.default_dataset_config() + default_config.max_dataset_size = request.max_dataset_size + init_kwargs["dataset_config"] = default_config + + # Instantiate and initialize scenario + constructor_kwargs: dict[str, Any] = {} + if request.scenario_result_id: + constructor_kwargs["scenario_result_id"] = request.scenario_result_id + scenario = scenario_class(**constructor_kwargs) # type: ignore[call-arg] + scenario.set_params_from_args(args=request.scenario_params or {}) + await scenario.initialize_async(**init_kwargs) + return scenario - # --- Phase 2: Run --- - info.status = ScenarioRunStatus.RUNNING - info.updated_at = datetime.now(timezone.utc) + async def _execute_run_async(self, *, run_id: str) -> None: + """ + Execute a scenario run (background task entry point). - # Build init kwargs - init_kwargs: dict[str, Any] = { - "objective_target": objective_target, - "max_concurrency": request.max_concurrency, - "max_retries": request.max_retries, - } - - if request.memory_labels: - init_kwargs["memory_labels"] = request.memory_labels - - # Resolve strategies - if request.strategies: - strategy_class = scenario_class.get_strategy_class() - strategy_enums = [] - for name in request.strategies: - try: - strategy_enums.append(strategy_class(name)) - except ValueError: - available_strategies = [s.value for s in strategy_class] - raise ValueError( - f"Strategy '{name}' not found for scenario '{request.scenario_name}'. " - f"Available: {', '.join(available_strategies)}" - ) from None - init_kwargs["scenario_strategies"] = strategy_enums - - # Build dataset config - if request.dataset_names: - init_kwargs["dataset_config"] = DatasetConfiguration( - dataset_names=request.dataset_names, - max_dataset_size=request.max_dataset_size, - ) - elif request.max_dataset_size is not None: - default_config = scenario_class.default_dataset_config() - default_config.max_dataset_size = request.max_dataset_size - init_kwargs["dataset_config"] = default_config + Only calls scenario.run_async on the already-initialized scenario. - # Instantiate and execute - scenario = scenario_class() # type: ignore[call-arg] - await scenario.initialize_async(**init_kwargs) - scenario_result = await scenario.run_async() + Args: + run_id: The run to execute. + """ + info = self._runs[run_id] + + try: + scenario_result = await info.scenario.run_async() - # --- Phase 3: Store result --- info.status = ScenarioRunStatus.COMPLETED info.updated_at = datetime.now(timezone.utc) info.result = ScenarioRunResult( @@ -318,17 +336,13 @@ def get_run_results(self, *, run_id: str) -> "ScenarioResultDetailResponse | Non return None if info.status != ScenarioRunStatus.COMPLETED or info.result is None: - raise ValueError( - f"Results are only available for completed runs. Current status: '{info.status}'." - ) + raise ValueError(f"Results are only available for completed runs. Current status: '{info.status}'.") # Retrieve from CentralMemory memory = CentralMemory.get_memory_instance() results = memory.get_scenario_results(scenario_result_ids=[info.result.scenario_result_id]) if not results: - raise ValueError( - f"Scenario result '{info.result.scenario_result_id}' not found in memory." - ) + raise ValueError(f"Scenario result '{info.result.scenario_result_id}' not found in memory.") scenario_result = results[0] display_groups = scenario_result.get_display_groups() @@ -347,7 +361,9 @@ def get_run_results(self, *, run_id: str) -> "ScenarioResultDetailResponse | Non last_response_text = None if ar.last_response is not None: - last_response_text = ar.last_response.value if hasattr(ar.last_response, "value") else str(ar.last_response) + last_response_text = ( + ar.last_response.value if hasattr(ar.last_response, "value") else str(ar.last_response) + ) details.append( AttackResultDetail( diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py index 5498118594..b2340fdcdd 100644 --- a/tests/unit/backend/test_scenario_run_routes.py +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -79,9 +79,7 @@ def test_start_run_invalid_scenario_returns_400(self, client: TestClient) -> Non """Test that an invalid scenario returns 400.""" with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: mock_service = MagicMock() - mock_service.start_run_async = AsyncMock( - side_effect=ValueError("'bad.scenario' not found in registry.") - ) + mock_service.start_run_async = AsyncMock(side_effect=ValueError("'bad.scenario' not found in registry.")) mock_get.return_value = mock_service response = client.post( @@ -118,6 +116,8 @@ def test_start_run_with_all_options(self, client: TestClient) -> None: "max_concurrency": 5, "max_retries": 2, "memory_labels": {"team": "red"}, + "scenario_params": {"max_turns": 10, "threshold": 0.8}, + "initializer_args": {"target": {"endpoint": "https://example.com"}}, }, ) @@ -218,9 +218,7 @@ def test_cancel_completed_run_returns_409(self, client: TestClient) -> None: """Test that cancelling a completed run returns 409 Conflict.""" with patch("pyrit.backend.routes.scenarios.get_scenario_run_service") as mock_get: mock_service = MagicMock() - mock_service.cancel_run_async = AsyncMock( - side_effect=ValueError("Cannot cancel run in 'completed' state.") - ) + mock_service.cancel_run_async = AsyncMock(side_effect=ValueError("Cannot cancel run in 'completed' state.")) mock_get.return_value = mock_service response = client.delete("/api/scenarios/runs/test-run-id") diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 0cb449ea86..7b031ddf51 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -5,7 +5,6 @@ Tests for ScenarioRunService. """ -import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -38,6 +37,7 @@ def _make_request( target_name: str = "my_target", initializers: list[str] | None = None, strategies: list[str] | None = None, + scenario_result_id: str | None = None, ) -> RunScenarioRequest: """Create a RunScenarioRequest for testing.""" return RunScenarioRequest( @@ -45,6 +45,7 @@ def _make_request( target_name=target_name, initializers=initializers, strategies=strategies, + scenario_result_id=scenario_result_id, ) @@ -80,47 +81,176 @@ def mock_initializer_registry(): yield mock_registry, mock_class, mock_instance +@pytest.fixture +def mock_all_registries(): + """Patch all registries with valid defaults for start_run_async tests.""" + mock_scenario_instance = MagicMock() + mock_scenario_instance.initialize_async = AsyncMock() + mock_scenario_instance.run_async = AsyncMock() + + mock_scenario_class = MagicMock(return_value=mock_scenario_instance) + mock_scenario_class.get_strategy_class.return_value = MagicMock() + mock_scenario_class.default_dataset_config.return_value = MagicMock() + + mock_sr = MagicMock() + mock_sr.get_class.return_value = mock_scenario_class + + mock_tr = MagicMock() + mock_tr.get_instance_by_name.return_value = MagicMock() + mock_tr.get_names.return_value = ["my_target"] + + mock_ir = MagicMock() + mock_ir.get_class.return_value = MagicMock(return_value=MagicMock(initialize_async=AsyncMock())) + + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_tr), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton", return_value=mock_ir), + ): + yield { + "scenario_registry": mock_sr, + "target_registry": mock_tr, + "initializer_registry": mock_ir, + "scenario_class": mock_scenario_class, + "scenario_instance": mock_scenario_instance, + } + + class TestScenarioRunServiceStartRun: """Tests for ScenarioRunService.start_run_async.""" - async def test_start_run_returns_pending_status(self, mock_scenario_registry) -> None: - """Test that starting a run returns PENDING status with a run_id.""" + async def test_start_run_returns_running_status(self, mock_all_registries) -> None: + """Test that starting a run returns RUNNING status with a run_id.""" service = ScenarioRunService() - - with patch.object(service, "_execute_run_async", new_callable=AsyncMock): - response = await service.start_run_async(request=_make_request()) + response = await service.start_run_async(request=_make_request()) assert response.run_id is not None - assert response.status == ScenarioRunStatus.PENDING + assert response.status == ScenarioRunStatus.RUNNING assert response.scenario_name == "foundry.red_team_agent" assert response.error is None assert response.result is None async def test_start_run_invalid_scenario_raises_value_error(self) -> None: - """Test that an invalid scenario name raises ValueError.""" + """Test that an invalid scenario name raises ValueError immediately.""" service = ScenarioRunService() - mock_registry = MagicMock() - mock_registry.get_class.side_effect = KeyError("'bad.scenario' not found in registry. Available: foo") - with patch( - f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_registry + mock_sr = MagicMock() + mock_sr.get_class.side_effect = KeyError("'bad.scenario' not found in registry. Available: foo") + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton"), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), ): with pytest.raises(ValueError, match="not found in registry"): await service.start_run_async(request=_make_request(scenario_name="bad.scenario")) - async def test_start_run_exceeds_concurrent_limit(self, mock_scenario_registry) -> None: - """Test that exceeding concurrent run limit raises ValueError.""" + async def test_start_run_invalid_target_raises_value_error(self) -> None: + """Test that an invalid target name raises ValueError immediately.""" service = ScenarioRunService() - with patch.object(service, "_execute_run_async", new_callable=AsyncMock): - # Fill up to the limit - for _ in range(MAX_CONCURRENT_RUNS): - await service.start_run_async(request=_make_request()) + mock_sr = MagicMock() + mock_sr.get_class.return_value = MagicMock() + + mock_tr = MagicMock() + mock_tr.get_instance_by_name.return_value = None + mock_tr.get_names.return_value = ["other_target"] - # Next one should fail - with pytest.raises(ValueError, match="Maximum concurrent runs"): + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_tr), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), + ): + with pytest.raises(ValueError, match="my_target.*not found in registry"): await service.start_run_async(request=_make_request()) + async def test_start_run_invalid_initializer_raises_value_error(self) -> None: + """Test that an invalid initializer name raises ValueError immediately.""" + service = ScenarioRunService() + + mock_sr = MagicMock() + mock_sr.get_class.return_value = MagicMock() + + mock_ir = MagicMock() + mock_ir.get_class.side_effect = KeyError("'bad_init' not found") + + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton"), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton", return_value=mock_ir), + ): + with pytest.raises(ValueError, match="Initializer not found"): + await service.start_run_async(request=_make_request(initializers=["bad_init"])) + + async def test_start_run_invalid_strategy_raises_value_error(self) -> None: + """Test that an invalid strategy name raises ValueError immediately.""" + service = ScenarioRunService() + + mock_strategy_class = MagicMock(side_effect=ValueError("not a valid strategy")) + mock_strategy_class.__iter__ = MagicMock(return_value=iter([MagicMock(value="valid_strat")])) + + mock_scenario_class = MagicMock() + mock_scenario_class.get_strategy_class.return_value = mock_strategy_class + + mock_sr = MagicMock() + mock_sr.get_class.return_value = mock_scenario_class + + mock_tr = MagicMock() + mock_tr.get_instance_by_name.return_value = MagicMock() + + with ( + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_tr), + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), + ): + with pytest.raises(ValueError, match="Strategy.*not found for scenario"): + await service.start_run_async(request=_make_request(strategies=["bad_strategy"])) + + async def test_start_run_exceeds_concurrent_limit(self, mock_all_registries) -> None: + """Test that exceeding concurrent run limit raises ValueError.""" + service = ScenarioRunService() + + # Fill up to the limit + for _ in range(MAX_CONCURRENT_RUNS): + await service.start_run_async(request=_make_request()) + + # Next one should fail + with pytest.raises(ValueError, match="Maximum concurrent runs"): + await service.start_run_async(request=_make_request()) + + async def test_start_run_runs_initializers(self, mock_all_registries) -> None: + """Test that initializers are run during start_run_async.""" + service = ScenarioRunService() + mock_ir = mock_all_registries["initializer_registry"] + mock_init_instance = mock_ir.get_class.return_value.return_value + + response = await service.start_run_async( + request=_make_request(initializers=["target", "load_default_datasets"]) + ) + + assert response.status == ScenarioRunStatus.RUNNING + assert mock_init_instance.initialize_async.await_count == 2 + + async def test_start_run_passes_scenario_result_id_for_resume(self, mock_all_registries) -> None: + """Test that scenario_result_id is passed to the scenario constructor for resumption.""" + service = ScenarioRunService() + mock_scenario_class = mock_all_registries["scenario_class"] + + response = await service.start_run_async( + request=_make_request(scenario_result_id="existing-result-uuid") + ) + + assert response.status == ScenarioRunStatus.RUNNING + mock_scenario_class.assert_called_once_with(scenario_result_id="existing-result-uuid") + + async def test_start_run_omits_scenario_result_id_when_none(self, mock_all_registries) -> None: + """Test that scenario_result_id is not passed to constructor when not provided.""" + service = ScenarioRunService() + mock_scenario_class = mock_all_registries["scenario_class"] + + await service.start_run_async(request=_make_request()) + + mock_scenario_class.assert_called_once_with() + class TestScenarioRunServiceGetRun: """Tests for ScenarioRunService.get_run.""" @@ -131,12 +261,10 @@ async def test_get_run_returns_none_for_unknown_id(self) -> None: result = service.get_run(run_id="nonexistent-id") assert result is None - async def test_get_run_returns_existing_run(self, mock_scenario_registry) -> None: + async def test_get_run_returns_existing_run(self, mock_all_registries) -> None: """Test that get_run returns a started run.""" service = ScenarioRunService() - - with patch.object(service, "_execute_run_async", new_callable=AsyncMock): - response = await service.start_run_async(request=_make_request()) + response = await service.start_run_async(request=_make_request()) fetched = service.get_run(run_id=response.run_id) assert fetched is not None @@ -153,13 +281,12 @@ async def test_list_runs_empty(self) -> None: result = service.list_runs() assert result.items == [] - async def test_list_runs_returns_all_runs(self, mock_scenario_registry) -> None: + async def test_list_runs_returns_all_runs(self, mock_all_registries) -> None: """Test that list_runs returns all tracked runs.""" service = ScenarioRunService() - with patch.object(service, "_execute_run_async", new_callable=AsyncMock): - await service.start_run_async(request=_make_request()) - await service.start_run_async(request=_make_request()) + await service.start_run_async(request=_make_request()) + await service.start_run_async(request=_make_request()) result = service.list_runs() assert len(result.items) == 2 @@ -174,23 +301,19 @@ async def test_cancel_run_returns_none_for_unknown_id(self) -> None: result = await service.cancel_run_async(run_id="nonexistent-id") assert result is None - async def test_cancel_run_sets_cancelled_status(self, mock_scenario_registry) -> None: + async def test_cancel_run_sets_cancelled_status(self, mock_all_registries) -> None: """Test that cancelling a running scenario sets CANCELLED status.""" service = ScenarioRunService() - - with patch.object(service, "_execute_run_async", new_callable=AsyncMock): - response = await service.start_run_async(request=_make_request()) + response = await service.start_run_async(request=_make_request()) result = await service.cancel_run_async(run_id=response.run_id) assert result is not None assert result.status == ScenarioRunStatus.CANCELLED - async def test_cancel_completed_run_raises_value_error(self, mock_scenario_registry) -> None: + async def test_cancel_completed_run_raises_value_error(self, mock_all_registries) -> None: """Test that cancelling a completed run raises ValueError.""" service = ScenarioRunService() - - with patch.object(service, "_execute_run_async", new_callable=AsyncMock): - response = await service.start_run_async(request=_make_request()) + response = await service.start_run_async(request=_make_request()) # Manually set to COMPLETED service._runs[response.run_id].status = ScenarioRunStatus.COMPLETED @@ -225,15 +348,9 @@ async def test_execute_run_completes_successfully(self) -> None: mock_target = MagicMock() with ( - patch( - f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton" - ) as mock_sr, - patch( - f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton" - ) as mock_tr, - patch( - f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton" - ), + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton") as mock_sr, + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton") as mock_tr, + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), ): mock_sr.return_value.get_class.return_value = mock_scenario_class mock_tr.return_value.get_instance_by_name.return_value = mock_target @@ -253,90 +370,37 @@ async def test_execute_run_completes_successfully(self) -> None: assert run.result.strategies_used == ["base64"] async def test_execute_run_fails_with_error(self) -> None: - """Test that a failed execution transitions to FAILED with error message.""" - service = ScenarioRunService() - - with ( - patch( - f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton" - ) as mock_sr, - patch( - f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton" - ) as mock_tr, - patch( - f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton" - ), - ): - mock_sr.return_value.get_class.return_value = MagicMock() - mock_tr.return_value.get_instance_by_name.return_value = None - mock_tr.return_value.get_names.return_value = ["other_target"] - - response = await service.start_run_async(request=_make_request()) - - # Wait for the background task - task = service._runs[response.run_id].task - assert task is not None - await task - - run = service.get_run(run_id=response.run_id) - assert run is not None - assert run.status == ScenarioRunStatus.FAILED - assert run.error is not None - assert "my_target" in run.error - - async def test_execute_run_with_initializers(self) -> None: - """Test that initializers are run before scenario execution.""" + """Test that a run_async failure transitions to FAILED with error message.""" service = ScenarioRunService() - mock_scenario_result = MagicMock() - mock_scenario_result.id = "result-uuid" - mock_scenario_result.scenario_run_state = "COMPLETED" - mock_scenario_result.get_strategies_used.return_value = [] - mock_scenario_result.attack_results = {} - mock_scenario_result.number_tries = 1 - mock_scenario_result.completion_time = None - mock_scenario_instance = MagicMock() mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(return_value=mock_scenario_result) + mock_scenario_instance.run_async = AsyncMock(side_effect=RuntimeError("scenario exploded")) mock_scenario_class = MagicMock(return_value=mock_scenario_instance) - - mock_initializer_instance = MagicMock() - mock_initializer_instance.initialize_async = AsyncMock() - mock_initializer_class = MagicMock(return_value=mock_initializer_instance) - - mock_target = MagicMock() + mock_scenario_class.get_strategy_class.return_value = MagicMock() + mock_scenario_class.default_dataset_config.return_value = MagicMock() with ( - patch( - f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton" - ) as mock_sr, - patch( - f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton" - ) as mock_tr, - patch( - f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton" - ) as mock_ir, + patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton") as mock_sr, + patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton") as mock_tr, + patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), ): mock_sr.return_value.get_class.return_value = mock_scenario_class - mock_tr.return_value.get_instance_by_name.return_value = mock_target - mock_ir.return_value.get_class.return_value = mock_initializer_class + mock_tr.return_value.get_instance_by_name.return_value = MagicMock() - response = await service.start_run_async( - request=_make_request(initializers=["target", "load_default_datasets"]) - ) + response = await service.start_run_async(request=_make_request()) + # Wait for the background task task = service._runs[response.run_id].task assert task is not None await task - # Initializer should have been called twice (once per name) - assert mock_initializer_instance.initialize_async.await_count == 2 - run = service.get_run(run_id=response.run_id) assert run is not None - assert run.status == ScenarioRunStatus.COMPLETED + assert run.status == ScenarioRunStatus.FAILED + assert run.error is not None + assert "scenario exploded" in run.error class TestScenarioRunServiceGetResults: @@ -348,26 +412,22 @@ def test_get_results_returns_none_for_unknown_id(self) -> None: result = service.get_run_results(run_id="nonexistent-id") assert result is None - async def test_get_results_raises_if_not_completed(self, mock_scenario_registry) -> None: + async def test_get_results_raises_if_not_completed(self, mock_all_registries) -> None: """Test that get_run_results raises ValueError if run is not completed.""" service = ScenarioRunService() + response = await service.start_run_async(request=_make_request()) - with patch.object(service, "_execute_run_async", new_callable=AsyncMock): - response = await service.start_run_async(request=_make_request()) - - # Run is in PENDING state + # Run is in RUNNING state with pytest.raises(ValueError, match="only available for completed runs"): service.get_run_results(run_id=response.run_id) - async def test_get_results_returns_details_for_completed_run(self, mock_scenario_registry) -> None: + async def test_get_results_returns_details_for_completed_run(self, mock_all_registries) -> None: """Test that get_run_results returns full details for a completed run.""" from pyrit.backend.models.scenarios import ScenarioRunResult from pyrit.models import AttackOutcome service = ScenarioRunService() - - with patch.object(service, "_execute_run_async", new_callable=AsyncMock): - response = await service.start_run_async(request=_make_request()) + response = await service.start_run_async(request=_make_request()) # Manually set run to completed with a result info = service._runs[response.run_id] diff --git a/tests/unit/backend/test_scenario_service.py b/tests/unit/backend/test_scenario_service.py index 7f435d76a5..1a56086c25 100644 --- a/tests/unit/backend/test_scenario_service.py +++ b/tests/unit/backend/test_scenario_service.py @@ -210,7 +210,7 @@ class TestScenarioRoutes: """Tests for scenario API routes.""" def test_list_scenarios_returns_200(self, client: TestClient) -> None: - """Test that GET /api/scenarios returns 200.""" + """Test that GET /api/scenarios/catalog returns 200.""" with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: mock_service = MagicMock() mock_service.list_scenarios_async = AsyncMock( @@ -221,7 +221,7 @@ def test_list_scenarios_returns_200(self, client: TestClient) -> None: ) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios") + response = client.get("/api/scenarios/catalog") assert response.status_code == status.HTTP_200_OK data = response.json() @@ -229,7 +229,7 @@ def test_list_scenarios_returns_200(self, client: TestClient) -> None: assert data["pagination"]["has_more"] is False def test_list_scenarios_with_items(self, client: TestClient) -> None: - """Test that GET /api/scenarios returns scenario data.""" + """Test that GET /api/scenarios/catalog returns scenario data.""" summary = ScenarioSummary( scenario_name="foundry.red_team_agent", scenario_type="RedTeamAgentScenario", @@ -251,7 +251,7 @@ def test_list_scenarios_with_items(self, client: TestClient) -> None: ) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios") + response = client.get("/api/scenarios/catalog") assert response.status_code == status.HTTP_200_OK data = response.json() @@ -277,13 +277,13 @@ def test_list_scenarios_passes_pagination_params(self, client: TestClient) -> No ) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios?limit=10&cursor=test.scenario_1") + response = client.get("/api/scenarios/catalog?limit=10&cursor=test.scenario_1") assert response.status_code == status.HTTP_200_OK mock_service.list_scenarios_async.assert_called_once_with(limit=10, cursor="test.scenario_1") def test_get_scenario_returns_200(self, client: TestClient) -> None: - """Test that GET /api/scenarios/{name} returns 200 when found.""" + """Test that GET /api/scenarios/catalog/{name} returns 200 when found.""" summary = ScenarioSummary( scenario_name="foundry.red_team_agent", scenario_type="RedTeamAgentScenario", @@ -300,20 +300,20 @@ def test_get_scenario_returns_200(self, client: TestClient) -> None: mock_service.get_scenario_async = AsyncMock(return_value=summary) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios/foundry.red_team_agent") + response = client.get("/api/scenarios/catalog/foundry.red_team_agent") assert response.status_code == status.HTTP_200_OK data = response.json() assert data["scenario_name"] == "foundry.red_team_agent" def test_get_scenario_returns_404_when_not_found(self, client: TestClient) -> None: - """Test that GET /api/scenarios/{name} returns 404 when not found.""" + """Test that GET /api/scenarios/catalog/{name} returns 404 when not found.""" with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: mock_service = MagicMock() mock_service.get_scenario_async = AsyncMock(return_value=None) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios/nonexistent") + response = client.get("/api/scenarios/catalog/nonexistent") assert response.status_code == status.HTTP_404_NOT_FOUND @@ -335,7 +335,7 @@ def test_get_scenario_with_dotted_name(self, client: TestClient) -> None: mock_service.get_scenario_async = AsyncMock(return_value=summary) mock_get_service.return_value = mock_service - response = client.get("/api/scenarios/garak.encoding") + response = client.get("/api/scenarios/catalog/garak.encoding") assert response.status_code == status.HTTP_200_OK mock_service.get_scenario_async.assert_called_once_with(scenario_name="garak.encoding") From eed310bf817948eb2513c3fe4b019614c6b7dd8a Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 6 May 2026 17:17:57 -0700 Subject: [PATCH 03/10] refactoring again --- .pyrit_conf_example | 6 + pyrit/backend/routes/scenarios.py | 11 +- .../backend/services/scenario_run_service.py | 328 +++++++++-------- pyrit/cli/frontend_core.py | 2 + pyrit/cli/pyrit_backend.py | 1 + .../attack/multi_turn/tree_of_attacks.py | 2 +- pyrit/memory/azure_sql_memory.py | 12 +- pyrit/memory/memory_interface.py | 29 +- pyrit/memory/memory_models.py | 2 +- pyrit/memory/sqlite_memory.py | 10 +- pyrit/models/scenario_result.py | 7 +- .../random_translation_converter.py | 2 +- .../azure_content_filter_scorer.py | 2 +- pyrit/setup/configuration_loader.py | 1 + .../unit/backend/test_scenario_run_service.py | 331 +++++++++--------- 15 files changed, 424 insertions(+), 322 deletions(-) diff --git a/.pyrit_conf_example b/.pyrit_conf_example index c45bb390ce..9d9e66305d 100644 --- a/.pyrit_conf_example +++ b/.pyrit_conf_example @@ -111,6 +111,12 @@ operation: op_trash_panda # - /path/to/.env # - /path/to/.env.local +# Max Concurrent Scenario Runs +# ---------------------------- +# Maximum number of scenario runs that can execute concurrently in the backend. +# Applies only to the pyrit_backend server. +max_concurrent_scenario_runs: 3 + # Silent Mode # ----------- # If true, suppresses print statements during initialization. diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py index 6c1fde6c47..77f73a38c5 100644 --- a/pyrit/backend/routes/scenarios.py +++ b/pyrit/backend/routes/scenarios.py @@ -122,15 +122,18 @@ async def start_scenario_run(request: RunScenarioRequest) -> ScenarioRunResponse "/runs", response_model=ScenarioRunListResponse, ) -async def list_scenario_runs() -> ScenarioRunListResponse: +async def list_scenario_runs(limit: int = Query(100, ge=1)) -> ScenarioRunListResponse: """ - List all tracked scenario runs. + List tracked scenario runs (most recent first). + + Args: + limit (int): Maximum number of runs to return. Defaults to 100. Returns: - ScenarioRunListResponse: All runs, most recent first. + ScenarioRunListResponse: Runs, most recent first. """ service = get_scenario_run_service() - return service.list_runs() + return service.list_runs(limit=limit) @router.get( diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 3da690b3a1..bca74d862f 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -10,51 +10,61 @@ import asyncio import logging -import uuid -from dataclasses import dataclass, field -from datetime import datetime, timezone -from functools import lru_cache +from dataclasses import dataclass from typing import Any from pyrit.backend.models.scenarios import ( + AtomicAttackResults, + AttackResultDetail, RunScenarioRequest, + ScenarioResultDetailResponse, ScenarioRunListResponse, ScenarioRunResponse, ScenarioRunResult, ScenarioRunStatus, ) +from pyrit.memory import CentralMemory +from pyrit.models import AttackOutcome, ScenarioResult +from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry +from pyrit.scenario import Scenario +from pyrit.scenario.core import DatasetConfiguration logger = logging.getLogger(__name__) -MAX_CONCURRENT_RUNS = 3 -MAX_COMPLETED_RUNS = 50 +_DEFAULT_MAX_CONCURRENT_RUNS = 3 + +# Maps DB ScenarioRunState values to API ScenarioRunStatus +_STATE_TO_STATUS: dict[str, ScenarioRunStatus] = { + "CREATED": ScenarioRunStatus.INITIALIZING, + "IN_PROGRESS": ScenarioRunStatus.RUNNING, + "COMPLETED": ScenarioRunStatus.COMPLETED, + "FAILED": ScenarioRunStatus.FAILED, + "CANCELLED": ScenarioRunStatus.CANCELLED, +} @dataclass -class _RunInfo: - """Internal tracking state for a scenario run.""" - - run_id: str - request: RunScenarioRequest - status: ScenarioRunStatus = ScenarioRunStatus.PENDING - created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) +class _ActiveTask: + """Tracks an in-flight scenario run's asyncio task.""" + + scenario_result_id: str task: asyncio.Task[None] | None = None + scenario: Scenario | None = None error: str | None = None - result: ScenarioRunResult | None = None - scenario: Any = None class ScenarioRunService: """ Service for managing scenario run lifecycle. - Runs are tracked in-memory and executed as background asyncio tasks. + Uses CentralMemory (database) as the source of truth for run state. + Keeps an in-memory dict only for active asyncio tasks (cancellation support). """ - def __init__(self) -> None: + def __init__(self, *, max_concurrent_runs: int = _DEFAULT_MAX_CONCURRENT_RUNS) -> None: """Initialize the scenario run service.""" - self._runs: dict[str, _RunInfo] = {} + self._max_concurrent_runs = max_concurrent_runs + self._active_tasks: dict[str, _ActiveTask] = {} async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunResponse: """ @@ -75,58 +85,63 @@ async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunRe ValueError: If scenario, target, initializer, or strategy cannot be found, or concurrent limit exceeded. """ - # Check concurrent run limit - active_count = sum( - 1 - for r in self._runs.values() - if r.status in (ScenarioRunStatus.PENDING, ScenarioRunStatus.INITIALIZING, ScenarioRunStatus.RUNNING) - ) - if active_count >= MAX_CONCURRENT_RUNS: + if len(self._active_tasks) >= self._max_concurrent_runs: raise ValueError( - f"Maximum concurrent runs ({MAX_CONCURRENT_RUNS}) reached. " + f"Maximum concurrent runs ({self._max_concurrent_runs}) reached. " "Wait for an existing run to complete or cancel one." ) # Perform all initialization eagerly — errors propagate to caller scenario = await self._initialize_run_async(request=request) - # Create run info in RUNNING state (initialization already complete) - run_id = str(uuid.uuid4()) - info = _RunInfo(run_id=run_id, request=request, status=ScenarioRunStatus.RUNNING, scenario=scenario) - self._runs[run_id] = info + # scenario_result_id is set during initialize_async + scenario_result_id = scenario._scenario_result_id + if scenario_result_id is None: + raise ValueError("Scenario did not produce a scenario_result_id during initialization.") - # Evict old completed runs if over limit - self._evict_completed_runs() + # Track active task + active = _ActiveTask(scenario_result_id=scenario_result_id, scenario=scenario) + self._active_tasks[scenario_result_id] = active # Spawn background task (only runs scenario.run_async) - task = asyncio.create_task(self._execute_run_async(run_id=run_id)) - info.task = task + task = asyncio.create_task(self._execute_run_async(scenario_result_id=scenario_result_id)) + active.task = task - return self._to_response(info) + response = self._build_response(scenario_result_id=scenario_result_id) + assert response is not None # guaranteed: we just inserted into DB via initialize_async + return response def get_run(self, *, run_id: str) -> ScenarioRunResponse | None: """ - Get the current status of a scenario run. + Get the current status of a scenario run by querying the database. Args: - run_id: The unique run identifier. + run_id: The scenario result ID (run identifier). Returns: ScenarioRunResponse if found, None otherwise. """ - info = self._runs.get(run_id) - if info is None: - return None - return self._to_response(info) + return self._build_response(scenario_result_id=run_id) - def list_runs(self) -> ScenarioRunListResponse: + def list_runs(self, *, limit: int = 100) -> ScenarioRunListResponse: """ - List all tracked scenario runs (most recent first). + List scenario runs by querying the database (most recent first). + + Args: + limit (int): Maximum number of runs to return. Defaults to 100. Returns: - ScenarioRunListResponse with all runs. + ScenarioRunListResponse with runs. """ - items = [self._to_response(info) for info in reversed(self._runs.values())] + memory = CentralMemory.get_memory_instance() + + # This is expensive, and we don't need all the data. At some point + # we may want to add a lightweight "list" query to the DB layer that only + results = memory.get_scenario_results(limit=limit) + items = [ + self._build_response_from_db(scenario_result=sr) + for sr in results + ] return ScenarioRunListResponse(items=items) async def cancel_run_async(self, *, run_id: str) -> ScenarioRunResponse | None: @@ -134,37 +149,45 @@ async def cancel_run_async(self, *, run_id: str) -> ScenarioRunResponse | None: Cancel a running scenario. Args: - run_id: The unique run identifier. + run_id: The scenario result ID (run identifier). Returns: Updated ScenarioRunResponse if found, None if run_id not found. Raises: - ValueError: If the run is already in a terminal state. + ValueError: If the run is already in a terminal state or not active. """ - info = self._runs.get(run_id) - if info is None: + # Verify run exists in DB + memory = CentralMemory.get_memory_instance() + results = memory.get_scenario_results(scenario_result_ids=[run_id]) + if not results: return None - terminal_states = (ScenarioRunStatus.COMPLETED, ScenarioRunStatus.FAILED, ScenarioRunStatus.CANCELLED) - if info.status in terminal_states: - raise ValueError(f"Cannot cancel run in '{info.status}' state.") + scenario_result = results[0] + db_status = _STATE_TO_STATUS.get(scenario_result.scenario_run_state, ScenarioRunStatus.FAILED) + + if db_status in (ScenarioRunStatus.COMPLETED, ScenarioRunStatus.FAILED, ScenarioRunStatus.CANCELLED): + raise ValueError(f"Cannot cancel run in '{db_status}' state.") - # Cancel the asyncio task - if info.task is not None and not info.task.done(): - info.task.cancel() + # Cancel the asyncio task if active + active = self._active_tasks.get(run_id) + if active is not None: + if active.task is not None and not active.task.done(): + active.task.cancel() - info.status = ScenarioRunStatus.CANCELLED - info.updated_at = datetime.now(timezone.utc) - return self._to_response(info) + # Persist cancelled state to DB + memory.update_scenario_run_state(scenario_result_id=run_id, scenario_run_state="CANCELLED") - async def _initialize_run_async(self, *, request: RunScenarioRequest) -> Any: + return self._build_response(scenario_result_id=run_id) + + async def _initialize_run_async(self, *, request: RunScenarioRequest) -> Scenario: """ Validate inputs and initialize the scenario eagerly. Performs all validation (scenario, initializers, target, strategies) and calls scenario.initialize_async so that any errors are raised immediately - to the caller. + to the caller. Running initialization on creation simplifies error handling and ensures + that the scenario is fully ready to run when we spawn the background task. Args: request: The run request with scenario name, target, and options. @@ -176,9 +199,6 @@ async def _initialize_run_async(self, *, request: RunScenarioRequest) -> Any: ValueError: If any validation fails (bad scenario name, missing target, invalid strategy, unknown initializer, etc.). """ - from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry - from pyrit.scenario.core import DatasetConfiguration - # Validate scenario exists scenario_registry = ScenarioRegistry.get_registry_singleton() try: @@ -211,8 +231,7 @@ async def _initialize_run_async(self, *, request: RunScenarioRequest) -> Any: "(e.g., initializers: ['target'])." ) raise ValueError( - f"Target '{request.target_name}' not found in registry. " - f"Available targets: {', '.join(available_names)}" + f"Target '{request.target_name}' not found in registry. Available targets: {', '.join(available_names)}" ) # Build init kwargs @@ -260,54 +279,95 @@ async def _initialize_run_async(self, *, request: RunScenarioRequest) -> Any: await scenario.initialize_async(**init_kwargs) return scenario - async def _execute_run_async(self, *, run_id: str) -> None: + async def _execute_run_async(self, *, scenario_result_id: str) -> None: """ Execute a scenario run (background task entry point). Only calls scenario.run_async on the already-initialized scenario. + Removes the task from _active_tasks when done. Args: - run_id: The run to execute. + scenario_result_id: The scenario result ID for this run. """ - info = self._runs[run_id] + active = self._active_tasks[scenario_result_id] + assert active.scenario is not None try: - scenario_result = await info.scenario.run_async() + await active.scenario.run_async() + + except asyncio.CancelledError: + logger.info(f"Scenario run {scenario_result_id} was cancelled.") + + except Exception as e: + active.error = str(e) + logger.exception(f"Scenario run {scenario_result_id} failed: {e}") + + finally: + del self._active_tasks[scenario_result_id] + + def _build_response(self, *, scenario_result_id: str) -> ScenarioRunResponse | None: + """ + Build a ScenarioRunResponse by querying the database and merging active task state. + + Args: + scenario_result_id: The scenario result ID. + + Returns: + ScenarioRunResponse if found in the database, None otherwise. + """ + memory = CentralMemory.get_memory_instance() + results = memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) + if not results: + return None + return self._build_response_from_db(scenario_result=results[0]) - info.status = ScenarioRunStatus.COMPLETED - info.updated_at = datetime.now(timezone.utc) - info.result = ScenarioRunResult( - scenario_result_id=str(scenario_result.id), + def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> ScenarioRunResponse: + """ + Build a ScenarioRunResponse from a database ScenarioResult, merged with active task info. + + Args: + scenario_result: A ScenarioResult retrieved from CentralMemory. + + Returns: + The API response model. + """ + scenario_result_id = str(scenario_result.id) + active = self._active_tasks.get(scenario_result_id) + + status = _STATE_TO_STATUS.get(scenario_result.scenario_run_state, ScenarioRunStatus.FAILED) + + # Build result summary for completed runs + result = None + if status == ScenarioRunStatus.COMPLETED: + completed_attacks = sum( + 1 + for results in scenario_result.attack_results.values() + for ar in results + if ar.outcome in (AttackOutcome.SUCCESS, AttackOutcome.FAILURE) + ) + result = ScenarioRunResult( + scenario_result_id=scenario_result_id, run_state=scenario_result.scenario_run_state, strategies_used=scenario_result.get_strategies_used(), total_attacks=len(scenario_result.attack_results), - completed_attacks=len(scenario_result.attack_results), + completed_attacks=completed_attacks, number_tries=scenario_result.number_tries, completion_time=scenario_result.completion_time, ) - except asyncio.CancelledError: - info.status = ScenarioRunStatus.CANCELLED - info.updated_at = datetime.now(timezone.utc) - logger.info(f"Scenario run {run_id} was cancelled.") + error = active.error if active else None - except Exception as e: - info.status = ScenarioRunStatus.FAILED - info.updated_at = datetime.now(timezone.utc) - info.error = str(e) - logger.exception(f"Scenario run {run_id} failed: {e}") - - def _evict_completed_runs(self) -> None: - """Remove oldest completed runs if over the retention limit.""" - terminal_states = (ScenarioRunStatus.COMPLETED, ScenarioRunStatus.FAILED, ScenarioRunStatus.CANCELLED) - completed = [r for r in self._runs.values() if r.status in terminal_states] - if len(completed) > MAX_COMPLETED_RUNS: - # Sort by creation time, remove oldest - completed.sort(key=lambda r: r.created_at) - for run_info in completed[: len(completed) - MAX_COMPLETED_RUNS]: - del self._runs[run_info.run_id] - - def get_run_results(self, *, run_id: str) -> "ScenarioResultDetailResponse | None": + return ScenarioRunResponse( + run_id=scenario_result_id, + scenario_name=scenario_result.scenario_identifier.name, + status=status, + created_at=scenario_result.completion_time, + updated_at=scenario_result.completion_time, + error=error, + result=result, + ) + + def get_run_results(self, *, run_id: str) -> ScenarioResultDetailResponse | None: """ Get detailed results for a completed scenario run. @@ -315,40 +375,28 @@ def get_run_results(self, *, run_id: str) -> "ScenarioResultDetailResponse | Non to a detailed response model with per-attack outcomes. Args: - run_id: The unique run identifier. + run_id: The scenario result ID (run identifier). Returns: ScenarioResultDetailResponse if the run is completed and results exist, None if run not found. Raises: - ValueError: If the run is not in a completed state or results not found in memory. + ValueError: If the run is not in a completed state. """ - from pyrit.backend.models.scenarios import ( - AtomicAttackResults, - AttackResultDetail, - ScenarioResultDetailResponse, - ) - from pyrit.memory import CentralMemory - from pyrit.models import AttackOutcome - - info = self._runs.get(run_id) - if info is None: - return None - - if info.status != ScenarioRunStatus.COMPLETED or info.result is None: - raise ValueError(f"Results are only available for completed runs. Current status: '{info.status}'.") - - # Retrieve from CentralMemory memory = CentralMemory.get_memory_instance() - results = memory.get_scenario_results(scenario_result_ids=[info.result.scenario_result_id]) + results = memory.get_scenario_results(scenario_result_ids=[run_id]) if not results: - raise ValueError(f"Scenario result '{info.result.scenario_result_id}' not found in memory.") + return None scenario_result = results[0] - display_groups = scenario_result.get_display_groups() + + if scenario_result.scenario_run_state != "COMPLETED": + status = _STATE_TO_STATUS.get(scenario_result.scenario_run_state, ScenarioRunStatus.FAILED) + raise ValueError(f"Results are only available for completed runs. Current status: '{status}'.") # Build per-attack detail attacks: list[AtomicAttackResults] = [] + display_group_map = scenario_result.display_group_map for attack_name, attack_results in scenario_result.attack_results.items(): details: list[AttackResultDetail] = [] success_count = 0 @@ -357,13 +405,11 @@ def get_run_results(self, *, run_id: str) -> "ScenarioResultDetailResponse | Non for ar in attack_results: score_value = None if ar.last_score is not None: - score_value = ar.last_score.get_value() + score_value = str(ar.last_score.get_value()) last_response_text = None if ar.last_response is not None: - last_response_text = ( - ar.last_response.value if hasattr(ar.last_response, "value") else str(ar.last_response) - ) + last_response_text = str(ar.last_response) details.append( AttackResultDetail( @@ -385,15 +431,10 @@ def get_run_results(self, *, run_id: str) -> "ScenarioResultDetailResponse | Non elif ar.outcome == AttackOutcome.FAILURE: failure_count += 1 - # Find display group for this attack - display_group = None - if hasattr(scenario_result, "_display_group_map") and scenario_result._display_group_map: - display_group = scenario_result._display_group_map.get(attack_name) - attacks.append( AtomicAttackResults( atomic_attack_name=attack_name, - display_group=display_group, + display_group=display_group_map.get(attack_name), results=details, success_count=success_count, failure_count=failure_count, @@ -413,26 +454,31 @@ def get_run_results(self, *, run_id: str) -> "ScenarioResultDetailResponse | Non attacks=attacks, ) - @staticmethod - def _to_response(info: _RunInfo) -> ScenarioRunResponse: - """Convert internal run info to API response model.""" - return ScenarioRunResponse( - run_id=info.run_id, - scenario_name=info.request.scenario_name, - status=info.status, - created_at=info.created_at, - updated_at=info.updated_at, - error=info.error, - result=info.result, - ) + +_service_instance: ScenarioRunService | None = None -@lru_cache(maxsize=1) def get_scenario_run_service() -> ScenarioRunService: """ Get the global scenario run service instance. + On first call, reads ``max_concurrent_scenario_runs`` from ``app.state`` + (set by ``pyrit_backend`` CLI) if available, otherwise uses the default. + Returns: The singleton ScenarioRunService instance. """ - return ScenarioRunService() + global _service_instance + if _service_instance is not None: + return _service_instance + + max_runs = _DEFAULT_MAX_CONCURRENT_RUNS + try: + from pyrit.backend.main import app + + max_runs = getattr(app.state, "max_concurrent_scenario_runs", _DEFAULT_MAX_CONCURRENT_RUNS) + except Exception: + pass + + _service_instance = ScenarioRunService(max_concurrent_runs=max_runs) + return _service_instance diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index b75634891a..c17eb83b54 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -146,6 +146,7 @@ def __init__( self._env_files = config._resolve_env_files() self._operator = config.operator self._operation = config.operation + self._max_concurrent_scenario_runs = config.max_concurrent_scenario_runs # Lazy-loaded registries self._scenario_registry: Optional[ScenarioRegistry] = None @@ -221,6 +222,7 @@ def with_overrides( derived._env_files = self._env_files derived._operator = self._operator derived._operation = self._operation + derived._max_concurrent_scenario_runs = self._max_concurrent_scenario_runs derived._scenario_config = self._scenario_config # Apply overrides or inherit diff --git a/pyrit/cli/pyrit_backend.py b/pyrit/cli/pyrit_backend.py index f45cc0c448..8eed2cc929 100644 --- a/pyrit/cli/pyrit_backend.py +++ b/pyrit/cli/pyrit_backend.py @@ -198,6 +198,7 @@ async def initialize_and_run_async(*, parsed_args: Namespace) -> int: if context._operation: default_labels["operation"] = context._operation app.state.default_labels = default_labels + app.state.max_concurrent_scenario_runs = context._max_concurrent_scenario_runs display_host = parsed_args.host print(f"🚀 Starting PyRIT backend on http://{display_host}:{parsed_args.port}") diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index e23af1eabf..eebc771295 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -1354,7 +1354,7 @@ class TreeOfAttacksWithPruningAttack(AttackStrategy[TAPAttackContext, TAPAttackR def __init__( self, *, - objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-assignment, ty:invalid-parameter-default] + objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] attack_adversarial_config: AttackAdversarialConfig, attack_converter_config: Optional[AttackConverterConfig] = None, attack_scoring_config: Optional[AttackScoringConfig] = None, diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index dbb228b435..753236968d 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -786,6 +786,8 @@ def _query_entries( conditions: Optional[Any] = None, distinct: bool = False, join_scores: bool = False, + order_by: Optional[Any] = None, + limit: int | None = None, ) -> MutableSequence[Model]: """ Fetch data from the specified table model with optional conditions. @@ -795,6 +797,8 @@ def _query_entries( conditions: SQLAlchemy filter conditions (Optional). distinct: Flag to return distinct rows (defaults to False). join_scores: Flag to join the scores table with entries (defaults to False). + order_by: SQLAlchemy order_by clause (Optional). + limit (int | None): Maximum number of rows to return. Defaults to None (no limit). Returns: List of model instances representing the rows fetched from the table. @@ -814,8 +818,12 @@ def _query_entries( ) if conditions is not None: query = query.filter(conditions) + if order_by is not None: + query = query.order_by(order_by) if distinct: - return query.distinct().all() + query = query.distinct() + if limit is not None: + query = query.limit(limit) return query.all() except SQLAlchemyError as e: logger.exception(f"Error fetching data from table {model_class.__tablename__}: {e}") # type: ignore[ty:unresolved-attribute] @@ -846,7 +854,7 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict # attributes from the (potentially stale) detached object # and silently overwrite concurrent updates to columns # that are NOT in update_fields. - entry_in_session = session.get(type(entry), entry.id) # type: ignore[ty:unresolved-attribute] + entry_in_session = session.get(type(entry), entry.id) if entry_in_session is None: entry_in_session = session.merge(entry) for field, value in update_fields.items(): diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 0c3310c0ee..b28d05976e 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -354,6 +354,8 @@ def _query_entries( conditions: Optional[Any] = None, distinct: bool = False, join_scores: bool = False, + order_by: Optional[Any] = None, + limit: int | None = None, ) -> MutableSequence[Model]: """ Fetch data from the specified table model with optional conditions. @@ -363,6 +365,8 @@ def _query_entries( conditions: SQLAlchemy filter conditions (Optional). distinct: Whether to return distinct rows only. Defaults to False. join_scores: Whether to join the scores table. Defaults to False. + order_by: SQLAlchemy order_by clause (Optional). + limit (int | None): Maximum number of rows to return. Defaults to None (no limit). Returns: List of model instances representing the rows fetched from the table. @@ -378,6 +382,8 @@ def _execute_batched_query( distinct: bool = False, join_scores: bool = False, batch_size: int | None = None, + order_by: Optional[Any] = None, + limit: int | None = None, ) -> MutableSequence[Model]: """ Execute queries in batches to avoid exceeding database bind variable limits. @@ -394,6 +400,8 @@ def _execute_batched_query( join_scores: Whether to join the scores table. batch_size: Override for the number of values per batch. Defaults to ``_MAX_BIND_VARS`` when not specified. + order_by: SQLAlchemy order_by clause (Optional). + limit (int | None): Maximum number of rows to return. Defaults to None (no limit). Returns: MutableSequence[Model]: Merged and deduplicated results from all batched queries. @@ -411,6 +419,8 @@ def _execute_batched_query( conditions=and_(*conditions) if conditions else None, distinct=distinct, join_scores=join_scores, + order_by=order_by, + limit=limit, ) # Execute multiple separate queries and merge results @@ -426,6 +436,7 @@ def _execute_batched_query( conditions=and_(*conditions) if conditions else None, distinct=distinct, join_scores=join_scores, + order_by=order_by, ) # Deduplicate by primary key (id) @@ -2062,10 +2073,13 @@ def get_scenario_results( objective_target_endpoint: Optional[str] = None, objective_target_model_name: Optional[str] = None, identifier_filters: Optional[Sequence[IdentifierFilter]] = None, + limit: int | None = None, ) -> Sequence[ScenarioResult]: """ Retrieve a list of ScenarioResult objects based on the specified filters. + Results are always ordered by completion_time descending (most recent first). + Args: scenario_result_ids (Optional[Sequence[str]], optional): A list of scenario result IDs. Defaults to None. @@ -2088,9 +2102,11 @@ def get_scenario_results( identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A sequence of IdentifierFilter objects that allows filtering by identifier JSON properties. Defaults to None. + limit (int | None): Maximum number of results to return. Defaults to None (no limit). Returns: - Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. + Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters, + ordered by completion_time descending. """ if scenario_result_ids is not None and len(scenario_result_ids) == 0: return [] @@ -2149,6 +2165,8 @@ def get_scenario_results( ) try: + order_by_clause = ScenarioResultEntry.completion_time.desc() + # Handle scenario_result_ids with batched queries if needed if scenario_result_ids: entries = self._execute_batched_query( @@ -2156,9 +2174,16 @@ def get_scenario_results( batch_column=ScenarioResultEntry.id, batch_values=list(scenario_result_ids), other_conditions=conditions, + order_by=order_by_clause, + limit=limit, ) else: - entries = self._query_entries(ScenarioResultEntry, conditions=and_(*conditions) if conditions else None) + entries = self._query_entries( + ScenarioResultEntry, + conditions=and_(*conditions) if conditions else None, + order_by=order_by_clause, + limit=limit, + ) # Convert entries to ScenarioResults and populate attack_results efficiently scenario_results = [] diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 6b89313ba3..ce54730677 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -949,7 +949,7 @@ class ScenarioResultEntry(Base): scenario_init_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON, nullable=True) objective_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON, nullable=False) objective_scorer_identifier: Mapped[Optional[dict[str, str]]] = mapped_column(JSON, nullable=True) - scenario_run_state: Mapped[Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED"]] = mapped_column( + scenario_run_state: Mapped[Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED"]] = mapped_column( String, nullable=False, default="CREATED" ) attack_results_json: Mapped[str] = mapped_column(Unicode, nullable=False) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 53f6ce9134..0874428878 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -326,6 +326,8 @@ def _query_entries( conditions: Optional[Any] = None, distinct: bool = False, join_scores: bool = False, + order_by: Optional[Any] = None, + limit: int | None = None, ) -> MutableSequence[Model]: """ Fetch data from the specified table model with optional conditions. @@ -335,6 +337,8 @@ def _query_entries( conditions: SQLAlchemy filter conditions (Optional). distinct: Flag to return distinct rows (default is False). join_scores: Flag to join the scores table (default is False). + order_by: SQLAlchemy order_by clause (Optional). + limit (int | None): Maximum number of rows to return. Defaults to None (no limit). Returns: List of model instances representing the rows fetched from the table. @@ -354,8 +358,12 @@ def _query_entries( ) if conditions is not None: query = query.filter(conditions) + if order_by is not None: + query = query.order_by(order_by) if distinct: - return query.distinct().all() + query = query.distinct() + if limit is not None: + query = query.limit(limit) return query.all() except SQLAlchemyError as e: logger.exception(f"Error fetching data from table {model_class.__tablename__}: {e}") # type: ignore[ty:unresolved-attribute] diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index 3b159846d2..a0eca690cf 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -47,7 +47,7 @@ def __init__( self.init_data = init_data -ScenarioRunState = Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED"] +ScenarioRunState = Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED"] class ScenarioResult: @@ -101,6 +101,11 @@ def __init__( self.number_tries = number_tries self._display_group_map = display_group_map or {} + @property + def display_group_map(self) -> dict[str, str]: + """Mapping of atomic_attack_name → display group label.""" + return self._display_group_map + def get_strategies_used(self) -> list[str]: """ Get the list of strategies used in this scenario. diff --git a/pyrit/prompt_converter/random_translation_converter.py b/pyrit/prompt_converter/random_translation_converter.py index 4b81d2041b..769cb51611 100644 --- a/pyrit/prompt_converter/random_translation_converter.py +++ b/pyrit/prompt_converter/random_translation_converter.py @@ -35,7 +35,7 @@ class RandomTranslationConverter(LLMGenericTextConverter, WordLevelConverter): def __init__( self, *, - converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-assignment, ty:invalid-parameter-default] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] system_prompt_template: Optional[SeedPrompt] = None, languages: Optional[list[str]] = None, word_selection_strategy: Optional[WordSelectionStrategy] = None, diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 754a0269ce..34168b41ae 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -149,7 +149,7 @@ def __init__( if callable(self._api_key): # Token provider - create an AsyncTokenCredential wrapper credential = AsyncTokenProviderCredential(self._api_key) # type: ignore[ty:invalid-argument-type] - self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) # type: ignore[ty:invalid-argument-type] + self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) else: # String API key if not isinstance(self._api_key, str): diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 184a235f65..e2e18e2350 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -132,6 +132,7 @@ class ConfigurationLoader(YamlLoadable): operator: Optional[str] = None operation: Optional[str] = None scenario: Optional[Union[str, dict[str, Any]]] = None + max_concurrent_scenario_runs: int = 3 def __post_init__(self) -> None: """Validate and normalize the configuration after loading.""" diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 7b031ddf51..1725fb3755 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -5,6 +5,7 @@ Tests for ScenarioRunService. """ +from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -14,21 +15,23 @@ ScenarioRunStatus, ) from pyrit.backend.services.scenario_run_service import ( - MAX_CONCURRENT_RUNS, + _DEFAULT_MAX_CONCURRENT_RUNS, ScenarioRunService, get_scenario_run_service, ) -# The service uses deferred imports inside methods, so we patch at the source module. _REGISTRY_PATCH_BASE = "pyrit.registry" +_MEMORY_PATCH = "pyrit.memory.CentralMemory.get_memory_instance" @pytest.fixture(autouse=True) def clear_service_cache(): - """Clear the singleton cache between tests.""" - get_scenario_run_service.cache_clear() + """Clear the singleton instance between tests.""" + import pyrit.backend.services.scenario_run_service as svc_mod + + svc_mod._service_instance = None yield - get_scenario_run_service.cache_clear() + svc_mod._service_instance = None def _make_request( @@ -49,44 +52,46 @@ def _make_request( ) -@pytest.fixture -def mock_scenario_registry(): - """Patch ScenarioRegistry.get_registry_singleton to return a mock.""" - mock_registry = MagicMock() - mock_registry.get_class.return_value = MagicMock() - with patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_registry): - yield mock_registry - - -@pytest.fixture -def mock_target_registry(): - """Patch TargetRegistry.get_registry_singleton to return a mock.""" - mock_registry = MagicMock() - mock_registry.get_instance_by_name.return_value = MagicMock() - mock_registry.get_names.return_value = ["my_target"] - with patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_registry): - yield mock_registry +def _make_db_scenario_result( + *, + result_id: str = "sr-uuid-1", + scenario_name: str = "foundry.red_team_agent", + run_state: str = "IN_PROGRESS", + attack_results: dict | None = None, +) -> MagicMock: + """Create a mock ScenarioResult as returned by CentralMemory.""" + sr = MagicMock() + sr.id = result_id + sr.scenario_identifier.name = scenario_name + sr.scenario_identifier.version = 1 + sr.scenario_run_state = run_state + sr.get_strategies_used.return_value = [] + sr.attack_results = attack_results or {} + sr.number_tries = 1 + sr.completion_time = datetime(2025, 1, 1, tzinfo=timezone.utc) + sr.labels = {} + sr.objective_achieved_rate.return_value = 0 + sr.get_display_groups.return_value = {} + sr._display_group_map = {} + return sr @pytest.fixture -def mock_initializer_registry(): - """Patch InitializerRegistry.get_registry_singleton to return a mock.""" - mock_instance = MagicMock() - mock_instance.initialize_async = AsyncMock() - mock_class = MagicMock(return_value=mock_instance) - - mock_registry = MagicMock() - mock_registry.get_class.return_value = mock_class - with patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton", return_value=mock_registry): - yield mock_registry, mock_class, mock_instance +def mock_memory(): + """Patch CentralMemory.get_memory_instance to return a mock.""" + mock = MagicMock() + mock.get_scenario_results.return_value = [] + with patch(_MEMORY_PATCH, return_value=mock): + yield mock @pytest.fixture -def mock_all_registries(): - """Patch all registries with valid defaults for start_run_async tests.""" +def mock_all_registries(mock_memory): + """Patch all registries and CentralMemory with valid defaults.""" mock_scenario_instance = MagicMock() mock_scenario_instance.initialize_async = AsyncMock() mock_scenario_instance.run_async = AsyncMock() + mock_scenario_instance._scenario_result_id = "sr-uuid-1" mock_scenario_class = MagicMock(return_value=mock_scenario_instance) mock_scenario_class.get_strategy_class.return_value = MagicMock() @@ -102,6 +107,10 @@ def mock_all_registries(): mock_ir = MagicMock() mock_ir.get_class.return_value = MagicMock(return_value=MagicMock(initialize_async=AsyncMock())) + # By default, return a matching DB result for get_run / list_runs queries + db_result = _make_db_scenario_result() + mock_memory.get_scenario_results.return_value = [db_result] + with ( patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton", return_value=mock_sr), patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton", return_value=mock_tr), @@ -113,6 +122,8 @@ def mock_all_registries(): "initializer_registry": mock_ir, "scenario_class": mock_scenario_class, "scenario_instance": mock_scenario_instance, + "memory": mock_memory, + "db_result": db_result, } @@ -120,17 +131,16 @@ class TestScenarioRunServiceStartRun: """Tests for ScenarioRunService.start_run_async.""" async def test_start_run_returns_running_status(self, mock_all_registries) -> None: - """Test that starting a run returns RUNNING status with a run_id.""" + """Test that starting a run returns RUNNING status with run_id = scenario_result_id.""" service = ScenarioRunService() response = await service.start_run_async(request=_make_request()) - assert response.run_id is not None + assert response.run_id == "sr-uuid-1" assert response.status == ScenarioRunStatus.RUNNING assert response.scenario_name == "foundry.red_team_agent" assert response.error is None - assert response.result is None - async def test_start_run_invalid_scenario_raises_value_error(self) -> None: + async def test_start_run_invalid_scenario_raises_value_error(self, mock_memory) -> None: """Test that an invalid scenario name raises ValueError immediately.""" service = ScenarioRunService() @@ -144,7 +154,7 @@ async def test_start_run_invalid_scenario_raises_value_error(self) -> None: with pytest.raises(ValueError, match="not found in registry"): await service.start_run_async(request=_make_request(scenario_name="bad.scenario")) - async def test_start_run_invalid_target_raises_value_error(self) -> None: + async def test_start_run_invalid_target_raises_value_error(self, mock_memory) -> None: """Test that an invalid target name raises ValueError immediately.""" service = ScenarioRunService() @@ -163,7 +173,7 @@ async def test_start_run_invalid_target_raises_value_error(self) -> None: with pytest.raises(ValueError, match="my_target.*not found in registry"): await service.start_run_async(request=_make_request()) - async def test_start_run_invalid_initializer_raises_value_error(self) -> None: + async def test_start_run_invalid_initializer_raises_value_error(self, mock_memory) -> None: """Test that an invalid initializer name raises ValueError immediately.""" service = ScenarioRunService() @@ -181,7 +191,7 @@ async def test_start_run_invalid_initializer_raises_value_error(self) -> None: with pytest.raises(ValueError, match="Initializer not found"): await service.start_run_async(request=_make_request(initializers=["bad_init"])) - async def test_start_run_invalid_strategy_raises_value_error(self) -> None: + async def test_start_run_invalid_strategy_raises_value_error(self, mock_memory) -> None: """Test that an invalid strategy name raises ValueError immediately.""" service = ScenarioRunService() @@ -208,9 +218,21 @@ async def test_start_run_invalid_strategy_raises_value_error(self) -> None: async def test_start_run_exceeds_concurrent_limit(self, mock_all_registries) -> None: """Test that exceeding concurrent run limit raises ValueError.""" service = ScenarioRunService() + scenario_instance = mock_all_registries["scenario_instance"] + + # Each call needs a unique scenario_result_id + call_count = 0 + original_init = scenario_instance.initialize_async + + async def _set_unique_id(**kwargs: object) -> None: + nonlocal call_count + call_count += 1 + scenario_instance._scenario_result_id = f"sr-uuid-{call_count}" + + scenario_instance.initialize_async = AsyncMock(side_effect=_set_unique_id) # Fill up to the limit - for _ in range(MAX_CONCURRENT_RUNS): + for _ in range(_DEFAULT_MAX_CONCURRENT_RUNS): await service.start_run_async(request=_make_request()) # Next one should fail @@ -235,9 +257,7 @@ async def test_start_run_passes_scenario_result_id_for_resume(self, mock_all_reg service = ScenarioRunService() mock_scenario_class = mock_all_registries["scenario_class"] - response = await service.start_run_async( - request=_make_request(scenario_result_id="existing-result-uuid") - ) + response = await service.start_run_async(request=_make_request(scenario_result_id="existing-result-uuid")) assert response.status == ScenarioRunStatus.RUNNING mock_scenario_class.assert_called_once_with(scenario_result_id="existing-result-uuid") @@ -255,194 +275,179 @@ async def test_start_run_omits_scenario_result_id_when_none(self, mock_all_regis class TestScenarioRunServiceGetRun: """Tests for ScenarioRunService.get_run.""" - async def test_get_run_returns_none_for_unknown_id(self) -> None: + def test_get_run_returns_none_for_unknown_id(self, mock_memory) -> None: """Test that get_run returns None for non-existent run_id.""" + mock_memory.get_scenario_results.return_value = [] service = ScenarioRunService() result = service.get_run(run_id="nonexistent-id") assert result is None - async def test_get_run_returns_existing_run(self, mock_all_registries) -> None: - """Test that get_run returns a started run.""" + def test_get_run_returns_existing_run(self, mock_memory) -> None: + """Test that get_run returns a run from the database.""" + db_result = _make_db_scenario_result(result_id="sr-123", run_state="IN_PROGRESS") + mock_memory.get_scenario_results.return_value = [db_result] + service = ScenarioRunService() - response = await service.start_run_async(request=_make_request()) + fetched = service.get_run(run_id="sr-123") - fetched = service.get_run(run_id=response.run_id) assert fetched is not None - assert fetched.run_id == response.run_id + assert fetched.run_id == "sr-123" assert fetched.scenario_name == "foundry.red_team_agent" + assert fetched.status == ScenarioRunStatus.RUNNING class TestScenarioRunServiceListRuns: """Tests for ScenarioRunService.list_runs.""" - async def test_list_runs_empty(self) -> None: - """Test that list_runs returns empty list initially.""" + def test_list_runs_empty(self, mock_memory) -> None: + """Test that list_runs returns empty list when DB has no results.""" + mock_memory.get_scenario_results.return_value = [] service = ScenarioRunService() result = service.list_runs() assert result.items == [] + mock_memory.get_scenario_results.assert_called_once_with(limit=100) - async def test_list_runs_returns_all_runs(self, mock_all_registries) -> None: - """Test that list_runs returns all tracked runs.""" - service = ScenarioRunService() - - await service.start_run_async(request=_make_request()) - await service.start_run_async(request=_make_request()) + def test_list_runs_returns_all_runs(self, mock_memory) -> None: + """Test that list_runs returns all runs from the database.""" + db_results = [ + _make_db_scenario_result(result_id="sr-1", run_state="COMPLETED"), + _make_db_scenario_result(result_id="sr-2", run_state="IN_PROGRESS"), + ] + mock_memory.get_scenario_results.return_value = db_results + service = ScenarioRunService() result = service.list_runs() assert len(result.items) == 2 + mock_memory.get_scenario_results.assert_called_once_with(limit=100) + + def test_list_runs_passes_custom_limit(self, mock_memory) -> None: + """Test that list_runs passes a custom limit to the memory query.""" + mock_memory.get_scenario_results.return_value = [] + service = ScenarioRunService() + service.list_runs(limit=10) + mock_memory.get_scenario_results.assert_called_once_with(limit=10) class TestScenarioRunServiceCancelRun: """Tests for ScenarioRunService.cancel_run_async.""" - async def test_cancel_run_returns_none_for_unknown_id(self) -> None: + async def test_cancel_run_returns_none_for_unknown_id(self, mock_memory) -> None: """Test that cancel returns None for non-existent run_id.""" + mock_memory.get_scenario_results.return_value = [] service = ScenarioRunService() result = await service.cancel_run_async(run_id="nonexistent-id") assert result is None async def test_cancel_run_sets_cancelled_status(self, mock_all_registries) -> None: - """Test that cancelling a running scenario sets CANCELLED status.""" + """Test that cancelling a running scenario persists CANCELLED to DB.""" service = ScenarioRunService() + mock_memory = mock_all_registries["memory"] response = await service.start_run_async(request=_make_request()) + # After update_scenario_run_state, the next DB query should return CANCELLED + running_result = mock_all_registries["db_result"] + cancelled_result = _make_db_scenario_result(result_id=response.run_id, run_state="CANCELLED") + mock_memory.get_scenario_results.side_effect = [[running_result], [cancelled_result]] + result = await service.cancel_run_async(run_id=response.run_id) + + mock_memory.update_scenario_run_state.assert_called_once_with( + scenario_result_id=response.run_id, scenario_run_state="CANCELLED" + ) assert result is not None assert result.status == ScenarioRunStatus.CANCELLED - async def test_cancel_completed_run_raises_value_error(self, mock_all_registries) -> None: + async def test_cancel_completed_run_raises_value_error(self, mock_memory) -> None: """Test that cancelling a completed run raises ValueError.""" + db_result = _make_db_scenario_result(result_id="sr-done", run_state="COMPLETED") + mock_memory.get_scenario_results.return_value = [db_result] + service = ScenarioRunService() - response = await service.start_run_async(request=_make_request()) + with pytest.raises(ValueError, match="Cannot cancel run"): + await service.cancel_run_async(run_id="sr-done") - # Manually set to COMPLETED - service._runs[response.run_id].status = ScenarioRunStatus.COMPLETED + async def test_cancel_already_cancelled_run_raises_value_error(self, mock_memory) -> None: + """Test that cancelling an already-cancelled run raises ValueError.""" + db_result = _make_db_scenario_result(result_id="sr-cancelled", run_state="CANCELLED") + mock_memory.get_scenario_results.return_value = [db_result] + service = ScenarioRunService() with pytest.raises(ValueError, match="Cannot cancel run"): - await service.cancel_run_async(run_id=response.run_id) + await service.cancel_run_async(run_id="sr-cancelled") class TestScenarioRunServiceExecution: """Tests for the background execution logic.""" - async def test_execute_run_completes_successfully(self) -> None: - """Test that a successful execution transitions to COMPLETED.""" + async def test_execute_run_completes_successfully(self, mock_all_registries) -> None: + """Test that a successful execution removes active task and DB reflects COMPLETED.""" service = ScenarioRunService() + mock_instance = mock_all_registries["scenario_instance"] + mock_memory = mock_all_registries["memory"] mock_scenario_result = MagicMock() - mock_scenario_result.id = "result-uuid" + mock_scenario_result.id = "sr-uuid-1" mock_scenario_result.scenario_run_state = "COMPLETED" mock_scenario_result.get_strategies_used.return_value = ["base64"] mock_scenario_result.attack_results = {"attack1": []} mock_scenario_result.number_tries = 1 - mock_scenario_result.completion_time = None + mock_scenario_result.completion_time = datetime(2025, 1, 1, tzinfo=timezone.utc) - mock_scenario_instance = MagicMock() - mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(return_value=mock_scenario_result) + mock_instance.run_async = AsyncMock(return_value=mock_scenario_result) - mock_scenario_class = MagicMock(return_value=mock_scenario_instance) - mock_scenario_class.get_strategy_class.return_value = MagicMock() - mock_scenario_class.default_dataset_config.return_value = MagicMock() - - mock_target = MagicMock() - - with ( - patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton") as mock_sr, - patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton") as mock_tr, - patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), - ): - mock_sr.return_value.get_class.return_value = mock_scenario_class - mock_tr.return_value.get_instance_by_name.return_value = mock_target - - response = await service.start_run_async(request=_make_request()) + response = await service.start_run_async(request=_make_request()) - # Wait for the background task to complete - task = service._runs[response.run_id].task - assert task is not None - await task + # Wait for the background task to complete + active = service._active_tasks.get(response.run_id) + assert active is not None + assert active.task is not None + await active.task - run = service.get_run(run_id=response.run_id) - assert run is not None - assert run.status == ScenarioRunStatus.COMPLETED - assert run.result is not None - assert run.result.scenario_result_id == "result-uuid" - assert run.result.strategies_used == ["base64"] + # Active task should be cleaned up after completion + assert response.run_id not in service._active_tasks - async def test_execute_run_fails_with_error(self) -> None: - """Test that a run_async failure transitions to FAILED with error message.""" + async def test_execute_run_fails_with_error(self, mock_all_registries) -> None: + """Test that a run_async failure stores error and removes active task.""" service = ScenarioRunService() + mock_instance = mock_all_registries["scenario_instance"] - mock_scenario_instance = MagicMock() - mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(side_effect=RuntimeError("scenario exploded")) - - mock_scenario_class = MagicMock(return_value=mock_scenario_instance) - mock_scenario_class.get_strategy_class.return_value = MagicMock() - mock_scenario_class.default_dataset_config.return_value = MagicMock() + mock_instance.run_async = AsyncMock(side_effect=RuntimeError("scenario exploded")) - with ( - patch(f"{_REGISTRY_PATCH_BASE}.ScenarioRegistry.get_registry_singleton") as mock_sr, - patch(f"{_REGISTRY_PATCH_BASE}.TargetRegistry.get_registry_singleton") as mock_tr, - patch(f"{_REGISTRY_PATCH_BASE}.InitializerRegistry.get_registry_singleton"), - ): - mock_sr.return_value.get_class.return_value = mock_scenario_class - mock_tr.return_value.get_instance_by_name.return_value = MagicMock() - - response = await service.start_run_async(request=_make_request()) + response = await service.start_run_async(request=_make_request()) - # Wait for the background task - task = service._runs[response.run_id].task - assert task is not None - await task + # Wait for the background task + active = service._active_tasks.get(response.run_id) + assert active is not None + assert active.task is not None + await active.task - run = service.get_run(run_id=response.run_id) - assert run is not None - assert run.status == ScenarioRunStatus.FAILED - assert run.error is not None - assert "scenario exploded" in run.error + # Active task should be cleaned up + assert response.run_id not in service._active_tasks class TestScenarioRunServiceGetResults: """Tests for ScenarioRunService.get_run_results.""" - def test_get_results_returns_none_for_unknown_id(self) -> None: + def test_get_results_returns_none_for_unknown_id(self, mock_memory) -> None: """Test that get_run_results returns None for non-existent run_id.""" + mock_memory.get_scenario_results.return_value = [] service = ScenarioRunService() result = service.get_run_results(run_id="nonexistent-id") assert result is None - async def test_get_results_raises_if_not_completed(self, mock_all_registries) -> None: + def test_get_results_raises_if_not_completed(self, mock_memory) -> None: """Test that get_run_results raises ValueError if run is not completed.""" - service = ScenarioRunService() - response = await service.start_run_async(request=_make_request()) + db_result = _make_db_scenario_result(result_id="sr-running", run_state="IN_PROGRESS") + mock_memory.get_scenario_results.return_value = [db_result] - # Run is in RUNNING state + service = ScenarioRunService() with pytest.raises(ValueError, match="only available for completed runs"): - service.get_run_results(run_id=response.run_id) + service.get_run_results(run_id="sr-running") - async def test_get_results_returns_details_for_completed_run(self, mock_all_registries) -> None: + def test_get_results_returns_details_for_completed_run(self, mock_memory) -> None: """Test that get_run_results returns full details for a completed run.""" - from pyrit.backend.models.scenarios import ScenarioRunResult from pyrit.models import AttackOutcome - service = ScenarioRunService() - response = await service.start_run_async(request=_make_request()) - - # Manually set run to completed with a result - info = service._runs[response.run_id] - info.status = ScenarioRunStatus.COMPLETED - info.result = ScenarioRunResult( - scenario_result_id="sr-123", - run_state="COMPLETED", - strategies_used=["base64"], - total_attacks=1, - completed_attacks=1, - number_tries=1, - completion_time=None, - ) - - # Mock CentralMemory and ScenarioResult mock_attack_result = MagicMock() mock_attack_result.attack_result_id = "ar-1" mock_attack_result.conversation_id = "conv-1" @@ -456,24 +461,16 @@ async def test_get_results_returns_details_for_completed_run(self, mock_all_regi mock_attack_result.execution_time_ms = 1500 mock_attack_result.timestamp = None - mock_scenario_result = MagicMock() - mock_scenario_result.id = "sr-123" - mock_scenario_result.scenario_identifier.name = "foundry.red_team_agent" - mock_scenario_result.scenario_identifier.version = 1 - mock_scenario_result.scenario_run_state = "COMPLETED" - mock_scenario_result.objective_achieved_rate.return_value = 100 - mock_scenario_result.number_tries = 1 - mock_scenario_result.completion_time = None - mock_scenario_result.labels = {} - mock_scenario_result.attack_results = {"base64_attack": [mock_attack_result]} - mock_scenario_result.get_display_groups.return_value = {"base64_attack": [mock_attack_result]} - mock_scenario_result._display_group_map = {} - - mock_memory = MagicMock() - mock_memory.get_scenario_results.return_value = [mock_scenario_result] + db_result = _make_db_scenario_result( + result_id="sr-123", + run_state="COMPLETED", + attack_results={"base64_attack": [mock_attack_result]}, + ) + db_result.objective_achieved_rate.return_value = 100 + mock_memory.get_scenario_results.return_value = [db_result] - with patch("pyrit.memory.CentralMemory.get_memory_instance", return_value=mock_memory): - detail = service.get_run_results(run_id=response.run_id) + service = ScenarioRunService() + detail = service.get_run_results(run_id="sr-123") assert detail is not None assert detail.scenario_result_id == "sr-123" From 15bdf3566300bba08da85cdb0d2c159abac18cfd Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 6 May 2026 17:42:47 -0700 Subject: [PATCH 04/10] self pr review --- .../backend/services/scenario_run_service.py | 24 +++++++------- pyrit/memory/memory_models.py | 1 + pyrit/models/scenario_result.py | 3 ++ .../unit/backend/test_scenario_run_routes.py | 8 ++--- .../unit/backend/test_scenario_run_service.py | 32 ++++++++++++------- 5 files changed, 41 insertions(+), 27 deletions(-) diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index bca74d862f..85fe619fa4 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -85,7 +85,7 @@ async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunRe ValueError: If scenario, target, initializer, or strategy cannot be found, or concurrent limit exceeded. """ - if len(self._active_tasks) >= self._max_concurrent_runs: + if sum(1 for a in self._active_tasks.values() if a.task is not None and not a.task.done()) >= self._max_concurrent_runs: raise ValueError( f"Maximum concurrent runs ({self._max_concurrent_runs}) reached. " "Wait for an existing run to complete or cancel one." @@ -138,10 +138,7 @@ def list_runs(self, *, limit: int = 100) -> ScenarioRunListResponse: # This is expensive, and we don't need all the data. At some point # we may want to add a lightweight "list" query to the DB layer that only results = memory.get_scenario_results(limit=limit) - items = [ - self._build_response_from_db(scenario_result=sr) - for sr in results - ] + items = [self._build_response_from_db(scenario_result=sr) for sr in results] return ScenarioRunListResponse(items=items) async def cancel_run_async(self, *, run_id: str) -> ScenarioRunResponse | None: @@ -302,9 +299,6 @@ async def _execute_run_async(self, *, scenario_result_id: str) -> None: active.error = str(e) logger.exception(f"Scenario run {scenario_result_id} failed: {e}") - finally: - del self._active_tasks[scenario_result_id] - def _build_response(self, *, scenario_result_id: str) -> ScenarioRunResponse | None: """ Build a ScenarioRunResponse by querying the database and merging active task state. @@ -334,6 +328,13 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari scenario_result_id = str(scenario_result.id) active = self._active_tasks.get(scenario_result_id) + # Clean up finished active tasks after reading the error + error = None + if active is not None: + error = active.error + if active.task is not None and active.task.done(): + del self._active_tasks[scenario_result_id] + status = _STATE_TO_STATUS.get(scenario_result.scenario_run_state, ScenarioRunStatus.FAILED) # Build result summary for completed runs @@ -345,23 +346,22 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari for ar in results if ar.outcome in (AttackOutcome.SUCCESS, AttackOutcome.FAILURE) ) + total_attacks = sum(len(results) for results in scenario_result.attack_results.values()) result = ScenarioRunResult( scenario_result_id=scenario_result_id, run_state=scenario_result.scenario_run_state, strategies_used=scenario_result.get_strategies_used(), - total_attacks=len(scenario_result.attack_results), + total_attacks=total_attacks, completed_attacks=completed_attacks, number_tries=scenario_result.number_tries, completion_time=scenario_result.completion_time, ) - error = active.error if active else None - return ScenarioRunResponse( run_id=scenario_result_id, scenario_name=scenario_result.scenario_identifier.name, status=status, - created_at=scenario_result.completion_time, + created_at=scenario_result.created_at, updated_at=scenario_result.completion_time, error=error, result=result, diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index ce54730677..e9b6b6659e 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -1053,6 +1053,7 @@ def get_scenario_result(self) -> ScenarioResult: objective_scorer_identifier=scorer_identifier, # type: ignore[ty:invalid-argument-type] scenario_run_state=self.scenario_run_state, labels=self.labels, + created_at=self.timestamp, number_tries=self.number_tries, completion_time=self.completion_time, display_group_map=display_group_map, diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index a0eca690cf..7f237e8c9f 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -64,6 +64,7 @@ def __init__( objective_scorer_identifier: "ComponentIdentifier", scenario_run_state: ScenarioRunState = "CREATED", labels: Optional[dict[str, str]] = None, + created_at: Optional[datetime] = None, completion_time: Optional[datetime] = None, number_tries: int = 0, id: Optional[uuid.UUID] = None, # noqa: A002 @@ -79,6 +80,7 @@ def __init__( objective_scorer_identifier (ComponentIdentifier): Objective scorer identifier. scenario_run_state (ScenarioRunState): Current scenario run state. labels (Optional[dict[str, str]]): Optional labels. + created_at (Optional[datetime]): When the scenario result was created. completion_time (Optional[datetime]): Optional completion timestamp. number_tries (int): Number of run attempts. id (Optional[uuid.UUID]): Optional scenario result ID. @@ -97,6 +99,7 @@ def __init__( self.scenario_run_state = scenario_run_state self.attack_results = attack_results self.labels = labels if labels is not None else {} + self.created_at = created_at if created_at is not None else datetime.now(timezone.utc) self.completion_time = completion_time if completion_time is not None else datetime.now(timezone.utc) self.number_tries = number_tries self._display_group_map = display_group_map or {} diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py index b2340fdcdd..9099214e58 100644 --- a/tests/unit/backend/test_scenario_run_routes.py +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -18,7 +18,7 @@ ScenarioRunResponse, ScenarioRunStatus, ) -from pyrit.backend.services.scenario_run_service import get_scenario_run_service +import pyrit.backend.services.scenario_run_service as _svc_mod @pytest.fixture @@ -29,10 +29,10 @@ def client() -> TestClient: @pytest.fixture(autouse=True) def clear_service_cache(): - """Clear the service singleton cache between tests.""" - get_scenario_run_service.cache_clear() + """Clear the service singleton between tests.""" + _svc_mod._service_instance = None yield - get_scenario_run_service.cache_clear() + _svc_mod._service_instance = None def _mock_run_response( diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 1725fb3755..d1e653332d 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -17,8 +17,8 @@ from pyrit.backend.services.scenario_run_service import ( _DEFAULT_MAX_CONCURRENT_RUNS, ScenarioRunService, - get_scenario_run_service, ) +import pyrit.backend.services.scenario_run_service as _svc_mod _REGISTRY_PATCH_BASE = "pyrit.registry" _MEMORY_PATCH = "pyrit.memory.CentralMemory.get_memory_instance" @@ -27,11 +27,9 @@ @pytest.fixture(autouse=True) def clear_service_cache(): """Clear the singleton instance between tests.""" - import pyrit.backend.services.scenario_run_service as svc_mod - - svc_mod._service_instance = None + _svc_mod._service_instance = None yield - svc_mod._service_instance = None + _svc_mod._service_instance = None def _make_request( @@ -68,11 +66,12 @@ def _make_db_scenario_result( sr.get_strategies_used.return_value = [] sr.attack_results = attack_results or {} sr.number_tries = 1 - sr.completion_time = datetime(2025, 1, 1, tzinfo=timezone.utc) + sr.created_at = datetime(2025, 1, 1, tzinfo=timezone.utc) + sr.completion_time = datetime(2025, 1, 1, 0, 5, tzinfo=timezone.utc) sr.labels = {} sr.objective_achieved_rate.return_value = 0 sr.get_display_groups.return_value = {} - sr._display_group_map = {} + sr.display_group_map = {} return sr @@ -391,7 +390,8 @@ async def test_execute_run_completes_successfully(self, mock_all_registries) -> mock_scenario_result.get_strategies_used.return_value = ["base64"] mock_scenario_result.attack_results = {"attack1": []} mock_scenario_result.number_tries = 1 - mock_scenario_result.completion_time = datetime(2025, 1, 1, tzinfo=timezone.utc) + mock_scenario_result.created_at = datetime(2025, 1, 1, tzinfo=timezone.utc) + mock_scenario_result.completion_time = datetime(2025, 1, 1, 0, 5, tzinfo=timezone.utc) mock_instance.run_async = AsyncMock(return_value=mock_scenario_result) @@ -403,11 +403,14 @@ async def test_execute_run_completes_successfully(self, mock_all_registries) -> assert active.task is not None await active.task - # Active task should be cleaned up after completion + # Active task is cleaned up on next get_run (deferred cleanup) + assert response.run_id in service._active_tasks + fetched = service.get_run(run_id=response.run_id) + assert fetched is not None assert response.run_id not in service._active_tasks async def test_execute_run_fails_with_error(self, mock_all_registries) -> None: - """Test that a run_async failure stores error and removes active task.""" + """Test that a run_async failure stores error and surfaces it via get_run.""" service = ScenarioRunService() mock_instance = mock_all_registries["scenario_instance"] @@ -421,7 +424,14 @@ async def test_execute_run_fails_with_error(self, mock_all_registries) -> None: assert active.task is not None await active.task - # Active task should be cleaned up + # Error is stored on the active task until get_run reads it + assert active.error == "scenario exploded" + assert response.run_id in service._active_tasks + + # get_run should surface the error and clean up + fetched = service.get_run(run_id=response.run_id) + assert fetched is not None + assert fetched.error == "scenario exploded" assert response.run_id not in service._active_tasks From 33f36b290fee554e334398f2ba0e207b936f6207 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 7 May 2026 09:39:33 -0700 Subject: [PATCH 05/10] pre-commit --- pyrit/backend/models/attacks.py | 15 +++++++++++++++ pyrit/backend/models/scenarios.py | 16 +--------------- .../backend/services/scenario_run_service.py | 19 +++++++++++++------ .../unit/backend/test_scenario_run_routes.py | 11 ++++------- .../unit/backend/test_scenario_run_service.py | 2 +- tests/unit/backend/test_scenario_service.py | 8 -------- 6 files changed, 34 insertions(+), 37 deletions(-) diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 95b98a8f49..24f2855318 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -77,6 +77,21 @@ class Message(BaseModel): created_at: datetime = Field(..., description="Message creation timestamp") +class AttackResultDetail(BaseModel): + """Detailed result of a single attack within a scenario.""" + + attack_result_id: str = Field(..., description="Unique ID of this attack result") + conversation_id: str = Field(..., description="Conversation ID that produced this result") + objective: str = Field(..., description="Natural-language description of the attacker's objective") + outcome: str = Field(..., description="Attack outcome: success, failure, or undetermined") + outcome_reason: str | None = Field(None, description="Reason for the outcome") + last_response: str | None = Field(None, description="Model response from the final turn") + score_value: str | None = Field(None, description="Score value from the objective scorer") + executed_turns: int = Field(0, ge=0, description="Number of turns executed") + execution_time_ms: int = Field(0, ge=0, description="Execution time in milliseconds") + timestamp: datetime | None = Field(None, description="When the result was created") + + # ============================================================================ # Attack Summary (List View) # ============================================================================ diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index 494972ddfb..5b36c0f6a0 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -14,6 +14,7 @@ from pydantic import BaseModel, Field +from pyrit.backend.models.attacks import AttackResultDetail from pyrit.backend.models.common import PaginationInfo @@ -122,21 +123,6 @@ class ScenarioRunListResponse(BaseModel): # ============================================================================ -class AttackResultDetail(BaseModel): - """Detailed result of a single attack within a scenario.""" - - attack_result_id: str = Field(..., description="Unique ID of this attack result") - conversation_id: str = Field(..., description="Conversation ID that produced this result") - objective: str = Field(..., description="Natural-language description of the attacker's objective") - outcome: str = Field(..., description="Attack outcome: success, failure, or undetermined") - outcome_reason: str | None = Field(None, description="Reason for the outcome") - last_response: str | None = Field(None, description="Model response from the final turn") - score_value: str | None = Field(None, description="Score value from the objective scorer") - executed_turns: int = Field(0, ge=0, description="Number of turns executed") - execution_time_ms: int = Field(0, ge=0, description="Execution time in milliseconds") - timestamp: datetime | None = Field(None, description="When the result was created") - - class AtomicAttackResults(BaseModel): """Results grouped by atomic attack name.""" diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 85fe619fa4..52554d7120 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -34,7 +34,7 @@ _DEFAULT_MAX_CONCURRENT_RUNS = 3 # Maps DB ScenarioRunState values to API ScenarioRunStatus -_STATE_TO_STATUS: dict[str, ScenarioRunStatus] = { +_STATE_TO_STATUS = { "CREATED": ScenarioRunStatus.INITIALIZING, "IN_PROGRESS": ScenarioRunStatus.RUNNING, "COMPLETED": ScenarioRunStatus.COMPLETED, @@ -85,7 +85,10 @@ async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunRe ValueError: If scenario, target, initializer, or strategy cannot be found, or concurrent limit exceeded. """ - if sum(1 for a in self._active_tasks.values() if a.task is not None and not a.task.done()) >= self._max_concurrent_runs: + if ( + sum(1 for a in self._active_tasks.values() if a.task is not None and not a.task.done()) + >= self._max_concurrent_runs + ): raise ValueError( f"Maximum concurrent runs ({self._max_concurrent_runs}) reached. " "Wait for an existing run to complete or cancel one." @@ -168,9 +171,8 @@ async def cancel_run_async(self, *, run_id: str) -> ScenarioRunResponse | None: # Cancel the asyncio task if active active = self._active_tasks.get(run_id) - if active is not None: - if active.task is not None and not active.task.done(): - active.task.cancel() + if active is not None and active.task is not None and not active.task.done(): + active.task.cancel() # Persist cancelled state to DB memory.update_scenario_run_state(scenario_result_id=run_id, scenario_run_state="CANCELLED") @@ -281,7 +283,12 @@ async def _execute_run_async(self, *, scenario_result_id: str) -> None: Execute a scenario run (background task entry point). Only calls scenario.run_async on the already-initialized scenario. - Removes the task from _active_tasks when done. + + Note: this method intentionally does NOT remove the entry from + ``_active_tasks`` on completion. The entry must stay so that + ``_build_response_from_db`` can read ``active.error`` when the + caller next polls the run status. Cleanup happens lazily there + once the error has been surfaced. Args: scenario_result_id: The scenario result ID for this run. diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py index 9099214e58..bf6664c7a5 100644 --- a/tests/unit/backend/test_scenario_run_routes.py +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -12,13 +12,16 @@ from fastapi import status from fastapi.testclient import TestClient +import pyrit.backend.services.scenario_run_service as _svc_mod from pyrit.backend.main import app +from pyrit.backend.models.attacks import AttackResultDetail from pyrit.backend.models.scenarios import ( + AtomicAttackResults, + ScenarioResultDetailResponse, ScenarioRunListResponse, ScenarioRunResponse, ScenarioRunStatus, ) -import pyrit.backend.services.scenario_run_service as _svc_mod @pytest.fixture @@ -232,12 +235,6 @@ class TestGetScenarioRunResultsRoute: def test_get_results_returns_200(self, client: TestClient) -> None: """Test that getting results of a completed run returns 200.""" - from pyrit.backend.models.scenarios import ( - AtomicAttackResults, - AttackResultDetail, - ScenarioResultDetailResponse, - ) - mock_result = ScenarioResultDetailResponse( scenario_result_id="result-uuid", scenario_name="foundry.red_team_agent", diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index d1e653332d..b7c852f30e 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -10,6 +10,7 @@ import pytest +import pyrit.backend.services.scenario_run_service as _svc_mod from pyrit.backend.models.scenarios import ( RunScenarioRequest, ScenarioRunStatus, @@ -18,7 +19,6 @@ _DEFAULT_MAX_CONCURRENT_RUNS, ScenarioRunService, ) -import pyrit.backend.services.scenario_run_service as _svc_mod _REGISTRY_PATCH_BASE = "pyrit.registry" _MEMORY_PATCH = "pyrit.memory.CentralMemory.get_memory_instance" diff --git a/tests/unit/backend/test_scenario_service.py b/tests/unit/backend/test_scenario_service.py index 1a56086c25..04d61e622a 100644 --- a/tests/unit/backend/test_scenario_service.py +++ b/tests/unit/backend/test_scenario_service.py @@ -65,7 +65,6 @@ def _make_scenario_metadata( class TestScenarioServiceListScenarios: """Tests for ScenarioService.list_scenarios_async.""" - @pytest.mark.asyncio async def test_list_scenarios_returns_empty_when_no_scenarios(self) -> None: """Test that list returns empty list when no scenarios are registered.""" with patch.object(ScenarioService, "__init__", lambda self: None): @@ -78,7 +77,6 @@ async def test_list_scenarios_returns_empty_when_no_scenarios(self) -> None: assert result.items == [] assert result.pagination.has_more is False - @pytest.mark.asyncio async def test_list_scenarios_returns_scenarios_from_registry(self) -> None: """Test that list returns scenarios from registry.""" metadata = _make_scenario_metadata() @@ -100,7 +98,6 @@ async def test_list_scenarios_returns_scenarios_from_registry(self) -> None: assert result.items[0].default_datasets == ["test_dataset"] assert result.items[0].max_dataset_size is None - @pytest.mark.asyncio async def test_list_scenarios_paginates_with_limit(self) -> None: """Test that list respects the limit parameter.""" metadata_list = [ @@ -118,7 +115,6 @@ async def test_list_scenarios_paginates_with_limit(self) -> None: assert result.pagination.has_more is True assert result.pagination.next_cursor == "test.scenario_2" - @pytest.mark.asyncio async def test_list_scenarios_paginates_with_cursor(self) -> None: """Test that list uses cursor for pagination.""" metadata_list = [ @@ -137,7 +133,6 @@ async def test_list_scenarios_paginates_with_cursor(self) -> None: assert result.items[1].scenario_name == "test.scenario_3" assert result.pagination.has_more is True - @pytest.mark.asyncio async def test_list_scenarios_last_page_has_more_false(self) -> None: """Test that last page shows has_more=False.""" metadata_list = [ @@ -155,7 +150,6 @@ async def test_list_scenarios_last_page_has_more_false(self) -> None: assert result.pagination.has_more is False assert result.pagination.next_cursor is None - @pytest.mark.asyncio async def test_list_scenarios_includes_max_dataset_size(self) -> None: """Test that max_dataset_size is included in response.""" metadata = _make_scenario_metadata(max_dataset_size=10) @@ -173,7 +167,6 @@ async def test_list_scenarios_includes_max_dataset_size(self) -> None: class TestScenarioServiceGetScenario: """Tests for ScenarioService.get_scenario_async.""" - @pytest.mark.asyncio async def test_get_scenario_returns_matching_scenario(self) -> None: """Test that get returns the matching scenario.""" metadata = _make_scenario_metadata(registry_name="foundry.red_team_agent") @@ -188,7 +181,6 @@ async def test_get_scenario_returns_matching_scenario(self) -> None: assert result is not None assert result.scenario_name == "foundry.red_team_agent" - @pytest.mark.asyncio async def test_get_scenario_returns_none_for_missing(self) -> None: """Test that get returns None when scenario not found.""" with patch.object(ScenarioService, "__init__", lambda self: None): From f70a3b556ce44ac825d88b04c5fa6b08d4bc2d25 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 7 May 2026 09:59:03 -0700 Subject: [PATCH 06/10] test fix --- pyrit/backend/models/scenarios.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index 5b36c0f6a0..48e4392a93 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -9,7 +9,7 @@ """ from datetime import datetime -from enum import StrEnum +from enum import Enum from typing import Any, Optional from pydantic import BaseModel, Field @@ -45,7 +45,7 @@ class ScenarioListResponse(BaseModel): # ============================================================================ -class ScenarioRunStatus(StrEnum): +class ScenarioRunStatus(str, Enum): """Status of a scenario run.""" PENDING = "pending" From 492961c07169f85886a5ae09fe132a981b2fc289 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 7 May 2026 12:07:14 -0700 Subject: [PATCH 07/10] pr feedback --- pyrit/backend/models/__init__.py | 8 +-- pyrit/backend/models/scenarios.py | 39 ++++-------- pyrit/backend/routes/scenarios.py | 36 +++++------ .../backend/services/scenario_run_service.py | 60 ++++++++----------- pyrit/backend/services/scenario_service.py | 16 ++--- .../unit/backend/test_scenario_run_routes.py | 38 ++++++------ .../unit/backend/test_scenario_run_service.py | 28 ++++----- 7 files changed, 103 insertions(+), 122 deletions(-) diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index d606d89eb0..aeebe087fd 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -48,8 +48,8 @@ PreviewStep, ) from pyrit.backend.models.scenarios import ( - ScenarioListResponse, - ScenarioSummary, + ListRegisteredScenarioResponse, + RegisteredScenario, ) from pyrit.backend.models.targets import ( CreateTargetRequest, @@ -96,8 +96,8 @@ "CreateConverterResponse", "PreviewStep", # Scenarios - "ScenarioListResponse", - "ScenarioSummary", + "ListRegisteredScenarioResponse", + "RegisteredScenario", # Targets "CreateTargetRequest", "TargetInstance", diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index 48e4392a93..d76cb66111 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -18,7 +18,7 @@ from pyrit.backend.models.common import PaginationInfo -class ScenarioSummary(BaseModel): +class RegisteredScenario(BaseModel): """Summary of a registered scenario.""" scenario_name: str = Field(..., description="Registry key (e.g., 'foundry.red_team_agent')") @@ -33,10 +33,10 @@ class ScenarioSummary(BaseModel): max_dataset_size: Optional[int] = Field(None, description="Maximum items per dataset (None means unlimited)") -class ScenarioListResponse(BaseModel): +class ListRegisteredScenarioResponse(BaseModel): """Response for listing scenarios.""" - items: list[ScenarioSummary] = Field(..., description="List of scenario summaries") + items: list[RegisteredScenario] = Field(..., description="List of scenario summaries") pagination: PaginationInfo = Field(..., description="Pagination metadata") @@ -69,7 +69,7 @@ class RunScenarioRequest(BaseModel): max_dataset_size: int | None = Field(None, ge=1, description="Maximum items per dataset") max_concurrency: int = Field(10, ge=1, le=100, description="Maximum concurrent operations") max_retries: int = Field(0, ge=0, le=20, description="Maximum retry attempts on failure") - memory_labels: dict[str, str] | None = Field(None, description="Labels to attach to memory entries") + labels: dict[str, str] | None = Field(None, description="Labels to attach to memory entries") scenario_params: dict[str, Any] | None = Field( None, description="Custom parameters for the scenario (passed to scenario.set_params_from_args). " @@ -88,34 +88,25 @@ class RunScenarioRequest(BaseModel): ) -class ScenarioRunResult(BaseModel): - """Summary of a completed scenario run's results.""" +class ScenarioRunSummary(BaseModel): + """Response for a scenario run (status + result details).""" scenario_result_id: str = Field(..., description="UUID of the ScenarioResult in memory") - run_state: str = Field(..., description="Final scenario run state (COMPLETED, FAILED)") - strategies_used: list[str] = Field(..., description="Strategy names that were executed") - total_attacks: int = Field(..., ge=0, description="Total number of atomic attacks") - completed_attacks: int = Field(..., ge=0, description="Number of attacks that completed") - number_tries: int = Field(..., ge=0, description="Number of execution attempts") - completion_time: datetime | None = Field(None, description="When the scenario finished") - - -class ScenarioRunResponse(BaseModel): - """Response for a scenario run (status + optional result).""" - - run_id: str = Field(..., description="Unique identifier for this run") scenario_name: str = Field(..., description="Registry key of the scenario being run") status: ScenarioRunStatus = Field(..., description="Current run status") created_at: datetime = Field(..., description="When the run was created") updated_at: datetime = Field(..., description="When the run status last changed") error: str | None = Field(None, description="Error message if status is FAILED") - result: ScenarioRunResult | None = Field(None, description="Result details if status is COMPLETED") + strategies_used: list[str] = Field(default_factory=list, description="Strategy names that were executed") + total_attacks: int = Field(0, ge=0, description="Total number of atomic attacks") + completed_attacks: int = Field(0, ge=0, description="Number of attacks that completed") + completed_at: datetime | None = Field(None, description="When the scenario finished") class ScenarioRunListResponse(BaseModel): """Response for listing scenario runs.""" - items: list[ScenarioRunResponse] = Field(..., description="List of scenario runs") + items: list[ScenarioRunSummary] = Field(..., description="List of scenario runs") # ============================================================================ @@ -134,15 +125,11 @@ class AtomicAttackResults(BaseModel): total_count: int = Field(0, ge=0, description="Total number of attack results") -class ScenarioResultDetailResponse(BaseModel): +class ScenarioRunDetail(BaseModel): """Full detailed results of a scenario run.""" - scenario_result_id: str = Field(..., description="UUID of the ScenarioResult") - scenario_name: str = Field(..., description="Name of the scenario") + run: ScenarioRunSummary = Field(..., description="The scenario run summary") scenario_version: int = Field(..., description="Version of the scenario") - run_state: str = Field(..., description="Final run state (COMPLETED, FAILED, etc.)") objective_achieved_rate: int = Field(..., ge=0, le=100, description="Success rate as percentage (0-100)") - number_tries: int = Field(..., ge=0, description="Number of execution attempts") - completion_time: datetime | None = Field(None, description="When the scenario finished") labels: dict[str, str] = Field(default_factory=dict, description="Labels attached to this run") attacks: list[AtomicAttackResults] = Field(..., description="Results grouped by atomic attack") diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py index 77f73a38c5..ab825b9f9f 100644 --- a/pyrit/backend/routes/scenarios.py +++ b/pyrit/backend/routes/scenarios.py @@ -18,12 +18,12 @@ from pyrit.backend.models.common import ProblemDetail from pyrit.backend.models.scenarios import ( + ListRegisteredScenarioResponse, + RegisteredScenario, RunScenarioRequest, - ScenarioListResponse, - ScenarioResultDetailResponse, + ScenarioRunDetail, ScenarioRunListResponse, - ScenarioRunResponse, - ScenarioSummary, + ScenarioRunSummary, ) from pyrit.backend.services.scenario_run_service import get_scenario_run_service from pyrit.backend.services.scenario_service import get_scenario_service @@ -38,12 +38,12 @@ @router.get( "/catalog", - response_model=ScenarioListResponse, + response_model=ListRegisteredScenarioResponse, ) async def list_scenarios( limit: int = Query(50, ge=1, le=200, description="Maximum items per page"), cursor: Optional[str] = Query(None, description="Pagination cursor (scenario_name to start after)"), -) -> ScenarioListResponse: +) -> ListRegisteredScenarioResponse: """ List all available scenarios. @@ -59,12 +59,12 @@ async def list_scenarios( @router.get( "/catalog/{scenario_name:path}", - response_model=ScenarioSummary, + response_model=RegisteredScenario, responses={ 404: {"model": ProblemDetail, "description": "Scenario not found"}, }, ) -async def get_scenario(scenario_name: str) -> ScenarioSummary: +async def get_scenario(scenario_name: str) -> RegisteredScenario: """ Get details for a specific scenario. @@ -93,13 +93,13 @@ async def get_scenario(scenario_name: str) -> ScenarioSummary: @router.post( "/runs", - response_model=ScenarioRunResponse, + response_model=ScenarioRunSummary, status_code=status.HTTP_202_ACCEPTED, responses={ 400: {"model": ProblemDetail, "description": "Invalid request (bad scenario/target/strategy)"}, }, ) -async def start_scenario_run(request: RunScenarioRequest) -> ScenarioRunResponse: +async def start_scenario_run(request: RunScenarioRequest) -> ScenarioRunSummary: """ Start a new scenario run as a background task. @@ -138,12 +138,12 @@ async def list_scenario_runs(limit: int = Query(100, ge=1)) -> ScenarioRunListRe @router.get( "/runs/{run_id}", - response_model=ScenarioRunResponse, + response_model=ScenarioRunSummary, responses={ 404: {"model": ProblemDetail, "description": "Run not found"}, }, ) -async def get_scenario_run(run_id: str) -> ScenarioRunResponse: +async def get_scenario_run(run_id: str) -> ScenarioRunSummary: """ Get the current status and result of a scenario run. @@ -163,15 +163,15 @@ async def get_scenario_run(run_id: str) -> ScenarioRunResponse: return run -@router.delete( - "/runs/{run_id}", - response_model=ScenarioRunResponse, +@router.post( + "/runs/{run_id}/cancel", + response_model=ScenarioRunSummary, responses={ 404: {"model": ProblemDetail, "description": "Run not found"}, 409: {"model": ProblemDetail, "description": "Run already in terminal state"}, }, ) -async def cancel_scenario_run(run_id: str) -> ScenarioRunResponse: +async def cancel_scenario_run(run_id: str) -> ScenarioRunSummary: """ Cancel a running scenario. @@ -197,13 +197,13 @@ async def cancel_scenario_run(run_id: str) -> ScenarioRunResponse: @router.get( "/runs/{run_id}/results", - response_model=ScenarioResultDetailResponse, + response_model=ScenarioRunDetail, responses={ 404: {"model": ProblemDetail, "description": "Run not found"}, 409: {"model": ProblemDetail, "description": "Run not yet completed"}, }, ) -async def get_scenario_run_results(run_id: str) -> ScenarioResultDetailResponse: +async def get_scenario_run_results(run_id: str) -> ScenarioRunDetail: """ Get detailed results for a completed scenario run. diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 52554d7120..daff5702ea 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -17,11 +17,10 @@ AtomicAttackResults, AttackResultDetail, RunScenarioRequest, - ScenarioResultDetailResponse, + ScenarioRunDetail, ScenarioRunListResponse, - ScenarioRunResponse, - ScenarioRunResult, ScenarioRunStatus, + ScenarioRunSummary, ) from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, ScenarioResult @@ -66,7 +65,7 @@ def __init__(self, *, max_concurrent_runs: int = _DEFAULT_MAX_CONCURRENT_RUNS) - self._max_concurrent_runs = max_concurrent_runs self._active_tasks: dict[str, _ActiveTask] = {} - async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunResponse: + async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunSummary: """ Start a new scenario run as a background task. @@ -114,7 +113,7 @@ async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunRe assert response is not None # guaranteed: we just inserted into DB via initialize_async return response - def get_run(self, *, run_id: str) -> ScenarioRunResponse | None: + def get_run(self, *, run_id: str) -> ScenarioRunSummary | None: """ Get the current status of a scenario run by querying the database. @@ -144,7 +143,7 @@ def list_runs(self, *, limit: int = 100) -> ScenarioRunListResponse: items = [self._build_response_from_db(scenario_result=sr) for sr in results] return ScenarioRunListResponse(items=items) - async def cancel_run_async(self, *, run_id: str) -> ScenarioRunResponse | None: + async def cancel_run_async(self, *, run_id: str) -> ScenarioRunSummary | None: """ Cancel a running scenario. @@ -240,8 +239,8 @@ async def _initialize_run_async(self, *, request: RunScenarioRequest) -> Scenari "max_retries": request.max_retries, } - if request.memory_labels: - init_kwargs["memory_labels"] = request.memory_labels + if request.labels: + init_kwargs["memory_labels"] = request.labels # Validate and resolve strategies if request.strategies: @@ -306,7 +305,7 @@ async def _execute_run_async(self, *, scenario_result_id: str) -> None: active.error = str(e) logger.exception(f"Scenario run {scenario_result_id} failed: {e}") - def _build_response(self, *, scenario_result_id: str) -> ScenarioRunResponse | None: + def _build_response(self, *, scenario_result_id: str) -> ScenarioRunSummary | None: """ Build a ScenarioRunResponse by querying the database and merging active task state. @@ -322,7 +321,7 @@ def _build_response(self, *, scenario_result_id: str) -> ScenarioRunResponse | N return None return self._build_response_from_db(scenario_result=results[0]) - def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> ScenarioRunResponse: + def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> ScenarioRunSummary: """ Build a ScenarioRunResponse from a database ScenarioResult, merged with active task info. @@ -344,8 +343,10 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari status = _STATE_TO_STATUS.get(scenario_result.scenario_run_state, ScenarioRunStatus.FAILED) - # Build result summary for completed runs - result = None + # Build result fields for completed runs + strategies_used: list[str] = [] + total_attacks = 0 + completed_attacks = 0 if status == ScenarioRunStatus.COMPLETED: completed_attacks = sum( 1 @@ -354,27 +355,22 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari if ar.outcome in (AttackOutcome.SUCCESS, AttackOutcome.FAILURE) ) total_attacks = sum(len(results) for results in scenario_result.attack_results.values()) - result = ScenarioRunResult( - scenario_result_id=scenario_result_id, - run_state=scenario_result.scenario_run_state, - strategies_used=scenario_result.get_strategies_used(), - total_attacks=total_attacks, - completed_attacks=completed_attacks, - number_tries=scenario_result.number_tries, - completion_time=scenario_result.completion_time, - ) + strategies_used = scenario_result.get_strategies_used() - return ScenarioRunResponse( - run_id=scenario_result_id, + return ScenarioRunSummary( + scenario_result_id=scenario_result_id, scenario_name=scenario_result.scenario_identifier.name, status=status, created_at=scenario_result.created_at, updated_at=scenario_result.completion_time, error=error, - result=result, + strategies_used=strategies_used, + total_attacks=total_attacks, + completed_attacks=completed_attacks, + completed_at=scenario_result.completion_time, ) - def get_run_results(self, *, run_id: str) -> ScenarioResultDetailResponse | None: + def get_run_results(self, *, run_id: str) -> ScenarioRunDetail | None: """ Get detailed results for a completed scenario run. @@ -396,10 +392,10 @@ def get_run_results(self, *, run_id: str) -> ScenarioResultDetailResponse | None return None scenario_result = results[0] + run_response = self._build_response_from_db(scenario_result=scenario_result) - if scenario_result.scenario_run_state != "COMPLETED": - status = _STATE_TO_STATUS.get(scenario_result.scenario_run_state, ScenarioRunStatus.FAILED) - raise ValueError(f"Results are only available for completed runs. Current status: '{status}'.") + if run_response.status != ScenarioRunStatus.COMPLETED: + raise ValueError(f"Results are only available for completed runs. Current status: '{run_response.status}'.") # Build per-attack detail attacks: list[AtomicAttackResults] = [] @@ -449,14 +445,10 @@ def get_run_results(self, *, run_id: str) -> ScenarioResultDetailResponse | None ) ) - return ScenarioResultDetailResponse( - scenario_result_id=str(scenario_result.id), - scenario_name=scenario_result.scenario_identifier.name, + return ScenarioRunDetail( + run=run_response, scenario_version=scenario_result.scenario_identifier.version, - run_state=scenario_result.scenario_run_state, objective_achieved_rate=scenario_result.objective_achieved_rate(), - number_tries=scenario_result.number_tries, - completion_time=scenario_result.completion_time, labels=scenario_result.labels, attacks=attacks, ) diff --git a/pyrit/backend/services/scenario_service.py b/pyrit/backend/services/scenario_service.py index 52df32fe61..dc7594f4f9 100644 --- a/pyrit/backend/services/scenario_service.py +++ b/pyrit/backend/services/scenario_service.py @@ -12,11 +12,11 @@ from typing import Optional from pyrit.backend.models.common import PaginationInfo -from pyrit.backend.models.scenarios import ScenarioListResponse, ScenarioSummary +from pyrit.backend.models.scenarios import ListRegisteredScenarioResponse, RegisteredScenario from pyrit.registry import ScenarioMetadata, ScenarioRegistry -def _metadata_to_summary(metadata: ScenarioMetadata) -> ScenarioSummary: +def _metadata_to_summary(metadata: ScenarioMetadata) -> RegisteredScenario: """ Convert a ScenarioMetadata dataclass to a ScenarioSummary Pydantic model. @@ -26,7 +26,7 @@ def _metadata_to_summary(metadata: ScenarioMetadata) -> ScenarioSummary: Returns: ScenarioSummary Pydantic model. """ - return ScenarioSummary( + return RegisteredScenario( scenario_name=metadata.registry_name, scenario_type=metadata.class_name, description=metadata.class_description, @@ -54,7 +54,7 @@ async def list_scenarios_async( *, limit: int = 50, cursor: Optional[str] = None, - ) -> ScenarioListResponse: + ) -> ListRegisteredScenarioResponse: """ List all available scenarios with pagination. @@ -71,12 +71,12 @@ async def list_scenarios_async( page, has_more = self._paginate(items=all_summaries, cursor=cursor, limit=limit) next_cursor = page[-1].scenario_name if has_more and page else None - return ScenarioListResponse( + return ListRegisteredScenarioResponse( items=page, pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), ) - async def get_scenario_async(self, *, scenario_name: str) -> Optional[ScenarioSummary]: + async def get_scenario_async(self, *, scenario_name: str) -> Optional[RegisteredScenario]: """ Get a single scenario by registry name. @@ -95,10 +95,10 @@ async def get_scenario_async(self, *, scenario_name: str) -> Optional[ScenarioSu @staticmethod def _paginate( *, - items: list[ScenarioSummary], + items: list[RegisteredScenario], cursor: Optional[str], limit: int, - ) -> tuple[list[ScenarioSummary], bool]: + ) -> tuple[list[RegisteredScenario], bool]: """ Apply cursor-based pagination. diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py index bf6664c7a5..dd5a78be89 100644 --- a/tests/unit/backend/test_scenario_run_routes.py +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -17,10 +17,10 @@ from pyrit.backend.models.attacks import AttackResultDetail from pyrit.backend.models.scenarios import ( AtomicAttackResults, - ScenarioResultDetailResponse, + ScenarioRunDetail, ScenarioRunListResponse, - ScenarioRunResponse, ScenarioRunStatus, + ScenarioRunSummary, ) @@ -43,16 +43,15 @@ def _mock_run_response( run_id: str = "test-run-id", scenario_name: str = "foundry.red_team_agent", run_status: ScenarioRunStatus = ScenarioRunStatus.PENDING, -) -> ScenarioRunResponse: +) -> ScenarioRunSummary: """Create a mock ScenarioRunResponse.""" - return ScenarioRunResponse( - run_id=run_id, + return ScenarioRunSummary( + scenario_result_id=run_id, scenario_name=scenario_name, status=run_status, created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), error=None, - result=None, ) @@ -75,7 +74,7 @@ def test_start_run_returns_202(self, client: TestClient) -> None: assert response.status_code == status.HTTP_202_ACCEPTED data = response.json() - assert data["run_id"] == "test-run-id" + assert data["scenario_result_id"] == "test-run-id" assert data["status"] == "pending" def test_start_run_invalid_scenario_returns_400(self, client: TestClient) -> None: @@ -190,7 +189,7 @@ def test_get_run_not_found_returns_404(self, client: TestClient) -> None: class TestCancelScenarioRunRoute: - """Tests for DELETE /api/scenarios/runs/{run_id}.""" + """Tests for POST /api/scenarios/runs/{run_id}/cancel.""" def test_cancel_run_returns_200(self, client: TestClient) -> None: """Test that cancelling a running scenario returns 200.""" @@ -201,7 +200,7 @@ def test_cancel_run_returns_200(self, client: TestClient) -> None: mock_service.cancel_run_async = AsyncMock(return_value=mock_response) mock_get.return_value = mock_service - response = client.delete("/api/scenarios/runs/test-run-id") + response = client.post("/api/scenarios/runs/test-run-id/cancel") assert response.status_code == status.HTTP_200_OK assert response.json()["status"] == "cancelled" @@ -213,7 +212,7 @@ def test_cancel_run_not_found_returns_404(self, client: TestClient) -> None: mock_service.cancel_run_async = AsyncMock(return_value=None) mock_get.return_value = mock_service - response = client.delete("/api/scenarios/runs/nonexistent") + response = client.post("/api/scenarios/runs/nonexistent/cancel") assert response.status_code == status.HTTP_404_NOT_FOUND @@ -224,7 +223,7 @@ def test_cancel_completed_run_returns_409(self, client: TestClient) -> None: mock_service.cancel_run_async = AsyncMock(side_effect=ValueError("Cannot cancel run in 'completed' state.")) mock_get.return_value = mock_service - response = client.delete("/api/scenarios/runs/test-run-id") + response = client.post("/api/scenarios/runs/test-run-id/cancel") assert response.status_code == status.HTTP_409_CONFLICT assert "Cannot cancel" in response.json()["detail"] @@ -235,14 +234,17 @@ class TestGetScenarioRunResultsRoute: def test_get_results_returns_200(self, client: TestClient) -> None: """Test that getting results of a completed run returns 200.""" - mock_result = ScenarioResultDetailResponse( - scenario_result_id="result-uuid", - scenario_name="foundry.red_team_agent", + mock_result = ScenarioRunDetail( + run=ScenarioRunSummary( + scenario_result_id="result-uuid", + scenario_name="foundry.red_team_agent", + status=ScenarioRunStatus.COMPLETED, + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + completed_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + ), scenario_version=1, - run_state="COMPLETED", objective_achieved_rate=50, - number_tries=1, - completion_time=datetime(2025, 1, 1, tzinfo=timezone.utc), labels={"team": "red"}, attacks=[ AtomicAttackResults( @@ -278,7 +280,7 @@ def test_get_results_returns_200(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() - assert data["scenario_result_id"] == "result-uuid" + assert data["run"]["scenario_result_id"] == "result-uuid" assert data["objective_achieved_rate"] == 50 assert len(data["attacks"]) == 1 assert data["attacks"][0]["atomic_attack_name"] == "base64_attack" diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index b7c852f30e..09b7127030 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -134,7 +134,7 @@ async def test_start_run_returns_running_status(self, mock_all_registries) -> No service = ScenarioRunService() response = await service.start_run_async(request=_make_request()) - assert response.run_id == "sr-uuid-1" + assert response.scenario_result_id == "sr-uuid-1" assert response.status == ScenarioRunStatus.RUNNING assert response.scenario_name == "foundry.red_team_agent" assert response.error is None @@ -290,7 +290,7 @@ def test_get_run_returns_existing_run(self, mock_memory) -> None: fetched = service.get_run(run_id="sr-123") assert fetched is not None - assert fetched.run_id == "sr-123" + assert fetched.scenario_result_id == "sr-123" assert fetched.scenario_name == "foundry.red_team_agent" assert fetched.status == ScenarioRunStatus.RUNNING @@ -345,13 +345,13 @@ async def test_cancel_run_sets_cancelled_status(self, mock_all_registries) -> No # After update_scenario_run_state, the next DB query should return CANCELLED running_result = mock_all_registries["db_result"] - cancelled_result = _make_db_scenario_result(result_id=response.run_id, run_state="CANCELLED") + cancelled_result = _make_db_scenario_result(result_id=response.scenario_result_id, run_state="CANCELLED") mock_memory.get_scenario_results.side_effect = [[running_result], [cancelled_result]] - result = await service.cancel_run_async(run_id=response.run_id) + result = await service.cancel_run_async(run_id=response.scenario_result_id) mock_memory.update_scenario_run_state.assert_called_once_with( - scenario_result_id=response.run_id, scenario_run_state="CANCELLED" + scenario_result_id=response.scenario_result_id, scenario_run_state="CANCELLED" ) assert result is not None assert result.status == ScenarioRunStatus.CANCELLED @@ -398,16 +398,16 @@ async def test_execute_run_completes_successfully(self, mock_all_registries) -> response = await service.start_run_async(request=_make_request()) # Wait for the background task to complete - active = service._active_tasks.get(response.run_id) + active = service._active_tasks.get(response.scenario_result_id) assert active is not None assert active.task is not None await active.task # Active task is cleaned up on next get_run (deferred cleanup) - assert response.run_id in service._active_tasks - fetched = service.get_run(run_id=response.run_id) + assert response.scenario_result_id in service._active_tasks + fetched = service.get_run(run_id=response.scenario_result_id) assert fetched is not None - assert response.run_id not in service._active_tasks + assert response.scenario_result_id not in service._active_tasks async def test_execute_run_fails_with_error(self, mock_all_registries) -> None: """Test that a run_async failure stores error and surfaces it via get_run.""" @@ -419,20 +419,20 @@ async def test_execute_run_fails_with_error(self, mock_all_registries) -> None: response = await service.start_run_async(request=_make_request()) # Wait for the background task - active = service._active_tasks.get(response.run_id) + active = service._active_tasks.get(response.scenario_result_id) assert active is not None assert active.task is not None await active.task # Error is stored on the active task until get_run reads it assert active.error == "scenario exploded" - assert response.run_id in service._active_tasks + assert response.scenario_result_id in service._active_tasks # get_run should surface the error and clean up - fetched = service.get_run(run_id=response.run_id) + fetched = service.get_run(run_id=response.scenario_result_id) assert fetched is not None assert fetched.error == "scenario exploded" - assert response.run_id not in service._active_tasks + assert response.scenario_result_id not in service._active_tasks class TestScenarioRunServiceGetResults: @@ -483,7 +483,7 @@ def test_get_results_returns_details_for_completed_run(self, mock_memory) -> Non detail = service.get_run_results(run_id="sr-123") assert detail is not None - assert detail.scenario_result_id == "sr-123" + assert detail.run.scenario_result_id == "sr-123" assert detail.objective_achieved_rate == 100 assert len(detail.attacks) == 1 assert detail.attacks[0].atomic_attack_name == "base64_attack" From 9ad59380565e459d27dd9789138707227868f380 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 7 May 2026 12:17:05 -0700 Subject: [PATCH 08/10] moving other scenario data to summary --- pyrit/backend/models/scenarios.py | 6 +++--- pyrit/backend/services/scenario_run_service.py | 6 +++--- tests/unit/backend/test_scenario_run_routes.py | 8 ++++---- tests/unit/backend/test_scenario_run_service.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index d76cb66111..7a9c8dcc85 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -93,6 +93,7 @@ class ScenarioRunSummary(BaseModel): scenario_result_id: str = Field(..., description="UUID of the ScenarioResult in memory") scenario_name: str = Field(..., description="Registry key of the scenario being run") + scenario_version: int = Field(0, ge=0, description="Version of the scenario") status: ScenarioRunStatus = Field(..., description="Current run status") created_at: datetime = Field(..., description="When the run was created") updated_at: datetime = Field(..., description="When the run status last changed") @@ -100,6 +101,8 @@ class ScenarioRunSummary(BaseModel): strategies_used: list[str] = Field(default_factory=list, description="Strategy names that were executed") total_attacks: int = Field(0, ge=0, description="Total number of atomic attacks") completed_attacks: int = Field(0, ge=0, description="Number of attacks that completed") + objective_achieved_rate: int = Field(0, ge=0, le=100, description="Success rate as percentage (0-100)") + labels: dict[str, str] = Field(default_factory=dict, description="Labels attached to this run") completed_at: datetime | None = Field(None, description="When the scenario finished") @@ -129,7 +132,4 @@ class ScenarioRunDetail(BaseModel): """Full detailed results of a scenario run.""" run: ScenarioRunSummary = Field(..., description="The scenario run summary") - scenario_version: int = Field(..., description="Version of the scenario") - objective_achieved_rate: int = Field(..., ge=0, le=100, description="Success rate as percentage (0-100)") - labels: dict[str, str] = Field(default_factory=dict, description="Labels attached to this run") attacks: list[AtomicAttackResults] = Field(..., description="Results grouped by atomic attack") diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index daff5702ea..9325c06029 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -360,6 +360,7 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari return ScenarioRunSummary( scenario_result_id=scenario_result_id, scenario_name=scenario_result.scenario_identifier.name, + scenario_version=scenario_result.scenario_identifier.version, status=status, created_at=scenario_result.created_at, updated_at=scenario_result.completion_time, @@ -367,6 +368,8 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari strategies_used=strategies_used, total_attacks=total_attacks, completed_attacks=completed_attacks, + objective_achieved_rate=scenario_result.objective_achieved_rate(), + labels=scenario_result.labels, completed_at=scenario_result.completion_time, ) @@ -447,9 +450,6 @@ def get_run_results(self, *, run_id: str) -> ScenarioRunDetail | None: return ScenarioRunDetail( run=run_response, - scenario_version=scenario_result.scenario_identifier.version, - objective_achieved_rate=scenario_result.objective_achieved_rate(), - labels=scenario_result.labels, attacks=attacks, ) diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py index dd5a78be89..0d958dc112 100644 --- a/tests/unit/backend/test_scenario_run_routes.py +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -238,14 +238,14 @@ def test_get_results_returns_200(self, client: TestClient) -> None: run=ScenarioRunSummary( scenario_result_id="result-uuid", scenario_name="foundry.red_team_agent", + scenario_version=1, status=ScenarioRunStatus.COMPLETED, created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + objective_achieved_rate=50, + labels={"team": "red"}, completed_at=datetime(2025, 1, 1, tzinfo=timezone.utc), ), - scenario_version=1, - objective_achieved_rate=50, - labels={"team": "red"}, attacks=[ AtomicAttackResults( atomic_attack_name="base64_attack", @@ -281,7 +281,7 @@ def test_get_results_returns_200(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() assert data["run"]["scenario_result_id"] == "result-uuid" - assert data["objective_achieved_rate"] == 50 + assert data["run"]["objective_achieved_rate"] == 50 assert len(data["attacks"]) == 1 assert data["attacks"][0]["atomic_attack_name"] == "base64_attack" assert data["attacks"][0]["results"][0]["outcome"] == "success" diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 09b7127030..ce95d8ba20 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -484,7 +484,7 @@ def test_get_results_returns_details_for_completed_run(self, mock_memory) -> Non assert detail is not None assert detail.run.scenario_result_id == "sr-123" - assert detail.objective_achieved_rate == 100 + assert detail.run.objective_achieved_rate == 100 assert len(detail.attacks) == 1 assert detail.attacks[0].atomic_attack_name == "base64_attack" assert detail.attacks[0].success_count == 1 From cdae91d8726c2e49f8715f9b11b059daab5630f0 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 7 May 2026 15:21:19 -0700 Subject: [PATCH 09/10] pr feedback --- pyrit/backend/models/attacks.py | 23 +++---- pyrit/backend/models/scenarios.py | 4 +- pyrit/backend/routes/scenarios.py | 40 ++++++------ .../backend/services/scenario_run_service.py | 63 +++++++++++-------- pyrit/backend/services/scenario_service.py | 6 +- .../unit/backend/test_scenario_run_routes.py | 13 ++-- .../unit/backend/test_scenario_run_service.py | 28 ++++----- 7 files changed, 91 insertions(+), 86 deletions(-) diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 24f2855318..64d6269e44 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -77,21 +77,6 @@ class Message(BaseModel): created_at: datetime = Field(..., description="Message creation timestamp") -class AttackResultDetail(BaseModel): - """Detailed result of a single attack within a scenario.""" - - attack_result_id: str = Field(..., description="Unique ID of this attack result") - conversation_id: str = Field(..., description="Conversation ID that produced this result") - objective: str = Field(..., description="Natural-language description of the attacker's objective") - outcome: str = Field(..., description="Attack outcome: success, failure, or undetermined") - outcome_reason: str | None = Field(None, description="Reason for the outcome") - last_response: str | None = Field(None, description="Model response from the final turn") - score_value: str | None = Field(None, description="Score value from the objective scorer") - executed_turns: int = Field(0, ge=0, description="Number of turns executed") - execution_time_ms: int = Field(0, ge=0, description="Execution time in milliseconds") - timestamp: datetime | None = Field(None, description="When the result was created") - - # ============================================================================ # Attack Summary (List View) # ============================================================================ @@ -110,18 +95,24 @@ class AttackSummary(BaseModel): attack_result_id: str = Field(..., description="Database-assigned unique ID for this AttackResult") conversation_id: str = Field(..., description="Primary conversation of this attack result") - attack_type: str = Field(..., description="Attack class name (e.g., 'CrescendoAttack', 'ManualAttack')") + attack_type: str = Field("", description="Attack class name (e.g., 'CrescendoAttack', 'ManualAttack')") attack_specific_params: Optional[dict[str, Any]] = Field(None, description="Additional attack-specific parameters") target: Optional[TargetInfo] = Field(None, description="Target information from the stored identifier") converters: list[str] = Field( default_factory=list, description="Request converter class names applied in this attack" ) + objective: str = Field("", description="Natural-language description of the attacker's objective") outcome: Optional[Literal["undetermined", "success", "failure"]] = Field( None, description="Attack outcome (null if not yet determined)" ) + outcome_reason: str | None = Field(None, description="Reason for the outcome") + last_response: str | None = Field(None, description="Model response from the final turn") last_message_preview: Optional[str] = Field( None, description="Preview of the last message (truncated to ~100 chars)" ) + score_value: str | None = Field(None, description="Score value from the objective scorer") + executed_turns: int = Field(0, ge=0, description="Number of turns executed") + execution_time_ms: int = Field(0, ge=0, description="Execution time in milliseconds") message_count: int = Field(0, description="Total number of messages in the attack") related_conversation_ids: list[str] = Field( default_factory=list, description="IDs of related conversations within this attack" diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index 7a9c8dcc85..5a8fdf252a 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -14,7 +14,7 @@ from pydantic import BaseModel, Field -from pyrit.backend.models.attacks import AttackResultDetail +from pyrit.backend.models.attacks import AttackSummary from pyrit.backend.models.common import PaginationInfo @@ -122,7 +122,7 @@ class AtomicAttackResults(BaseModel): atomic_attack_name: str = Field(..., description="Name of the atomic attack (strategy)") display_group: str | None = Field(None, description="Display group label for UI grouping") - results: list[AttackResultDetail] = Field(..., description="Individual attack results") + results: list[AttackSummary] = Field(..., description="Individual attack results") success_count: int = Field(0, ge=0, description="Number of successful attacks") failure_count: int = Field(0, ge=0, description="Number of failed attacks") total_count: int = Field(0, ge=0, description="Total number of attack results") diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py index ab825b9f9f..a9e6e885fe 100644 --- a/pyrit/backend/routes/scenarios.py +++ b/pyrit/backend/routes/scenarios.py @@ -103,13 +103,13 @@ async def start_scenario_run(request: RunScenarioRequest) -> ScenarioRunSummary: """ Start a new scenario run as a background task. - Returns immediately with a run_id that can be polled for status. + Returns immediately with a scenario_result_id that can be polled for status. Args: request: Scenario run configuration. Returns: - ScenarioRunResponse: Run metadata with PENDING status. + ScenarioRunSummary: Run metadata with PENDING status. """ service = get_scenario_run_service() try: @@ -137,73 +137,73 @@ async def list_scenario_runs(limit: int = Query(100, ge=1)) -> ScenarioRunListRe @router.get( - "/runs/{run_id}", + "/runs/{scenario_result_id}", response_model=ScenarioRunSummary, responses={ 404: {"model": ProblemDetail, "description": "Run not found"}, }, ) -async def get_scenario_run(run_id: str) -> ScenarioRunSummary: +async def get_scenario_run(scenario_result_id: str) -> ScenarioRunSummary: """ Get the current status and result of a scenario run. Args: - run_id: The unique run identifier returned by POST /runs. + scenario_result_id: The scenario_result_id returned by POST /runs. Returns: - ScenarioRunResponse: Current run status (and result if completed). + ScenarioRunSummary: Current run status (and result if completed). """ service = get_scenario_run_service() - run = service.get_run(run_id=run_id) + run = service.get_run(scenario_result_id=scenario_result_id) if run is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Scenario run '{run_id}' not found", + detail=f"Scenario run '{scenario_result_id}' not found", ) return run @router.post( - "/runs/{run_id}/cancel", + "/runs/{scenario_result_id}/cancel", response_model=ScenarioRunSummary, responses={ 404: {"model": ProblemDetail, "description": "Run not found"}, 409: {"model": ProblemDetail, "description": "Run already in terminal state"}, }, ) -async def cancel_scenario_run(run_id: str) -> ScenarioRunSummary: +async def cancel_scenario_run(scenario_result_id: str) -> ScenarioRunSummary: """ Cancel a running scenario. Args: - run_id: The unique run identifier to cancel. + scenario_result_id: The scenario_result_id to cancel. Returns: - ScenarioRunResponse: Updated run with CANCELLED status. + ScenarioRunSummary: Updated run with CANCELLED status. """ service = get_scenario_run_service() try: - result = await service.cancel_run_async(run_id=run_id) + result = await service.cancel_run_async(scenario_result_id=scenario_result_id) except ValueError as e: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from None if result is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Scenario run '{run_id}' not found", + detail=f"Scenario run '{scenario_result_id}' not found", ) return result @router.get( - "/runs/{run_id}/results", + "/runs/{scenario_result_id}/results", response_model=ScenarioRunDetail, responses={ 404: {"model": ProblemDetail, "description": "Run not found"}, 409: {"model": ProblemDetail, "description": "Run not yet completed"}, }, ) -async def get_scenario_run_results(run_id: str) -> ScenarioRunDetail: +async def get_scenario_run_results(scenario_result_id: str) -> ScenarioRunDetail: """ Get detailed results for a completed scenario run. @@ -211,20 +211,20 @@ async def get_scenario_run_results(run_id: str) -> ScenarioRunDetail: and success/failure counts. Args: - run_id: The unique run identifier. + scenario_result_id: The scenario_result_id. Returns: - ScenarioResultDetailResponse: Full attack-level results. + ScenarioRunDetail: Full attack-level results. """ service = get_scenario_run_service() try: - result = service.get_run_results(run_id=run_id) + result = service.get_run_results(scenario_result_id=scenario_result_id) except ValueError as e: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from None if result is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Scenario run '{run_id}' not found", + detail=f"Scenario run '{scenario_result_id}' not found", ) return result diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 9325c06029..d00f9b0caa 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -9,13 +9,15 @@ """ import asyncio +import contextlib import logging from dataclasses import dataclass +from datetime import datetime, timezone from typing import Any from pyrit.backend.models.scenarios import ( AtomicAttackResults, - AttackResultDetail, + AttackSummary, RunScenarioRequest, ScenarioRunDetail, ScenarioRunListResponse, @@ -64,6 +66,7 @@ def __init__(self, *, max_concurrent_runs: int = _DEFAULT_MAX_CONCURRENT_RUNS) - """Initialize the scenario run service.""" self._max_concurrent_runs = max_concurrent_runs self._active_tasks: dict[str, _ActiveTask] = {} + self._run_semaphore = asyncio.Semaphore(max_concurrent_runs) async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunSummary: """ @@ -84,17 +87,20 @@ async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunSu ValueError: If scenario, target, initializer, or strategy cannot be found, or concurrent limit exceeded. """ - if ( - sum(1 for a in self._active_tasks.values() if a.task is not None and not a.task.done()) - >= self._max_concurrent_runs - ): + if self._run_semaphore.locked(): raise ValueError( f"Maximum concurrent runs ({self._max_concurrent_runs}) reached. " "Wait for an existing run to complete or cancel one." ) + await self._run_semaphore.acquire() + # Perform all initialization eagerly — errors propagate to caller - scenario = await self._initialize_run_async(request=request) + try: + scenario = await self._initialize_run_async(request=request) + except Exception: + self._run_semaphore.release() + raise # scenario_result_id is set during initialize_async scenario_result_id = scenario._scenario_result_id @@ -113,17 +119,17 @@ async def start_run_async(self, *, request: RunScenarioRequest) -> ScenarioRunSu assert response is not None # guaranteed: we just inserted into DB via initialize_async return response - def get_run(self, *, run_id: str) -> ScenarioRunSummary | None: + def get_run(self, *, scenario_result_id: str) -> ScenarioRunSummary | None: """ Get the current status of a scenario run by querying the database. Args: - run_id: The scenario result ID (run identifier). + scenario_result_id: The scenario result ID. Returns: - ScenarioRunResponse if found, None otherwise. + ScenarioRunSummary if found, None otherwise. """ - return self._build_response(scenario_result_id=run_id) + return self._build_response(scenario_result_id=scenario_result_id) def list_runs(self, *, limit: int = 100) -> ScenarioRunListResponse: """ @@ -143,22 +149,22 @@ def list_runs(self, *, limit: int = 100) -> ScenarioRunListResponse: items = [self._build_response_from_db(scenario_result=sr) for sr in results] return ScenarioRunListResponse(items=items) - async def cancel_run_async(self, *, run_id: str) -> ScenarioRunSummary | None: + async def cancel_run_async(self, *, scenario_result_id: str) -> ScenarioRunSummary | None: """ Cancel a running scenario. Args: - run_id: The scenario result ID (run identifier). + scenario_result_id: The scenario result ID. Returns: - Updated ScenarioRunResponse if found, None if run_id not found. + Updated ScenarioRunSummary if found, None if not found. Raises: ValueError: If the run is already in a terminal state or not active. """ # Verify run exists in DB memory = CentralMemory.get_memory_instance() - results = memory.get_scenario_results(scenario_result_ids=[run_id]) + results = memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) if not results: return None @@ -168,15 +174,17 @@ async def cancel_run_async(self, *, run_id: str) -> ScenarioRunSummary | None: if db_status in (ScenarioRunStatus.COMPLETED, ScenarioRunStatus.FAILED, ScenarioRunStatus.CANCELLED): raise ValueError(f"Cannot cancel run in '{db_status}' state.") - # Cancel the asyncio task if active - active = self._active_tasks.get(run_id) + # Cancel the asyncio task if active and wait for it to finish + active = self._active_tasks.get(scenario_result_id) if active is not None and active.task is not None and not active.task.done(): active.task.cancel() + with contextlib.suppress(asyncio.CancelledError, asyncio.TimeoutError): + await asyncio.wait_for(active.task, timeout=5.0) # Persist cancelled state to DB - memory.update_scenario_run_state(scenario_result_id=run_id, scenario_run_state="CANCELLED") + memory.update_scenario_run_state(scenario_result_id=scenario_result_id, scenario_run_state="CANCELLED") - return self._build_response(scenario_result_id=run_id) + return self._build_response(scenario_result_id=scenario_result_id) async def _initialize_run_async(self, *, request: RunScenarioRequest) -> Scenario: """ @@ -305,6 +313,9 @@ async def _execute_run_async(self, *, scenario_result_id: str) -> None: active.error = str(e) logger.exception(f"Scenario run {scenario_result_id} failed: {e}") + finally: + self._run_semaphore.release() + def _build_response(self, *, scenario_result_id: str) -> ScenarioRunSummary | None: """ Build a ScenarioRunResponse by querying the database and merging active task state. @@ -373,7 +384,7 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari completed_at=scenario_result.completion_time, ) - def get_run_results(self, *, run_id: str) -> ScenarioRunDetail | None: + def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | None: """ Get detailed results for a completed scenario run. @@ -381,16 +392,16 @@ def get_run_results(self, *, run_id: str) -> ScenarioRunDetail | None: to a detailed response model with per-attack outcomes. Args: - run_id: The scenario result ID (run identifier). + scenario_result_id: The scenario result ID. Returns: - ScenarioResultDetailResponse if the run is completed and results exist, None if run not found. + ScenarioRunDetail if the run is completed and results exist, None if not found. Raises: ValueError: If the run is not in a completed state. """ memory = CentralMemory.get_memory_instance() - results = memory.get_scenario_results(scenario_result_ids=[run_id]) + results = memory.get_scenario_results(scenario_result_ids=[scenario_result_id]) if not results: return None @@ -404,7 +415,7 @@ def get_run_results(self, *, run_id: str) -> ScenarioRunDetail | None: attacks: list[AtomicAttackResults] = [] display_group_map = scenario_result.display_group_map for attack_name, attack_results in scenario_result.attack_results.items(): - details: list[AttackResultDetail] = [] + details: list[AttackSummary] = [] success_count = 0 failure_count = 0 @@ -417,8 +428,9 @@ def get_run_results(self, *, run_id: str) -> ScenarioRunDetail | None: if ar.last_response is not None: last_response_text = str(ar.last_response) + timestamp = ar.timestamp or datetime.now(timezone.utc) details.append( - AttackResultDetail( + AttackSummary( attack_result_id=ar.attack_result_id, conversation_id=ar.conversation_id, objective=ar.objective, @@ -428,7 +440,8 @@ def get_run_results(self, *, run_id: str) -> ScenarioRunDetail | None: score_value=score_value, executed_turns=ar.executed_turns, execution_time_ms=ar.execution_time_ms, - timestamp=ar.timestamp, + created_at=timestamp, + updated_at=timestamp, ) ) diff --git a/pyrit/backend/services/scenario_service.py b/pyrit/backend/services/scenario_service.py index dc7594f4f9..77b4739020 100644 --- a/pyrit/backend/services/scenario_service.py +++ b/pyrit/backend/services/scenario_service.py @@ -16,7 +16,7 @@ from pyrit.registry import ScenarioMetadata, ScenarioRegistry -def _metadata_to_summary(metadata: ScenarioMetadata) -> RegisteredScenario: +def _metadata_to_registered_scenario(metadata: ScenarioMetadata) -> RegisteredScenario: """ Convert a ScenarioMetadata dataclass to a ScenarioSummary Pydantic model. @@ -66,7 +66,7 @@ async def list_scenarios_async( ScenarioListResponse with paginated scenario summaries. """ all_metadata = self._registry.list_metadata() - all_summaries = [_metadata_to_summary(m) for m in all_metadata] + all_summaries = [_metadata_to_registered_scenario(m) for m in all_metadata] page, has_more = self._paginate(items=all_summaries, cursor=cursor, limit=limit) next_cursor = page[-1].scenario_name if has_more and page else None @@ -89,7 +89,7 @@ async def get_scenario_async(self, *, scenario_name: str) -> Optional[Registered all_metadata = self._registry.list_metadata() for metadata in all_metadata: if metadata.registry_name == scenario_name: - return _metadata_to_summary(metadata) + return _metadata_to_registered_scenario(metadata) return None @staticmethod diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py index 0d958dc112..9eceff77e7 100644 --- a/tests/unit/backend/test_scenario_run_routes.py +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -14,7 +14,7 @@ import pyrit.backend.services.scenario_run_service as _svc_mod from pyrit.backend.main import app -from pyrit.backend.models.attacks import AttackResultDetail +from pyrit.backend.models.attacks import AttackSummary from pyrit.backend.models.scenarios import ( AtomicAttackResults, ScenarioRunDetail, @@ -160,7 +160,7 @@ def test_list_runs_returns_multiple_runs(self, client: TestClient) -> None: class TestGetScenarioRunRoute: - """Tests for GET /api/scenarios/runs/{run_id}.""" + """Tests for GET /api/scenarios/runs/{id}.""" def test_get_run_returns_200(self, client: TestClient) -> None: """Test that getting an existing run returns 200.""" @@ -189,7 +189,7 @@ def test_get_run_not_found_returns_404(self, client: TestClient) -> None: class TestCancelScenarioRunRoute: - """Tests for POST /api/scenarios/runs/{run_id}/cancel.""" + """Tests for POST /api/scenarios/runs/{id}/cancel.""" def test_cancel_run_returns_200(self, client: TestClient) -> None: """Test that cancelling a running scenario returns 200.""" @@ -230,7 +230,7 @@ def test_cancel_completed_run_returns_409(self, client: TestClient) -> None: class TestGetScenarioRunResultsRoute: - """Tests for GET /api/scenarios/runs/{run_id}/results.""" + """Tests for GET /api/scenarios/runs/{id}/results.""" def test_get_results_returns_200(self, client: TestClient) -> None: """Test that getting results of a completed run returns 200.""" @@ -251,7 +251,7 @@ def test_get_results_returns_200(self, client: TestClient) -> None: atomic_attack_name="base64_attack", display_group="encoding", results=[ - AttackResultDetail( + AttackSummary( attack_result_id="ar-1", conversation_id="conv-1", objective="Extract sensitive info", @@ -261,7 +261,8 @@ def test_get_results_returns_200(self, client: TestClient) -> None: score_value="1.0", executed_turns=3, execution_time_ms=1500, - timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), + created_at=datetime(2025, 1, 1, tzinfo=timezone.utc), + updated_at=datetime(2025, 1, 1, tzinfo=timezone.utc), ), ], success_count=1, diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index ce95d8ba20..524e63918c 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -275,10 +275,10 @@ class TestScenarioRunServiceGetRun: """Tests for ScenarioRunService.get_run.""" def test_get_run_returns_none_for_unknown_id(self, mock_memory) -> None: - """Test that get_run returns None for non-existent run_id.""" + """Test that get_run returns None for non-existent run.""" mock_memory.get_scenario_results.return_value = [] service = ScenarioRunService() - result = service.get_run(run_id="nonexistent-id") + result = service.get_run(scenario_result_id="nonexistent-id") assert result is None def test_get_run_returns_existing_run(self, mock_memory) -> None: @@ -287,7 +287,7 @@ def test_get_run_returns_existing_run(self, mock_memory) -> None: mock_memory.get_scenario_results.return_value = [db_result] service = ScenarioRunService() - fetched = service.get_run(run_id="sr-123") + fetched = service.get_run(scenario_result_id="sr-123") assert fetched is not None assert fetched.scenario_result_id == "sr-123" @@ -331,10 +331,10 @@ class TestScenarioRunServiceCancelRun: """Tests for ScenarioRunService.cancel_run_async.""" async def test_cancel_run_returns_none_for_unknown_id(self, mock_memory) -> None: - """Test that cancel returns None for non-existent run_id.""" + """Test that cancel returns None for non-existent run.""" mock_memory.get_scenario_results.return_value = [] service = ScenarioRunService() - result = await service.cancel_run_async(run_id="nonexistent-id") + result = await service.cancel_run_async(scenario_result_id="nonexistent-id") assert result is None async def test_cancel_run_sets_cancelled_status(self, mock_all_registries) -> None: @@ -348,7 +348,7 @@ async def test_cancel_run_sets_cancelled_status(self, mock_all_registries) -> No cancelled_result = _make_db_scenario_result(result_id=response.scenario_result_id, run_state="CANCELLED") mock_memory.get_scenario_results.side_effect = [[running_result], [cancelled_result]] - result = await service.cancel_run_async(run_id=response.scenario_result_id) + result = await service.cancel_run_async(scenario_result_id=response.scenario_result_id) mock_memory.update_scenario_run_state.assert_called_once_with( scenario_result_id=response.scenario_result_id, scenario_run_state="CANCELLED" @@ -363,7 +363,7 @@ async def test_cancel_completed_run_raises_value_error(self, mock_memory) -> Non service = ScenarioRunService() with pytest.raises(ValueError, match="Cannot cancel run"): - await service.cancel_run_async(run_id="sr-done") + await service.cancel_run_async(scenario_result_id="sr-done") async def test_cancel_already_cancelled_run_raises_value_error(self, mock_memory) -> None: """Test that cancelling an already-cancelled run raises ValueError.""" @@ -372,7 +372,7 @@ async def test_cancel_already_cancelled_run_raises_value_error(self, mock_memory service = ScenarioRunService() with pytest.raises(ValueError, match="Cannot cancel run"): - await service.cancel_run_async(run_id="sr-cancelled") + await service.cancel_run_async(scenario_result_id="sr-cancelled") class TestScenarioRunServiceExecution: @@ -405,7 +405,7 @@ async def test_execute_run_completes_successfully(self, mock_all_registries) -> # Active task is cleaned up on next get_run (deferred cleanup) assert response.scenario_result_id in service._active_tasks - fetched = service.get_run(run_id=response.scenario_result_id) + fetched = service.get_run(scenario_result_id=response.scenario_result_id) assert fetched is not None assert response.scenario_result_id not in service._active_tasks @@ -429,7 +429,7 @@ async def test_execute_run_fails_with_error(self, mock_all_registries) -> None: assert response.scenario_result_id in service._active_tasks # get_run should surface the error and clean up - fetched = service.get_run(run_id=response.scenario_result_id) + fetched = service.get_run(scenario_result_id=response.scenario_result_id) assert fetched is not None assert fetched.error == "scenario exploded" assert response.scenario_result_id not in service._active_tasks @@ -439,10 +439,10 @@ class TestScenarioRunServiceGetResults: """Tests for ScenarioRunService.get_run_results.""" def test_get_results_returns_none_for_unknown_id(self, mock_memory) -> None: - """Test that get_run_results returns None for non-existent run_id.""" + """Test that get_run_results returns None for non-existent run.""" mock_memory.get_scenario_results.return_value = [] service = ScenarioRunService() - result = service.get_run_results(run_id="nonexistent-id") + result = service.get_run_results(scenario_result_id="nonexistent-id") assert result is None def test_get_results_raises_if_not_completed(self, mock_memory) -> None: @@ -452,7 +452,7 @@ def test_get_results_raises_if_not_completed(self, mock_memory) -> None: service = ScenarioRunService() with pytest.raises(ValueError, match="only available for completed runs"): - service.get_run_results(run_id="sr-running") + service.get_run_results(scenario_result_id="sr-running") def test_get_results_returns_details_for_completed_run(self, mock_memory) -> None: """Test that get_run_results returns full details for a completed run.""" @@ -480,7 +480,7 @@ def test_get_results_returns_details_for_completed_run(self, mock_memory) -> Non mock_memory.get_scenario_results.return_value = [db_result] service = ScenarioRunService() - detail = service.get_run_results(run_id="sr-123") + detail = service.get_run_results(scenario_result_id="sr-123") assert detail is not None assert detail.run.scenario_result_id == "sr-123" From 6173e85377a82b93233aced1b6053ee609196203 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 7 May 2026 15:31:08 -0700 Subject: [PATCH 10/10] fixing test --- tests/unit/backend/test_scenario_service.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/unit/backend/test_scenario_service.py b/tests/unit/backend/test_scenario_service.py index 04d61e622a..cffff85583 100644 --- a/tests/unit/backend/test_scenario_service.py +++ b/tests/unit/backend/test_scenario_service.py @@ -13,7 +13,7 @@ from pyrit.backend.main import app from pyrit.backend.models.common import PaginationInfo -from pyrit.backend.models.scenarios import ScenarioListResponse, ScenarioSummary +from pyrit.backend.models.scenarios import ListRegisteredScenarioResponse, RegisteredScenario from pyrit.backend.services.scenario_service import ScenarioService, get_scenario_service from pyrit.registry import ScenarioMetadata @@ -206,7 +206,7 @@ def test_list_scenarios_returns_200(self, client: TestClient) -> None: with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: mock_service = MagicMock() mock_service.list_scenarios_async = AsyncMock( - return_value=ScenarioListResponse( + return_value=ListRegisteredScenarioResponse( items=[], pagination=PaginationInfo(limit=50, has_more=False, next_cursor=None, prev_cursor=None), ) @@ -222,7 +222,7 @@ def test_list_scenarios_returns_200(self, client: TestClient) -> None: def test_list_scenarios_with_items(self, client: TestClient) -> None: """Test that GET /api/scenarios/catalog returns scenario data.""" - summary = ScenarioSummary( + summary = RegisteredScenario( scenario_name="foundry.red_team_agent", scenario_type="RedTeamAgentScenario", description="Red team agent testing", @@ -236,7 +236,7 @@ def test_list_scenarios_with_items(self, client: TestClient) -> None: with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: mock_service = MagicMock() mock_service.list_scenarios_async = AsyncMock( - return_value=ScenarioListResponse( + return_value=ListRegisteredScenarioResponse( items=[summary], pagination=PaginationInfo(limit=50, has_more=False, next_cursor=None, prev_cursor=None), ) @@ -262,7 +262,7 @@ def test_list_scenarios_passes_pagination_params(self, client: TestClient) -> No with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: mock_service = MagicMock() mock_service.list_scenarios_async = AsyncMock( - return_value=ScenarioListResponse( + return_value=ListRegisteredScenarioResponse( items=[], pagination=PaginationInfo(limit=10, has_more=False, next_cursor=None, prev_cursor=None), ) @@ -276,7 +276,7 @@ def test_list_scenarios_passes_pagination_params(self, client: TestClient) -> No def test_get_scenario_returns_200(self, client: TestClient) -> None: """Test that GET /api/scenarios/catalog/{name} returns 200 when found.""" - summary = ScenarioSummary( + summary = RegisteredScenario( scenario_name="foundry.red_team_agent", scenario_type="RedTeamAgentScenario", description="Red team agent testing", @@ -311,7 +311,7 @@ def test_get_scenario_returns_404_when_not_found(self, client: TestClient) -> No def test_get_scenario_with_dotted_name(self, client: TestClient) -> None: """Test that dotted scenario names (e.g., 'foundry.red_team_agent') work in path.""" - summary = ScenarioSummary( + summary = RegisteredScenario( scenario_name="garak.encoding", scenario_type="EncodingScenario", description="Encoding scenario",