diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py index eef47e3b4..eeb96f7a2 100644 --- a/src/strands/agent/a2a_agent.py +++ b/src/strands/agent/a2a_agent.py @@ -15,10 +15,14 @@ import httpx from a2a.client import A2ACardResolver, ClientConfig, ClientFactory -from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent +from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskStatusUpdateEvent from .._async import run_async -from ..multiagent.a2a._converters import convert_input_to_message, convert_response_to_agent_result +from ..multiagent.a2a._converters import ( + _STATE_TO_STOP_REASON, + convert_input_to_message, + convert_response_to_agent_result, +) from ..types._events import AgentResultEvent from ..types.a2a import A2AResponse, A2AStreamEvent from ..types.agent import AgentInput @@ -29,6 +33,13 @@ _DEFAULT_TIMEOUT = 300 +# A2A task states that indicate the response stream is complete. +# Derived from the canonical _STATE_TO_STOP_REASON mapping in _converters. +# Terminal states (end_turn) mean no more events; input states (interrupt) mean execution is paused. +_TERMINAL_STATES = {state for state, reason in _STATE_TO_STOP_REASON.items() if reason == "end_turn"} +_INPUT_STATES = {state for state, reason in _STATE_TO_STOP_REASON.items() if reason == "interrupt"} +_COMPLETE_STATES = _TERMINAL_STATES | _INPUT_STATES + class A2AAgent(AgentBase): """Client wrapper for remote A2A agents.""" @@ -265,6 +276,9 @@ async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]: def _is_complete_event(self, event: A2AResponse) -> bool: """Check if an A2A event represents a complete response. + Recognizes all terminal states (completed, failed, canceled, rejected) + and pausing states (input_required, auth_required) as complete events. + Args: event: A2A event. @@ -289,9 +303,10 @@ def _is_complete_event(self, event: A2AResponse) -> bool: return update_event.last_chunk return False - # Status update with completed state + # Status update - check for terminal or pausing states if isinstance(update_event, TaskStatusUpdateEvent): if update_event.status and hasattr(update_event.status, "state"): - return update_event.status.state == TaskState.completed + state = update_event.status.state + return state in _COMPLETE_STATES return False diff --git a/src/strands/multiagent/a2a/_converters.py b/src/strands/multiagent/a2a/_converters.py index 22c2ffb72..7808ae325 100644 --- a/src/strands/multiagent/a2a/_converters.py +++ b/src/strands/multiagent/a2a/_converters.py @@ -4,13 +4,24 @@ from uuid import uuid4 from a2a.types import Message as A2AMessage -from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart +from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent, TextPart from ...agent.agent_result import AgentResult from ...telemetry.metrics import EventLoopMetrics from ...types.a2a import A2AResponse from ...types.agent import AgentInput from ...types.content import ContentBlock, Message +from ...types.event_loop import StopReason + +# Mapping from A2A TaskState to Strands stop_reason +_STATE_TO_STOP_REASON: dict[TaskState, StopReason] = { + TaskState.completed: "end_turn", + TaskState.failed: "end_turn", + TaskState.canceled: "end_turn", + TaskState.rejected: "end_turn", + TaskState.input_required: "interrupt", + TaskState.auth_required: "interrupt", +} def convert_input_to_message(prompt: AgentInput) -> A2AMessage: @@ -79,9 +90,34 @@ def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[ return parts +def _extract_task_state(response: A2AResponse) -> TaskState | None: + """Extract the task state from an A2A response. + + Args: + response: A2A response (either A2AMessage or tuple of task and update event). + + Returns: + The TaskState if available, None otherwise. + """ + if isinstance(response, tuple) and len(response) == 2: + _task, update_event = response + if isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "state"): + return update_event.status.state + return None + + def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: """Convert A2A response to AgentResult. + Maps A2A task lifecycle states to appropriate Strands stop_reasons: + - completed → end_turn + - failed → end_turn (with error content) + - canceled → end_turn (with cancellation info) + - rejected → end_turn (with rejection info) + - input_required → interrupt (agent needs user input) + - auth_required → interrupt (agent needs authentication) + Args: response: A2A response (either A2AMessage or tuple of task and update event). @@ -89,19 +125,26 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: AgentResult with extracted content and metadata. """ content: list[ContentBlock] = [] + task_state = _extract_task_state(response) + stop_reason: StopReason = _STATE_TO_STOP_REASON.get(task_state, "end_turn") if task_state else "end_turn" if isinstance(response, tuple) and len(response) == 2: task, update_event = response # Handle artifact updates if isinstance(update_event, TaskArtifactUpdateEvent): - if update_event.artifact and hasattr(update_event.artifact, "parts"): + if update_event.artifact and hasattr(update_event.artifact, "parts") and update_event.artifact.parts: for part in update_event.artifact.parts: if hasattr(part, "root") and hasattr(part.root, "text"): content.append({"text": part.root.text}) # Handle status updates with messages elif isinstance(update_event, TaskStatusUpdateEvent): - if update_event.status and hasattr(update_event.status, "message") and update_event.status.message: + if ( + update_event.status + and hasattr(update_event.status, "message") + and update_event.status.message + and update_event.status.message.parts + ): for part in update_event.status.message.parts: if hasattr(part, "root") and hasattr(part.root, "text"): content.append({"text": part.root.text}) @@ -109,7 +152,7 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: # Use task.artifacts when no content was extracted from the event if not content and task and hasattr(task, "artifacts") and task.artifacts is not None: for artifact in task.artifacts: - if hasattr(artifact, "parts"): + if hasattr(artifact, "parts") and artifact.parts: for part in artifact.parts: if hasattr(part, "root") and hasattr(part.root, "text"): content.append({"text": part.root.text}) @@ -123,9 +166,14 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: "content": content, } + # Build state dict with A2A metadata + state: dict[str, str] = {} + if task_state is not None: + state["a2a_task_state"] = task_state.value + return AgentResult( - stop_reason="end_turn", + stop_reason=stop_reason, message=message, metrics=EventLoopMetrics(), - state={}, + state=state, ) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index c8c00600b..7526386e8 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -8,6 +8,7 @@ streamed requests to the A2AServer. """ +import asyncio import base64 import json import logging @@ -42,7 +43,9 @@ class StrandsA2AExecutor(AgentExecutor): """Executor that adapts a Strands Agent to the A2A protocol. This executor uses streaming mode to handle the execution of agent requests - and converts Strands Agent responses to A2A protocol events. + and converts Strands Agent responses to A2A protocol events. It supports the + full A2A task lifecycle including error handling (failed state), cancellation, + and interrupt-based input_required flows. """ # Default formats for each file type when MIME type is unavailable or unrecognized @@ -75,14 +78,18 @@ async def execute( """Execute a request using the Strands Agent and send the response as A2A events. This method executes the user's input using the Strands Agent in streaming mode - and converts the agent's response to A2A events. + and converts the agent's response to A2A events. If the agent raises an exception, + the task transitions to the `failed` state. If the agent returns with interrupts, + the task transitions to the `input_required` state. Args: context: The A2A request context, containing the user's input and task metadata. event_queue: The A2A event queue used to send response events back to the client. Raises: - ServerError: If an error occurs during agent execution + ServerError: If an unrecoverable error occurs during agent execution setup + (e.g., missing input). Agent execution errors are handled gracefully + by transitioning the task to the failed state. """ task = context.current_task if not task: @@ -93,8 +100,34 @@ async def execute( try: await self._execute_streaming(context, updater) - except Exception as e: - raise ServerError(error=InternalError()) from e + except ServerError: + # Re-raise ServerErrors (setup failures like missing input) + raise + except asyncio.CancelledError: + # asyncio.CancelledError is a BaseException (not Exception) — raised when + # the asyncio task is cancelled (e.g., HTTP client disconnect, server shutdown). + # We transition to canceled state so the task doesn't remain a zombie in "working". + logger.warning("task_id=<%s> | asyncio task cancelled, transitioning to canceled state", task.id) + try: + await updater.cancel( + message=updater.new_agent_message( + parts=[Part(root=TextPart(text="Task cancelled due to connection termination"))] + ) + ) + except RuntimeError: + # Task already in terminal state + logger.debug("task_id=<%s> | task already in terminal state, cannot transition to canceled", task.id) + raise + except Exception: + # Agent execution failures transition to failed state + logger.exception("task_id=<%s> | agent execution failed, transitioning to failed state", task.id) + try: + await updater.failed( + message=updater.new_agent_message(parts=[Part(root=TextPart(text="Agent execution failed"))]) + ) + except RuntimeError: + # Task already in terminal state (e.g., completed before error in cleanup) + logger.debug("task_id=<%s> | task already in terminal state, cannot transition to failed", task.id) async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater) -> None: """Execute request in streaming mode. @@ -105,14 +138,19 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater Args: context: The A2A request context, containing the user's input and other metadata. updater: The task updater for managing task state and sending updates. + + Raises: + ServerError: If input conversion fails (missing or empty content). """ # Convert A2A message parts to Strands ContentBlocks if context.message and hasattr(context.message, "parts"): content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts) if not content_blocks: - raise ValueError("No content blocks available") + raise ServerError( + error=InternalError(message="No valid content found in request message parts") + ) from None else: - raise ValueError("No content blocks available") + raise ServerError(error=InternalError(message="Request message is missing or has no parts")) from None if not self.enable_a2a_compliant_streaming: warnings.warn( @@ -133,8 +171,20 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater invocation_state: dict[str, Any] = {"a2a_request_context": context} try: + result: SAAgentResult | None = None async for event in self.agent.stream_async(content_blocks, invocation_state=invocation_state): - await self._handle_streaming_event(event, updater) + if "result" in event: + result = event["result"] + else: + await self._handle_streaming_event(event, updater) + + # Check if agent returned with interrupts (input_required) + # Note: stop_reason="interrupt" is the authoritative signal. Even if interrupts + # list is empty (edge case), the agent still indicated it needs input. + if result is not None and result.stop_reason == "interrupt": + await self._handle_interrupt_result(result, updater) + else: + await self._handle_agent_result(result, updater) except Exception: logger.exception("Error in streaming execution") raise @@ -143,6 +193,34 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater self._current_artifact_id = None self._is_first_chunk = True + async def _handle_interrupt_result(self, result: SAAgentResult, updater: TaskUpdater) -> None: + """Handle an agent result that contains interrupts. + + When the Strands Agent returns with stop_reason="interrupt", this maps to + the A2A `input_required` state. The interrupt details are communicated to + the client via the status message. + + Args: + result: The agent result containing interrupts. + updater: The task updater for managing task state. + """ + # Build a descriptive message about what input is needed + interrupt_descriptions = [] + for interrupt in result.interrupts or []: + desc = f"- {interrupt.name}" + if interrupt.reason: + desc += f": {interrupt.reason}" + interrupt_descriptions.append(desc) + + if interrupt_descriptions: + input_message = "Agent requires input:\n" + "\n".join(interrupt_descriptions) + else: + # Edge case: stop_reason="interrupt" but no interrupt details provided. + # Still transition to input_required — the agent signaled it needs input. + input_message = "Agent requires additional input to continue" + + await updater.requires_input(message=updater.new_agent_message(parts=[Part(root=TextPart(text=input_message))])) + async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None: """Handle a single streaming event from the Strands Agent. @@ -175,8 +253,6 @@ async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpda updater.task_id, ), ) - elif "result" in event: - await self._handle_agent_result(event["result"], updater) async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None: """Handle the final result from the Strands Agent. @@ -219,20 +295,42 @@ async def _handle_agent_result(self, result: SAAgentResult | None, updater: Task async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """Cancel an ongoing execution. - This method is called when a request cancellation is requested. Currently, - cancellation is not supported by the Strands Agent executor, so this method - always raises an UnsupportedOperationError. + Transitions the task to the canceled state and attempts to stop the agent. + The agent's cancel() method is called to signal cooperative cancellation + of in-flight execution. + + Note: This transitions the A2A task state. The underlying agent execution + may still complete its current model call before stopping. Args: context: The A2A request context. event_queue: The A2A event queue. Raises: - ServerError: Always raised with an UnsupportedOperationError, as cancellation - is not currently supported. + ServerError: If no current task exists or the task is already in a terminal state. """ - logger.warning("Cancellation requested but not supported") - raise ServerError(error=UnsupportedOperationError()) + task = context.current_task + if not task: + logger.warning("context_id=<%s> | cancel requested but no current task found", context.context_id) + raise ServerError(error=UnsupportedOperationError()) from None + + # Cooperatively cancel the agent's execution (best-effort). + # Agent.cancel() is always available since self.agent is typed as Agent. + try: + self.agent.cancel() + except Exception: + logger.debug("task_id=<%s> | agent cancel signal failed (non-critical)", task.id) + + updater = TaskUpdater(event_queue, task.id, task.context_id) + + try: + await updater.cancel( + message=updater.new_agent_message(parts=[Part(root=TextPart(text="Task cancelled by client request"))]) + ) + except RuntimeError: + # TaskUpdater raises RuntimeError when task is already in a terminal state + logger.warning("task_id=<%s> | cannot cancel, already in terminal state", task.id) + raise ServerError(error=UnsupportedOperationError()) from None def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]: """Classify file type based on MIME type. diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py index d918033e5..9c3be7917 100644 --- a/tests/strands/agent/test_a2a_agent.py +++ b/tests/strands/agent/test_a2a_agent.py @@ -7,7 +7,7 @@ import pytest from a2a.client import ClientConfig -from a2a.types import AgentCard, Message, Part, Role, TextPart +from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart from strands.agent.a2a_agent import A2AAgent from strands.agent.agent_result import AgentResult @@ -714,3 +714,163 @@ async def mock_send_message(*args, **kwargs): # Should have 1 stream event + 1 result event (falls back to last) assert len(events) == 2 assert "result" in events[1] + + +# ========================================================================= +# NEW TESTS: Client-side lifecycle state handling +# ========================================================================= + + +def test_is_complete_event_failed_state(a2a_agent): + """Test that failed state is recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.failed + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_canceled_state(a2a_agent): + """Test that canceled state is recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.canceled + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_rejected_state(a2a_agent): + """Test that rejected state is recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.rejected + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_input_required_state(a2a_agent): + """Test that input_required state is recognized as complete (pausing).""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.input_required + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_auth_required_state(a2a_agent): + """Test that auth_required state is recognized as complete (pausing).""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.auth_required + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_working_state_not_complete(a2a_agent): + """Test that working state is NOT recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.working + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is False + + +def test_is_complete_event_submitted_state_not_complete(a2a_agent): + """Test that submitted state is NOT recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.submitted + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is False + + +# ========================================================================= +# DEVIL'S ADVOCATE FINDINGS — Tests addressing review gaps +# ========================================================================= + + +@pytest.mark.parametrize( + "state,expected_complete", + [ + (TaskState.completed, True), + (TaskState.failed, True), + (TaskState.canceled, True), + (TaskState.rejected, True), + (TaskState.input_required, True), + (TaskState.auth_required, True), + (TaskState.working, False), + (TaskState.submitted, False), + (TaskState.unknown, False), + ], + ids=[ + "completed-is-complete", + "failed-is-complete", + "canceled-is-complete", + "rejected-is-complete", + "input_required-is-complete", + "auth_required-is-complete", + "working-not-complete", + "submitted-not-complete", + "unknown-not-complete", + ], +) +def test_is_complete_event_all_states_parametrized(a2a_agent, state, expected_complete): + """Minor Finding 7: Parametrized test covering ALL TaskState values. + + This replaces verbose individual tests with a single parameterized test that + covers all 9 TaskState values. When a2a-sdk adds new states, adding a row here + is trivial. + """ + from unittest.mock import MagicMock + + from a2a.types import TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = state + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is expected_complete diff --git a/tests/strands/multiagent/a2a/test_converters.py b/tests/strands/multiagent/a2a/test_converters.py index c3b310065..fff48653b 100644 --- a/tests/strands/multiagent/a2a/test_converters.py +++ b/tests/strands/multiagent/a2a/test_converters.py @@ -243,3 +243,286 @@ def test_convert_response_handles_missing_data(): mock_task.artifacts = [mock_artifact] result = convert_response_to_agent_result((mock_task, None)) assert len(result.message["content"]) == 0 + + +# ========================================================================= +# NEW TESTS: Lifecycle State Mapping +# ========================================================================= + + +def test_convert_response_completed_state_maps_to_end_turn(): + """Test that completed state maps to end_turn stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.completed, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + + +def test_convert_response_failed_state_maps_to_end_turn(): + """Test that failed state maps to end_turn stop_reason with error content.""" + from unittest.mock import MagicMock + + from a2a.types import Message, TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + # Create a status message with error info + error_part = MagicMock() + error_part.root = MagicMock() + error_part.root.text = "Agent execution failed: timeout" + + error_message = MagicMock(spec=Message) + error_message.parts = [error_part] + + status = TaskStatus(state=TaskState.failed, message=error_message) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "failed" + assert "Agent execution failed" in result.message["content"][0]["text"] + + +def test_convert_response_input_required_maps_to_interrupt(): + """Test that input_required state maps to interrupt stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import Message, TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + input_part = MagicMock() + input_part.root = MagicMock() + input_part.root.text = "Agent requires input:\n- approval: Need confirmation" + + input_message = MagicMock(spec=Message) + input_message.parts = [input_part] + + status = TaskStatus(state=TaskState.input_required, message=input_message) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "interrupt" + assert result.state.get("a2a_task_state") == "input-required" + assert "approval" in result.message["content"][0]["text"] + + +def test_convert_response_canceled_state_maps_to_end_turn(): + """Test that canceled state maps to end_turn stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.canceled, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "canceled" + + +def test_convert_response_rejected_state_maps_to_end_turn(): + """Test that rejected state maps to end_turn stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.rejected, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "rejected" + + +def test_convert_response_auth_required_maps_to_interrupt(): + """Test that auth_required state maps to interrupt stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.auth_required, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "interrupt" + assert result.state.get("a2a_task_state") == "auth-required" + + +def test_extract_task_state_from_status_update(): + """Test _extract_task_state helper.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + from strands.multiagent.a2a._converters import _extract_task_state + + task = MagicMock() + status = TaskStatus(state=TaskState.failed, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + state = _extract_task_state((task, update_event)) + assert state == TaskState.failed + + +def test_extract_task_state_from_message_returns_none(): + """Test _extract_task_state returns None for Message responses.""" + from unittest.mock import MagicMock + + from a2a.types import Message + + from strands.multiagent.a2a._converters import _extract_task_state + + message = MagicMock(spec=Message) + state = _extract_task_state(message) + assert state is None + + +# ========================================================================= +# DEVIL'S ADVOCATE FINDINGS — Tests addressing review gaps +# ========================================================================= + + +def test_convert_response_completed_state_includes_state_metadata(): + """Major Finding 3: The completed state test was missing state assertion. + + Every other state test asserts both stop_reason AND result.state, but the most + important one (completed — the happy path) was missing the state check. This ensures + downstream consumers relying on result.state["a2a_task_state"] won't break silently. + """ + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.completed, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "completed" # THIS WAS MISSING + + +def test_convert_response_unknown_state_defaults_to_end_turn(): + """Major Finding 4: TaskState.unknown should default to end_turn. + + The a2a-sdk has a TaskState.unknown value. Our code handles it via the .get() + default ("end_turn"). This test documents that this is an intentional design + decision: unknown states are treated as terminal completions rather than errors. + + Rationale: An unknown state from a remote server is ambiguous. Treating it as + end_turn (completed) is the safest default — the client won't hang waiting for + more events, and the result content (if any) is still accessible. + """ + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.unknown, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + # unknown is NOT in _STATE_TO_STOP_REASON, so defaults to "end_turn" + assert result.stop_reason == "end_turn" + # state metadata should reflect the actual state value + assert result.state.get("a2a_task_state") == "unknown" + + +def test_convert_response_working_state_defaults_to_end_turn(): + """Test that working state (not in mapping) defaults to end_turn. + + This covers the edge case where a TaskStatusUpdateEvent with state=working + somehow reaches the converter (shouldn't normally happen since _is_complete_event + filters these out, but defense-in-depth). + """ + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.working, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "working" + + +def test_extract_task_state_from_artifact_update_returns_none(): + """Minor Finding 5: _extract_task_state with TaskArtifactUpdateEvent returns None. + + This is the untested path where the update event is an artifact (not status). + """ + from unittest.mock import MagicMock + + from a2a.types import TaskArtifactUpdateEvent + + from strands.multiagent.a2a._converters import _extract_task_state + + task = MagicMock() + mock_event = MagicMock(spec=TaskArtifactUpdateEvent) + + state = _extract_task_state((task, mock_event)) + assert state is None + + +def test_state_to_stop_reason_covers_all_lifecycle_states(): + """Verify _STATE_TO_STOP_REASON has mappings for all documented lifecycle states. + + Guards against future additions to the a2a-sdk that we miss. + """ + from a2a.types import TaskState + + from strands.multiagent.a2a._converters import _STATE_TO_STOP_REASON + + # These are the states we explicitly handle + expected_mapped = { + TaskState.completed, + TaskState.failed, + TaskState.canceled, + TaskState.rejected, + TaskState.input_required, + TaskState.auth_required, + } + assert set(_STATE_TO_STOP_REASON.keys()) == expected_mapped + + # These should NOT be in the mapping (they're non-terminal progress states) + assert TaskState.working not in _STATE_TO_STOP_REASON + assert TaskState.submitted not in _STATE_TO_STOP_REASON + assert TaskState.unknown not in _STATE_TO_STOP_REASON diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index dc90fbdd6..940d26f8c 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -583,7 +583,7 @@ async def mock_stream(content_blocks): async def test_execute_streaming_mode_handles_agent_exception( mock_strands_agent, mock_request_context, mock_event_queue ): - """Test that execute handles agent exceptions correctly in streaming mode.""" + """Test that execute transitions to failed state when agent raises exception.""" # Setup mock agent to raise exception when stream_async is called mock_strands_agent.stream_async = MagicMock(side_effect=Exception("Agent error")) @@ -608,18 +608,25 @@ async def test_execute_streaming_mode_handles_agent_exception( mock_message.parts = [part] mock_request_context.message = mock_message - with pytest.raises(ServerError): - await executor.execute(mock_request_context, mock_event_queue) + # Should NOT raise - instead transitions to failed state + await executor.execute(mock_request_context, mock_event_queue) # Verify agent was called mock_strands_agent.stream_async.assert_called_once() + # Verify a failed status event was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + from a2a.types import TaskState, TaskStatusUpdateEvent -@pytest.mark.asyncio -async def test_cancel_raises_unsupported_operation_error(mock_strands_agent, mock_request_context, mock_event_queue): - """Test that cancel raises UnsupportedOperationError.""" + failed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed + ] + assert len(failed_events) == 1 + assert "Agent execution failed" in failed_events[0].status.message.parts[0].root.text executor = StrandsA2AExecutor(mock_strands_agent) + # Cancel with no current_task raises UnsupportedOperationError + mock_request_context.current_task = None with pytest.raises(ServerError) as excinfo: await executor.cancel(mock_request_context, mock_event_queue) @@ -1331,3 +1338,606 @@ async def test_invocation_state_with_a2a_compliant_streaming( assert invocation_state is not None assert invocation_state["a2a_request_context"] is mock_request_context + + +# ========================================================================= +# NEW TESTS: A2A Lifecycle State Support +# ========================================================================= + + +@pytest.mark.asyncio +async def test_execute_transitions_to_failed_on_streaming_error( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test that errors during streaming transition task to failed state.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + async def mock_stream(content_blocks, **kwargs): + """Mock streaming that raises mid-stream.""" + yield {"data": "partial output"} + raise RuntimeError("Connection lost") + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-fail" + mock_task.context_id = "ctx-fail" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # Should not raise + await executor.execute(mock_request_context, mock_event_queue) + + # Verify failed state was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + failed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed + ] + assert len(failed_events) == 1 + assert "Agent execution failed" in failed_events[0].status.message.parts[0].root.text + + +@pytest.mark.asyncio +async def test_cancel_with_valid_task(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel transitions task to canceled state when task exists.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancel" + mock_task.context_id = "ctx-cancel" + mock_request_context.current_task = mock_task + + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify canceled state was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + assert "cancelled" in canceled_events[0].status.message.parts[0].root.text.lower() + + +@pytest.mark.asyncio +async def test_cancel_without_task_raises_unsupported(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel raises UnsupportedOperationError when no task exists.""" + executor = StrandsA2AExecutor(mock_strands_agent) + mock_request_context.current_task = None + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + assert isinstance(excinfo.value.error, UnsupportedOperationError) + + +@pytest.mark.asyncio +async def test_execute_with_interrupt_transitions_to_input_required( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test that agent interrupts map to input_required state.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + from strands.interrupt import Interrupt + + # Create a mock result with interrupts + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_interrupt = Interrupt(id="int-1", name="approval", reason="Need user approval") + mock_result.interrupts = [mock_interrupt] + + async def mock_stream(content_blocks, **kwargs): + yield {"data": "Processing..."} + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-interrupt" + mock_task.context_id = "ctx-interrupt" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "delete file X" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Verify input_required state was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 1 + msg_text = input_required_events[0].status.message.parts[0].root.text + assert "approval" in msg_text + assert "Need user approval" in msg_text + + +@pytest.mark.asyncio +async def test_execute_with_multiple_interrupts(mock_strands_agent, mock_request_context, mock_event_queue): + """Test handling of multiple interrupts in a single result.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + from strands.interrupt import Interrupt + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_result.interrupts = [ + Interrupt(id="int-1", name="confirm_delete", reason="Confirm deletion of file X"), + Interrupt(id="int-2", name="select_backup", reason="Choose backup location"), + ] + + async def mock_stream(content_blocks, **kwargs): + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-multi-int" + mock_task.context_id = "ctx-multi-int" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "delete with backup" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 1 + msg_text = input_required_events[0].status.message.parts[0].root.text + assert "confirm_delete" in msg_text + assert "select_backup" in msg_text + assert "Confirm deletion of file X" in msg_text + assert "Choose backup location" in msg_text + + +@pytest.mark.asyncio +async def test_execute_normal_completion_no_interrupts(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that normal completion (no interrupts) still works as before.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "end_turn" + mock_result.interrupts = None + mock_result.__str__ = MagicMock(return_value="Task completed successfully") + + async def mock_stream(content_blocks, **kwargs): + yield {"data": "Working..."} + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-normal" + mock_task.context_id = "ctx-normal" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "do something" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Verify completed state was enqueued (not input_required) + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + completed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.completed + ] + assert len(completed_events) == 1 + + # Verify no input_required events + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 0 + + +@pytest.mark.asyncio +async def test_execute_setup_failure_raises_server_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that setup failures (missing message) still raise ServerError.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-setup-fail" + mock_task.context_id = "ctx-setup-fail" + mock_request_context.current_task = mock_task + + # No message at all + mock_request_context.message = None + + with pytest.raises(ServerError) as excinfo: + await executor.execute(mock_request_context, mock_event_queue) + + assert isinstance(excinfo.value.error, InternalError) + + +@pytest.mark.asyncio +async def test_execute_error_when_task_already_terminal(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that error during execution is handled gracefully when task is already in terminal state.""" + from a2a.types import TextPart + + # Make stream_async raise to trigger the error path + mock_strands_agent.stream_async = MagicMock(side_effect=Exception("Agent error")) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-already-done" + mock_task.context_id = "ctx-already-done" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # Patch TaskUpdater.failed to raise RuntimeError (simulating task already in terminal state) + with patch("strands.multiagent.a2a.executor.TaskUpdater") as MockTaskUpdater: + mock_updater = MagicMock() + mock_updater.failed = AsyncMock(side_effect=RuntimeError("Task is already in a terminal state")) + mock_updater.new_agent_message = MagicMock(return_value=MagicMock()) + MockTaskUpdater.return_value = mock_updater + + # Should NOT raise - handles RuntimeError gracefully + await executor.execute(mock_request_context, mock_event_queue) + + # Verify failed() was attempted + mock_updater.failed.assert_called_once() + + +@pytest.mark.asyncio +async def test_cancel_calls_agent_cancel_method(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel() attempts to call agent.cancel() if available.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + # Give the agent a cancel method + mock_strands_agent.cancel = MagicMock() + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancel-agent" + mock_task.context_id = "ctx-cancel-agent" + mock_request_context.current_task = mock_task + + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify agent.cancel() was called + mock_strands_agent.cancel.assert_called_once() + + # Verify task state is canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_handles_agent_cancel_exception(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel() gracefully handles agent.cancel() raising an exception.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + # Give the agent a cancel method that raises + mock_strands_agent.cancel = MagicMock(side_effect=RuntimeError("Cannot cancel")) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancel-err" + mock_task.context_id = "ctx-cancel-err" + mock_request_context.current_task = mock_task + + # Should still succeed (agent cancel is best-effort) + await executor.cancel(mock_request_context, mock_event_queue) + + # Task should still be transitioned to canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_raises_when_task_already_terminal(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel() raises ServerError when task is already in a terminal state.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-terminal" + mock_task.context_id = "ctx-terminal" + mock_request_context.current_task = mock_task + + # Patch TaskUpdater.cancel to raise RuntimeError (task already completed/failed) + with patch("strands.multiagent.a2a.executor.TaskUpdater") as MockTaskUpdater: + mock_updater = MagicMock() + mock_updater.cancel = AsyncMock(side_effect=RuntimeError("Task is already in a terminal state")) + mock_updater.new_agent_message = MagicMock(return_value=MagicMock()) + MockTaskUpdater.return_value = mock_updater + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + assert isinstance(excinfo.value.error, UnsupportedOperationError) + mock_updater.cancel.assert_called_once() + + +# ========================================================================= +# DEVIL'S ADVOCATE FINDINGS — Tests addressing review gaps +# ========================================================================= + + +@pytest.mark.asyncio +async def test_execute_handles_asyncio_cancelled_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Critical Finding 1: asyncio.CancelledError transitions task to canceled state. + + asyncio.CancelledError is a BaseException (not Exception). It's raised when an asyncio + task is cancelled — e.g., HTTP client disconnect, server shutdown, task group cancellation. + Without explicit handling, the task would remain stuck in 'working' state forever (zombie). + + This test verifies the task transitions to 'canceled' before re-raising CancelledError. + """ + import asyncio + + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + async def mock_stream(content_blocks, **kwargs): + """Mock streaming that gets cancelled mid-stream.""" + yield {"data": "partial output"} + raise asyncio.CancelledError() + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancelled" + mock_task.context_id = "ctx-cancelled" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # CancelledError should be re-raised (framework needs to know task was cancelled) + with pytest.raises(asyncio.CancelledError): + await executor.execute(mock_request_context, mock_event_queue) + + # But BEFORE re-raising, the task should have been transitioned to canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + assert ( + "cancelled" in canceled_events[0].status.message.parts[0].root.text.lower() + or "connection termination" in canceled_events[0].status.message.parts[0].root.text.lower() + ) + + +@pytest.mark.asyncio +async def test_execute_asyncio_cancelled_when_task_already_terminal( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test CancelledError handling when task is already in a terminal state. + + If the task completed right before the cancellation arrives, the updater.cancel() + will raise RuntimeError. We should handle this gracefully and still re-raise CancelledError. + """ + import asyncio + + from a2a.types import TextPart + + async def mock_stream(content_blocks, **kwargs): + """Async generator that immediately raises CancelledError.""" + yield {"data": "partial"} # Must yield to be async generator + raise asyncio.CancelledError() + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancelled-terminal" + mock_task.context_id = "ctx-cancelled-terminal" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # Patch TaskUpdater to simulate task already in terminal state + with patch("strands.multiagent.a2a.executor.TaskUpdater") as MockTaskUpdater: + mock_updater = MagicMock() + mock_updater.cancel = AsyncMock(side_effect=RuntimeError("Task is already in a terminal state")) + mock_updater.update_status = AsyncMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.new_agent_message = MagicMock(return_value=MagicMock()) + mock_updater.context_id = "ctx-cancelled-terminal" + mock_updater.task_id = "task-cancelled-terminal" + MockTaskUpdater.return_value = mock_updater + + # Should still re-raise CancelledError + with pytest.raises(asyncio.CancelledError): + await executor.execute(mock_request_context, mock_event_queue) + + # cancel() was attempted + mock_updater.cancel.assert_called_once() + + +@pytest.mark.asyncio +async def test_execute_with_interrupt_empty_list_transitions_to_input_required( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Critical Finding 2: stop_reason='interrupt' with empty interrupts list. + + The agent explicitly signaled it needs input (stop_reason="interrupt") but provided + no interrupt details. This should STILL transition to input_required — the stop_reason + is the authoritative signal. Previously this would silently complete the task. + """ + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_result.interrupts = [] # Empty list — previously this was falsy and caused completion! + + async def mock_stream(content_blocks, **kwargs): + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-empty-interrupts" + mock_task.context_id = "ctx-empty-interrupts" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "do something" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Should transition to input_required, NOT completed + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + completed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.completed + ] + + assert len(input_required_events) == 1, "Empty interrupts list should still trigger input_required" + assert len(completed_events) == 0, "Should NOT complete when stop_reason='interrupt'" + # Verify the fallback message is used + assert "additional input" in input_required_events[0].status.message.parts[0].root.text.lower() + + +@pytest.mark.asyncio +async def test_execute_with_interrupt_none_list_transitions_to_input_required( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Edge case: stop_reason='interrupt' with interrupts=None. + + Same logic — the stop_reason is authoritative. None interrupts should + still result in input_required transition. + """ + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_result.interrupts = None # None, not empty list + + async def mock_stream(content_blocks, **kwargs): + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-none-interrupts" + mock_task.context_id = "ctx-none-interrupts" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "do something" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_without_hasattr_cancel(mock_strands_agent, mock_request_context, mock_event_queue): + """Test cancel works when agent doesn't have cancel() method (AttributeError).""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + # Remove cancel method entirely + del mock_strands_agent.cancel + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-no-cancel-method" + mock_task.context_id = "ctx-no-cancel-method" + mock_request_context.current_task = mock_task + + # Should succeed — AttributeError from agent.cancel() is caught + await executor.cancel(mock_request_context, mock_event_queue) + + # Task should still be transitioned to canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1