diff --git a/frontend/src/components/Chat/ChatInputArea.test.tsx b/frontend/src/components/Chat/ChatInputArea.test.tsx index 570fb93039..9411313151 100644 --- a/frontend/src/components/Chat/ChatInputArea.test.tsx +++ b/frontend/src/components/Chat/ChatInputArea.test.tsx @@ -4,12 +4,27 @@ import userEvent from "@testing-library/user-event"; import { FluentProvider, webLightTheme } from "@fluentui/react-components"; import ChatInputArea from "./ChatInputArea"; import type { ChatInputAreaHandle } from "./ChatInputArea"; +import type { TargetCapabilitiesInfo } from "../../types"; // Wrapper component for Fluent UI context const TestWrapper: React.FC<{ children: React.ReactNode }> = ({ children }) => ( {children} ); +const buildCapabilities = ( + overrides: Partial = {} +): TargetCapabilitiesInfo => ({ + supports_multi_turn: true, + supports_multi_message_pieces: false, + supports_json_schema: false, + supports_json_output: false, + supports_editable_history: false, + supports_system_prompt: false, + supported_input_modalities: [], + supported_output_modalities: [], + ...overrides, +}); + // Helper to get the send button specifically const getSendButton = () => screen.getByRole("button", { name: /send/i }); @@ -367,7 +382,7 @@ describe("ChatInputArea", () => { activeTarget={{ target_registry_name: "test", target_type: "TextTarget", - supports_multi_turn: false, + capabilities: buildCapabilities({ supports_multi_turn: false }), }} /> @@ -388,7 +403,7 @@ describe("ChatInputArea", () => { activeTarget={{ target_registry_name: "test", target_type: "OpenAIChatTarget", - supports_multi_turn: true, + capabilities: buildCapabilities({ supports_multi_turn: true }), }} /> diff --git a/frontend/src/components/Chat/ChatInputArea.tsx b/frontend/src/components/Chat/ChatInputArea.tsx index 615244d0e9..0a3cf6ce0d 100644 --- a/frontend/src/components/Chat/ChatInputArea.tsx +++ b/frontend/src/components/Chat/ChatInputArea.tsx @@ -439,7 +439,7 @@ const ChatInputArea = forwardRef(functi />
- {activeTarget && activeTarget.supports_multi_turn === false && ( + {activeTarget && activeTarget.capabilities?.supports_multi_turn === false && ( = {} +): TargetCapabilitiesInfo => ({ + supports_multi_turn: true, + supports_multi_message_pieces: false, + supports_json_schema: false, + supports_json_output: false, + supports_editable_history: false, + supports_system_prompt: false, + supported_input_modalities: [], + supported_output_modalities: [], + ...overrides, +}); + // Fluent UI Combobox portal interactions are slow in JSDOM under full test load jest.setTimeout(60000); @@ -1227,7 +1242,7 @@ describe("ChatWindow Integration", () => { const singleTurnTarget: TargetInstance = { target_registry_name: "openai_image_1", target_type: "OpenAIImageTarget", - supports_multi_turn: false, + capabilities: buildCapabilities({ supports_multi_turn: false }), }; const messagesWithUser: Message[] = [ @@ -1261,7 +1276,7 @@ describe("ChatWindow Integration", () => { const singleTurnTarget: TargetInstance = { target_registry_name: "openai_image_1", target_type: "OpenAIImageTarget", - supports_multi_turn: false, + capabilities: buildCapabilities({ supports_multi_turn: false }), }; render( @@ -1310,7 +1325,7 @@ describe("ChatWindow Integration", () => { const singleTurnTarget: TargetInstance = { target_registry_name: "openai_tts_1", target_type: "OpenAITTSTarget", - supports_multi_turn: false, + capabilities: buildCapabilities({ supports_multi_turn: false }), }; const messagesWithUser: Message[] = [ @@ -1523,7 +1538,7 @@ describe("ChatWindow Integration", () => { const singleTurnTarget: TargetInstance = { target_registry_name: "openai_image_1", target_type: "OpenAIImageTarget", - supports_multi_turn: false, + capabilities: buildCapabilities({ supports_multi_turn: false }), }; const messagesWithUser: Message[] = [ diff --git a/frontend/src/components/Chat/ChatWindow.tsx b/frontend/src/components/Chat/ChatWindow.tsx index 89d32be4b1..8810e19182 100644 --- a/frontend/src/components/Chat/ChatWindow.tsx +++ b/frontend/src/components/Chat/ChatWindow.tsx @@ -446,7 +446,7 @@ export default function ChatWindow({ } }, [attackResultId]) - const singleTurnLimitReached = activeTarget?.supports_multi_turn === false && messages.some(m => m.role === 'user') + const singleTurnLimitReached = activeTarget?.capabilities?.supports_multi_turn === false && messages.some(m => m.role === 'user') // Operator locking: if the loaded attack's operator differs from the current // user's operator label, the conversation should be read-only. @@ -561,7 +561,7 @@ export default function ChatWindow({ onBranchConversation={attackResultId && activeConversationId ? handleBranchConversation : undefined} onBranchAttack={activeTarget && activeConversationId ? handleBranchAttack : undefined} isLoading={isLoadingAttack || isLoadingMessages || awaitingConversationLoad} - isSingleTurn={activeTarget?.supports_multi_turn === false} + isSingleTurn={activeTarget?.capabilities?.supports_multi_turn === false} isOperatorLocked={isOperatorLocked} isCrossTarget={isCrossTargetLocked} noTargetSelected={!activeTarget} diff --git a/frontend/src/components/Config/TargetConfig.test.tsx b/frontend/src/components/Config/TargetConfig.test.tsx index 690e91672b..c798373087 100644 --- a/frontend/src/components/Config/TargetConfig.test.tsx +++ b/frontend/src/components/Config/TargetConfig.test.tsx @@ -340,7 +340,7 @@ describe("TargetConfig", () => { }); // No reasoning or other special params should be displayed - expect(screen.queryByText(/reasoning_effort/)).not.toBeInTheDocument(); + expect(screen.queryByText(/reasoning_effort:/)).not.toBeInTheDocument(); }); it("should open dialog when Create First Target button is clicked in empty state", async () => { diff --git a/frontend/src/components/Config/TargetTable.styles.ts b/frontend/src/components/Config/TargetTable.styles.ts index 6c02a7077c..906db7578c 100644 --- a/frontend/src/components/Config/TargetTable.styles.ts +++ b/frontend/src/components/Config/TargetTable.styles.ts @@ -9,6 +9,12 @@ export const useTargetTableStyles = makeStyles({ tableLayout: 'fixed', width: '100%', }, + stickyHeader: { + position: 'sticky', + top: 0, + backgroundColor: tokens.colorNeutralBackground1, + zIndex: 1, + }, activeRow: { backgroundColor: tokens.colorBrandBackground2, }, @@ -20,4 +26,50 @@ export const useTargetTableStyles = makeStyles({ whiteSpace: 'pre-line', wordBreak: 'break-word', }, + capabilityCell: { + width: '75px', + textAlign: 'center', + }, + modalityCell: { + width: '110px', + textAlign: 'center', + }, + inputsModalityCell: { + width: '160px', + textAlign: 'center', + }, + modalityRow: { + display: 'inline-flex', + alignItems: 'center', + justifyContent: 'center', + gap: tokens.spacingHorizontalXS, + flexWrap: 'wrap', + }, + modalityIcon: { + fontSize: tokens.fontSizeBase500, + color: tokens.colorNeutralForeground2, + }, + compositeIcon: { + position: 'relative', + display: 'inline-flex', + lineHeight: 0, + }, + compositeBadge: { + position: 'absolute', + top: '-4px', + right: '-6px', + fontSize: tokens.fontSizeBase300, + color: tokens.colorNeutralForeground2, + }, + capabilityIconSupported: { + color: tokens.colorPaletteGreenForeground1, + fontSize: tokens.fontSizeBase500, + }, + capabilityIconUnsupported: { + color: tokens.colorPaletteRedForeground1, + fontSize: tokens.fontSizeBase500, + }, + helpHeader: { + cursor: 'help', + }, }) diff --git a/frontend/src/components/Config/TargetTable.test.tsx b/frontend/src/components/Config/TargetTable.test.tsx index 52618ef2ff..51289614ec 100644 --- a/frontend/src/components/Config/TargetTable.test.tsx +++ b/frontend/src/components/Config/TargetTable.test.tsx @@ -17,12 +17,32 @@ const sampleTargets: TargetInstance[] = [ target_type: 'OpenAIChatTarget', endpoint: 'https://api.openai.com', model_name: 'gpt-4', + capabilities: { + supports_multi_turn: true, + supports_multi_message_pieces: true, + supports_json_schema: true, + supports_json_output: true, + supports_editable_history: true, + supports_system_prompt: true, + supported_input_modalities: ['text', 'image_path'], + supported_output_modalities: ['text'], + }, }, { target_registry_name: 'azure_image_dalle', target_type: 'AzureImageTarget', endpoint: 'https://azure.openai.com', model_name: 'dall-e-3', + capabilities: { + supports_multi_turn: false, + supports_multi_message_pieces: false, + supports_json_schema: false, + supports_json_output: false, + supports_editable_history: false, + supports_system_prompt: false, + supported_input_modalities: ['text'], + supported_output_modalities: ['image_path'], + }, }, { target_registry_name: 'text_target_basic', @@ -58,7 +78,7 @@ describe('TargetTable', () => { expect(screen.getAllByText('TextTarget').length).toBeGreaterThanOrEqual(1) }) - it('should display Type, Model, Endpoint and Parameters columns', () => { + it('should display Type, Model, Endpoint, Inputs, Outputs, capability columns and Parameters columns', () => { render( @@ -68,6 +88,14 @@ describe('TargetTable', () => { expect(screen.getByText('Type')).toBeInTheDocument() expect(screen.getByText('Model')).toBeInTheDocument() expect(screen.getByText('Endpoint')).toBeInTheDocument() + expect(screen.getByText('Inputs')).toBeInTheDocument() + expect(screen.getByText('Outputs')).toBeInTheDocument() + expect(screen.getByText('Multi-turn')).toBeInTheDocument() + expect(screen.getByText('Multi-piece')).toBeInTheDocument() + expect(screen.getByText('JSON Schema')).toBeInTheDocument() + expect(screen.getByText('JSON Output')).toBeInTheDocument() + expect(screen.getByText('Edit History')).toBeInTheDocument() + expect(screen.getByText('System Prompt')).toBeInTheDocument() expect(screen.getByText('Parameters')).toBeInTheDocument() }) @@ -151,8 +179,79 @@ describe('TargetTable', () => { ) + // Dashes for model, endpoint, inputs, outputs, 6 capability columns (all unknown), and params const dashes = screen.getAllByText('—') - expect(dashes.length).toBeGreaterThanOrEqual(2) + expect(dashes).toHaveLength(11) + }) + + it('should show dash for capability columns when capabilities is absent', () => { + render( + + + + ) + + // TextTarget has no capabilities — all 6 should be dashes + const dashes = screen.getAllByText('—') + // model (—) + endpoint (—) + inputs (—) + outputs (—) + 6 capabilities (—) + params (—) = 11 + expect(dashes).toHaveLength(11) + }) + + it('should render modality icons with tooltips for inputs and outputs', () => { + render( + + + + ) + + // Modality tooltips are accessible labels; multiple identical labels can appear + // (e.g. one "Text" for input and one for output). + expect(screen.getAllByLabelText('Text').length).toBeGreaterThanOrEqual(1) + expect(screen.getAllByLabelText('Image').length).toBeGreaterThanOrEqual(1) + }) + + it('should render modality icons in canonical order: text, image, audio, video, reasoning, function_call, tool_call', () => { + const target: TargetInstance = { + target_registry_name: 'multi_modal', + target_type: 'CustomTarget', + endpoint: null, + model_name: null, + capabilities: { + supports_multi_turn: true, + supports_multi_message_pieces: true, + supports_json_schema: false, + supports_json_output: false, + supports_editable_history: false, + supports_system_prompt: false, + // Backend returns alphabetically sorted; UI must reorder. + supported_input_modalities: [ + 'audio_path', + 'function_call', + 'image_path', + 'reasoning', + 'text', + 'tool_call', + 'video_path', + ], + supported_output_modalities: ['text'], + }, + } + render( + + + + ) + + const expectedOrder = ['Text', 'Image', 'Audio', 'Video', 'Reasoning', 'Function call', 'Tool call'] + // The first set of modality icons belongs to the Inputs column. + const labels = expectedOrder.map((label) => screen.getAllByLabelText(label)[0]) + const positions = labels.map((el) => el.compareDocumentPosition(labels[0])) + // Each subsequent label should follow (or be) the first; verify monotonic ordering pairwise. + for (let i = 0; i < labels.length - 1; i += 1) { + const relation = labels[i].compareDocumentPosition(labels[i + 1]) + expect(relation & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy() + } + expect(positions).toBeDefined() }) it('should display target_specific_params when present', () => { diff --git a/frontend/src/components/Config/TargetTable.tsx b/frontend/src/components/Config/TargetTable.tsx index 4b9011e33d..ba7b224d4e 100644 --- a/frontend/src/components/Config/TargetTable.tsx +++ b/frontend/src/components/Config/TargetTable.tsx @@ -1,4 +1,4 @@ -import { useState, useMemo } from 'react' +import { useState, useMemo, forwardRef } from 'react' import { Table, TableHeader, @@ -12,7 +12,21 @@ import { Tooltip, Select, } from '@fluentui/react-components' -import { CheckmarkRegular } from '@fluentui/react-icons' +import { + CheckmarkRegular, + CheckmarkCircleFilled, + DismissCircleFilled, + TextTRegular, + ImageRegular, + MicRegular, + VideoRegular, + DocumentRegular, + LinkRegular, + LightbulbRegular, + MathFormulaRegular, + WrenchRegular, + ArrowHookUpLeftRegular, +} from '@fluentui/react-icons' import type { TargetInstance } from '../../types' import { useTargetTableStyles } from './TargetTable.styles' @@ -39,6 +53,105 @@ function formatParams(params?: Record | null): string { return parts.join('\n') } +/** Capability column definitions with tooltip descriptions. */ +const CAPABILITY_COLUMNS = [ + { key: 'supports_multi_turn', label: 'Multi-turn', tooltip: 'Supports multi-turn conversations' }, + { key: 'supports_multi_message_pieces', label: 'Multi-piece', tooltip: 'Supports multiple message pieces in a single request' }, + { key: 'supports_json_schema', label: 'JSON Schema', tooltip: 'Supports constraining output to a JSON schema' }, + { key: 'supports_json_output', label: 'JSON Output', tooltip: 'Supports JSON output format' }, + { key: 'supports_editable_history', label: 'Edit History', tooltip: 'Allows attack history to be modified' }, + { key: 'supports_system_prompt', label: 'System Prompt', tooltip: 'Supports system prompts' }, +] as const + +const COLUMN_TOOLTIPS = { + type: 'Target class implementation', + model: 'Configured model name. A dotted underline indicates the deployment alias differs from the underlying model — hover the value to see it.', + endpoint: 'API endpoint URL the target sends requests to', + parameters: 'Target-specific configuration parameters (e.g., reasoning_effort, max_output_tokens)', + inputs: 'Modalities the target accepts as input', + outputs: 'Modalities the target can produce as output', +} as const + +/** Composite icon: f(x) with a small return-arrow badge for function call outputs. */ +const FunctionCallOutputIcon = forwardRef & { className?: string }>( + function FunctionCallOutputIcon({ className, ...rest }, ref) { + const styles = useTargetTableStyles() + return ( + + + + + ) + } +) + +/** Modality → (icon, label) for input/output column rendering. The renderer accepts + * arbitrary props so Tooltip can inject event handlers / ARIA attributes. */ +// eslint-disable-next-line @typescript-eslint/no-explicit-any +const MODALITY_RENDERERS: Record; label: string }> = { + text: { Icon: TextTRegular, label: 'Text' }, + image_path: { Icon: ImageRegular, label: 'Image' }, + audio_path: { Icon: MicRegular, label: 'Audio' }, + video_path: { Icon: VideoRegular, label: 'Video' }, + reasoning: { Icon: LightbulbRegular, label: 'Reasoning' }, + function_call: { Icon: MathFormulaRegular, label: 'Function call' }, + function_call_output: { Icon: FunctionCallOutputIcon, label: 'Function call output' }, + tool_call: { Icon: WrenchRegular, label: 'Tool call' }, + binary_path: { Icon: DocumentRegular, label: 'Binary' }, + url: { Icon: LinkRegular, label: 'URL' }, +} + +/** Canonical display order for modality icons; unknown values are appended last. */ +const MODALITY_ORDER: readonly string[] = [ + 'text', + 'image_path', + 'audio_path', + 'video_path', + 'reasoning', + 'function_call', + 'function_call_output', + 'tool_call', + 'binary_path', + 'url', +] + +/** Render a row of modality icons; falls back to "—" when empty. */ +function ModalityCell({ modalities }: { modalities: string[] | undefined }) { + const styles = useTargetTableStyles() + if (!modalities || modalities.length === 0) { + return + } + const ordered = MODALITY_ORDER.filter((m) => modalities.includes(m)) + const extras = modalities.filter((m) => !MODALITY_ORDER.includes(m)) + const sorted = [...ordered, ...extras] + return ( +
+ {sorted.map((modality) => { + const renderer = MODALITY_RENDERERS[modality] + const label = renderer?.label ?? modality + const Icon = renderer?.Icon ?? DocumentRegular + return ( + + + + ) + })} +
+ ) +} + +/** Render a capability indicator: ✓ (green) / ✗ (red) / — (unknown). */ +function CapabilityCell({ value }: { value: boolean | undefined }) { + const styles = useTargetTableStyles() + if (value === undefined) { + return + } + if (value) { + return + } + return +} + /** Render the model cell with a tooltip when underlying model differs. */ function ModelCell({ target }: { target: TargetInstance }) { const displayName = target.model_name || '—' @@ -62,6 +175,22 @@ function ModelCell({ target }: { target: TargetInstance }) { return {displayName} } +/** Render capability cells for a target. */ +function CapabilityCells({ target }: { target: TargetInstance }) { + const styles = useTargetTableStyles() + return ( + <> + {CAPABILITY_COLUMNS.map(({ key }) => ( + + + + ))} + + ) +} + export default function TargetTable({ targets, activeTarget, onSetActiveTarget }: TargetTableProps) { const styles = useTargetTableStyles() const [typeFilter, setTypeFilter] = useState('') @@ -88,18 +217,25 @@ export default function TargetTable({ targets, activeTarget, onSetActiveTarget } }>Active - - {activeTarget.target_type} + + {activeTarget.target_type} - + - + {activeTarget.endpoint || '—'} - + + + + + + + + {formatParams(activeTarget.target_specific_params) || '—'} @@ -126,13 +262,46 @@ export default function TargetTable({ targets, activeTarget, onSetActiveTarget } )} - + - Type - Model - Endpoint - Parameters + + + Type + + + + + Model + + + + + Endpoint + + + + + Inputs + + + + + Outputs + + + {CAPABILITY_COLUMNS.map(({ key, label, tooltip }) => ( + + + {label} + + + ))} + + + Parameters + + @@ -157,7 +326,7 @@ export default function TargetTable({ targets, activeTarget, onSetActiveTarget } )} - {target.target_type} + {target.target_type} @@ -167,6 +336,13 @@ export default function TargetTable({ targets, activeTarget, onSetActiveTarget } {target.endpoint || '—'} + + + + + + + {formatParams(target.target_specific_params) || '—'} diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 07d5865123..f3db442914 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -53,6 +53,17 @@ export interface PaginationInfo { // --- Targets --- +export interface TargetCapabilitiesInfo { + supports_multi_turn: boolean + supports_multi_message_pieces: boolean + supports_json_schema: boolean + supports_json_output: boolean + supports_editable_history: boolean + supports_system_prompt: boolean + supported_input_modalities: string[] + supported_output_modalities: string[] +} + export interface TargetInstance { target_registry_name: string target_type: string @@ -63,6 +74,7 @@ export interface TargetInstance { top_p?: number | null max_requests_per_minute?: number | null supports_multi_turn?: boolean + capabilities?: TargetCapabilitiesInfo | null target_specific_params?: Record | null } diff --git a/pyrit/backend/mappers/target_mappers.py b/pyrit/backend/mappers/target_mappers.py index 1d72822690..e39de13ba8 100644 --- a/pyrit/backend/mappers/target_mappers.py +++ b/pyrit/backend/mappers/target_mappers.py @@ -5,7 +5,7 @@ Target mappers – domain → DTO translation for target-related models. """ -from pyrit.backend.models.targets import TargetInstance +from pyrit.backend.models.targets import TargetCapabilitiesInfo, TargetInstance from pyrit.prompt_target import PromptTarget @@ -27,6 +27,7 @@ def target_object_to_instance(target_registry_name: str, target_obj: PromptTarge params = identifier.params # Keys that are extracted as top-level TargetInstance fields + # or are internal-only (target_configuration is the verbose capabilities blob). extracted_keys = { "endpoint", "model_name", @@ -36,6 +37,7 @@ def target_object_to_instance(target_registry_name: str, target_obj: PromptTarge "max_requests_per_minute", "supports_multi_turn", "target_specific_params", + "target_configuration", } # Collect remaining params as target_specific_params so the frontend can display them @@ -43,6 +45,20 @@ def target_object_to_instance(target_registry_name: str, target_obj: PromptTarge extra = {k: v for k, v in params.items() if k not in extracted_keys and v is not None} combined_specific = {**extra, **explicit_specific} or None + caps = target_obj.capabilities + input_modalities = sorted({modality for combo in caps.input_modalities for modality in combo}) + output_modalities = sorted({modality for combo in caps.output_modalities for modality in combo}) + capabilities = TargetCapabilitiesInfo( + supports_multi_turn=caps.supports_multi_turn, + supports_multi_message_pieces=caps.supports_multi_message_pieces, + supports_json_schema=caps.supports_json_schema, + supports_json_output=caps.supports_json_output, + supports_editable_history=caps.supports_editable_history, + supports_system_prompt=caps.supports_system_prompt, + supported_input_modalities=input_modalities, + supported_output_modalities=output_modalities, + ) + return TargetInstance( target_registry_name=target_registry_name, target_type=identifier.class_name, @@ -52,6 +68,7 @@ def target_object_to_instance(target_registry_name: str, target_obj: PromptTarge temperature=params.get("temperature"), top_p=params.get("top_p"), max_requests_per_minute=params.get("max_requests_per_minute"), - supports_multi_turn=target_obj.capabilities.supports_multi_turn, + supports_multi_turn=caps.supports_multi_turn, + capabilities=capabilities, target_specific_params=combined_specific, ) diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py index fef7cbe41e..e9a28fb89f 100644 --- a/pyrit/backend/models/targets.py +++ b/pyrit/backend/models/targets.py @@ -18,6 +18,31 @@ from pyrit.backend.models.common import PaginationInfo +class TargetCapabilitiesInfo(BaseModel): + """Structured capability flags for a target instance.""" + + supports_multi_turn: bool = Field(..., description="Whether the target supports multi-turn conversation history") + supports_multi_message_pieces: bool = Field( + ..., description="Whether the target supports multiple message pieces in a single request" + ) + supports_json_schema: bool = Field( + ..., description="Whether the target supports constraining output to a JSON schema" + ) + supports_json_output: bool = Field(..., description="Whether the target supports JSON output format") + supports_editable_history: bool = Field( + ..., description="Whether the target allows the attack history to be modified" + ) + supports_system_prompt: bool = Field(..., description="Whether the target supports system prompts") + supported_input_modalities: list[str] = Field( + default_factory=list, + description="Flattened, sorted list of supported input modality data types (e.g., 'text', 'image_path')", + ) + supported_output_modalities: list[str] = Field( + default_factory=list, + description="Flattened, sorted list of supported output modality data types (e.g., 'text', 'audio_path')", + ) + + class TargetInstance(BaseModel): """ A runtime target instance. @@ -37,6 +62,7 @@ class TargetInstance(BaseModel): top_p: Optional[float] = Field(None, description="Top-p parameter for generation") max_requests_per_minute: Optional[int] = Field(None, description="Maximum requests per minute") supports_multi_turn: bool = Field(True, description="Whether the target supports multi-turn conversation history") + capabilities: Optional[TargetCapabilitiesInfo] = Field(None, description="Structured capability flags") target_specific_params: Optional[dict[str, Any]] = Field(None, description="Additional target-specific parameters") diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 9eda90ab5b..c6758bbaf0 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -1257,6 +1257,114 @@ def test_chat_target_extra_params_preserved(self) -> None: assert result.target_specific_params["seed"] == 42 assert result.target_specific_params["max_completion_tokens"] == 2048 + def test_capabilities_populated_from_target_object(self) -> None: + """Test that all 6 capability fields are populated from target_obj.capabilities.""" + target_obj = MagicMock(spec=PromptTarget) + target_obj.capabilities = TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_json_schema=False, + supports_json_output=True, + supports_editable_history=False, + supports_system_prompt=True, + ) + mock_identifier = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_obj.get_identifier.return_value = mock_identifier + + result = target_object_to_instance("t-1", target_obj) + + assert result.capabilities is not None + assert result.capabilities.supports_multi_turn is True + assert result.capabilities.supports_multi_message_pieces is True + assert result.capabilities.supports_json_schema is False + assert result.capabilities.supports_json_output is True + assert result.capabilities.supports_editable_history is False + assert result.capabilities.supports_system_prompt is True + + def test_capabilities_modalities_flattened_and_sorted(self) -> None: + """Test that input/output modality combinations are flattened to a sorted list of types.""" + target_obj = MagicMock(spec=PromptTarget) + target_obj.capabilities = TargetCapabilities( + input_modalities=frozenset( + { + frozenset({"text"}), + frozenset({"image_path"}), + frozenset({"text", "image_path"}), + } + ), + output_modalities=frozenset({frozenset({"audio_path", "video_path"})}), + ) + mock_identifier = ComponentIdentifier( + class_name="CustomTarget", + class_module="pyrit.prompt_target", + ) + target_obj.get_identifier.return_value = mock_identifier + + result = target_object_to_instance("t-1", target_obj) + + assert result.capabilities is not None + assert result.capabilities.supported_input_modalities == ["image_path", "text"] + assert result.capabilities.supported_output_modalities == ["audio_path", "video_path"] + + def test_capabilities_default_modalities_are_text(self) -> None: + """Targets that don't override modalities should default to ['text'].""" + target_obj = MagicMock(spec=PromptTarget) + target_obj.capabilities = TargetCapabilities() + mock_identifier = ComponentIdentifier( + class_name="TextTarget", + class_module="pyrit.prompt_target", + ) + target_obj.get_identifier.return_value = mock_identifier + + result = target_object_to_instance("t-1", target_obj) + + assert result.capabilities is not None + assert result.capabilities.supported_input_modalities == ["text"] + assert result.capabilities.supported_output_modalities == ["text"] + + def test_capabilities_matches_legacy_supports_multi_turn(self) -> None: + """Test that legacy supports_multi_turn field matches capabilities.supports_multi_turn.""" + target_obj = MagicMock(spec=PromptTarget) + target_obj.capabilities = TargetCapabilities(supports_multi_turn=False) + mock_identifier = ComponentIdentifier( + class_name="TextTarget", + class_module="pyrit.prompt_target", + ) + target_obj.get_identifier.return_value = mock_identifier + + result = target_object_to_instance("t-1", target_obj) + + assert result.supports_multi_turn is False + assert result.capabilities is not None + assert result.capabilities.supports_multi_turn is False + assert result.supports_multi_turn == result.capabilities.supports_multi_turn + + def test_target_configuration_excluded_from_target_specific_params(self) -> None: + """Test that the verbose target_configuration blob is filtered from target_specific_params.""" + target_obj = MagicMock(spec=PromptTarget) + target_obj.capabilities = TargetCapabilities(supports_multi_turn=True) + mock_identifier = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={ + "endpoint": "https://api.openai.com", + "model_name": "gpt-4", + "target_configuration": {"capabilities": {"supports_multi_turn": True}}, + "reasoning_effort": "high", + }, + ) + target_obj.get_identifier.return_value = mock_identifier + + result = target_object_to_instance("t-1", target_obj) + + assert result.target_specific_params is not None + assert "target_configuration" not in result.target_specific_params + assert result.target_specific_params["reasoning_effort"] == "high" + # ============================================================================ # Converter Mapper Tests