From 7ce5a092fc90409bf947486eec7fe55a6cc20d22 Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Mon, 4 May 2026 19:16:44 +0000 Subject: [PATCH 1/6] feat(a2a): implement full A2A task lifecycle state support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements all lifecycle states for the A2A protocol integration: 1. Error mapping → failed state: - Agent exceptions now gracefully transition tasks to TaskState.failed - Error details are communicated via the status message - ServerErrors (setup failures) still propagate as before 2. Cancel support: - cancel() now transitions tasks to TaskState.canceled with a message - Raises UnsupportedOperationError only when no current task exists - Handles already-terminal tasks gracefully 3. input_required via Strands interrupts: - Agent stop_reason='interrupt' maps to TaskState.input_required - Interrupt names and reasons are communicated to the client - Multiple interrupts are listed in the status message 4. Client-side handling of all states: - _is_complete_event recognizes all terminal states (completed, failed, canceled, rejected) and pausing states (input_required, auth_required) - convert_response_to_agent_result maps A2A states to Strands stop_reasons: * completed/failed/canceled/rejected → end_turn * input_required/auth_required → interrupt - AgentResult.state includes a2a_task_state for downstream inspection All changes use the existing a2a-sdk 0.3.26 - no dependency changes needed. Resolves: strands-agents/sdk-python#1371 --- src/strands/agent/a2a_agent.py | 27 +- src/strands/multiagent/a2a/_converters.py | 48 +++- src/strands/multiagent/a2a/executor.py | 103 +++++-- tests/strands/agent/test_a2a_agent.py | 103 +++++++ .../strands/multiagent/a2a/test_converters.py | 149 ++++++++++ tests/strands/multiagent/a2a/test_executor.py | 270 +++++++++++++++++- 6 files changed, 669 insertions(+), 31 deletions(-) diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py index eef47e3b4..25b9d8f63 100644 --- a/src/strands/agent/a2a_agent.py +++ b/src/strands/agent/a2a_agent.py @@ -29,6 +29,20 @@ _DEFAULT_TIMEOUT = 300 +# A2A task states that indicate the task is complete (no more events expected) +_TERMINAL_STATES = { + TaskState.completed, + TaskState.failed, + TaskState.canceled, + TaskState.rejected, +} + +# A2A task states that pause execution awaiting external input +_INPUT_STATES = { + TaskState.input_required, + TaskState.auth_required, +} + class A2AAgent(AgentBase): """Client wrapper for remote A2A agents.""" @@ -265,6 +279,9 @@ async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]: def _is_complete_event(self, event: A2AResponse) -> bool: """Check if an A2A event represents a complete response. + Recognizes all terminal states (completed, failed, canceled, rejected) + and pausing states (input_required, auth_required) as complete events. + Args: event: A2A event. @@ -289,9 +306,15 @@ def _is_complete_event(self, event: A2AResponse) -> bool: return update_event.last_chunk return False - # Status update with completed state + # Status update - check for terminal or pausing states if isinstance(update_event, TaskStatusUpdateEvent): if update_event.status and hasattr(update_event.status, "state"): - return update_event.status.state == TaskState.completed + state = update_event.status.state + # Terminal states: task is done + if state in _TERMINAL_STATES: + return True + # Input-required states: task is paused, waiting for user + if state in _INPUT_STATES: + return True return False diff --git a/src/strands/multiagent/a2a/_converters.py b/src/strands/multiagent/a2a/_converters.py index 22c2ffb72..78e65dee1 100644 --- a/src/strands/multiagent/a2a/_converters.py +++ b/src/strands/multiagent/a2a/_converters.py @@ -4,7 +4,7 @@ from uuid import uuid4 from a2a.types import Message as A2AMessage -from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart +from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent, TextPart from ...agent.agent_result import AgentResult from ...telemetry.metrics import EventLoopMetrics @@ -12,6 +12,16 @@ from ...types.agent import AgentInput from ...types.content import ContentBlock, Message +# Mapping from A2A TaskState to Strands stop_reason +_STATE_TO_STOP_REASON = { + TaskState.completed: "end_turn", + TaskState.failed: "end_turn", + TaskState.canceled: "end_turn", + TaskState.rejected: "end_turn", + TaskState.input_required: "interrupt", + TaskState.auth_required: "interrupt", +} + def convert_input_to_message(prompt: AgentInput) -> A2AMessage: """Convert AgentInput to A2A Message. @@ -79,9 +89,34 @@ def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[ return parts +def _extract_task_state(response: A2AResponse) -> TaskState | None: + """Extract the task state from an A2A response. + + Args: + response: A2A response (either A2AMessage or tuple of task and update event). + + Returns: + The TaskState if available, None otherwise. + """ + if isinstance(response, tuple) and len(response) == 2: + _task, update_event = response + if isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "state"): + return update_event.status.state + return None + + def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: """Convert A2A response to AgentResult. + Maps A2A task lifecycle states to appropriate Strands stop_reasons: + - completed → end_turn + - failed → end_turn (with error content) + - canceled → end_turn (with cancellation info) + - rejected → end_turn (with rejection info) + - input_required → interrupt (agent needs user input) + - auth_required → interrupt (agent needs authentication) + Args: response: A2A response (either A2AMessage or tuple of task and update event). @@ -89,6 +124,8 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: AgentResult with extracted content and metadata. """ content: list[ContentBlock] = [] + task_state = _extract_task_state(response) + stop_reason = _STATE_TO_STOP_REASON.get(task_state, "end_turn") if task_state else "end_turn" if isinstance(response, tuple) and len(response) == 2: task, update_event = response @@ -123,9 +160,14 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: "content": content, } + # Build state dict with A2A metadata + state: dict = {} + if task_state is not None: + state["a2a_task_state"] = task_state.value + return AgentResult( - stop_reason="end_turn", + stop_reason=stop_reason, message=message, metrics=EventLoopMetrics(), - state={}, + state=state, ) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index c8c00600b..fa028e60c 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -42,7 +42,9 @@ class StrandsA2AExecutor(AgentExecutor): """Executor that adapts a Strands Agent to the A2A protocol. This executor uses streaming mode to handle the execution of agent requests - and converts Strands Agent responses to A2A protocol events. + and converts Strands Agent responses to A2A protocol events. It supports the + full A2A task lifecycle including error handling (failed state), cancellation, + and interrupt-based input_required flows. """ # Default formats for each file type when MIME type is unavailable or unrecognized @@ -75,14 +77,18 @@ async def execute( """Execute a request using the Strands Agent and send the response as A2A events. This method executes the user's input using the Strands Agent in streaming mode - and converts the agent's response to A2A events. + and converts the agent's response to A2A events. If the agent raises an exception, + the task transitions to the `failed` state. If the agent returns with interrupts, + the task transitions to the `input_required` state. Args: context: The A2A request context, containing the user's input and task metadata. event_queue: The A2A event queue used to send response events back to the client. Raises: - ServerError: If an error occurs during agent execution + ServerError: If an unrecoverable error occurs during agent execution setup + (e.g., missing input). Agent execution errors are handled gracefully + by transitioning the task to the failed state. """ task = context.current_task if not task: @@ -93,8 +99,21 @@ async def execute( try: await self._execute_streaming(context, updater) + except ServerError: + # Re-raise ServerErrors (setup failures like missing input) + raise except Exception as e: - raise ServerError(error=InternalError()) from e + # Agent execution failures transition to failed state + logger.exception("Agent execution failed, transitioning task to failed state") + try: + await updater.failed( + message=updater.new_agent_message( + parts=[Part(root=TextPart(text=f"Agent execution failed: {e}"))] + ) + ) + except RuntimeError: + # Task already in terminal state (e.g., completed before error in cleanup) + logger.debug("Task already in terminal state, cannot transition to failed") async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater) -> None: """Execute request in streaming mode. @@ -105,14 +124,17 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater Args: context: The A2A request context, containing the user's input and other metadata. updater: The task updater for managing task state and sending updates. + + Raises: + ServerError: If input conversion fails (missing or empty content). """ # Convert A2A message parts to Strands ContentBlocks if context.message and hasattr(context.message, "parts"): content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts) if not content_blocks: - raise ValueError("No content blocks available") + raise ServerError(error=InternalError()) else: - raise ValueError("No content blocks available") + raise ServerError(error=InternalError()) if not self.enable_a2a_compliant_streaming: warnings.warn( @@ -133,8 +155,18 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater invocation_state: dict[str, Any] = {"a2a_request_context": context} try: + result: SAAgentResult | None = None async for event in self.agent.stream_async(content_blocks, invocation_state=invocation_state): - await self._handle_streaming_event(event, updater) + if "result" in event: + result = event["result"] + else: + await self._handle_streaming_event(event, updater) + + # Check if agent returned with interrupts (input_required) + if result is not None and result.stop_reason == "interrupt" and result.interrupts: + await self._handle_interrupt_result(result, updater) + else: + await self._handle_agent_result(result, updater) except Exception: logger.exception("Error in streaming execution") raise @@ -143,6 +175,33 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater self._current_artifact_id = None self._is_first_chunk = True + async def _handle_interrupt_result(self, result: SAAgentResult, updater: TaskUpdater) -> None: + """Handle an agent result that contains interrupts. + + When the Strands Agent returns with stop_reason="interrupt", this maps to + the A2A `input_required` state. The interrupt details are communicated to + the client via the status message. + + Args: + result: The agent result containing interrupts. + updater: The task updater for managing task state. + """ + # Build a descriptive message about what input is needed + interrupt_descriptions = [] + for interrupt in result.interrupts or []: + desc = f"- {interrupt.name}" + if interrupt.reason: + desc += f": {interrupt.reason}" + interrupt_descriptions.append(desc) + + input_message = "Agent requires input:\n" + "\n".join(interrupt_descriptions) + + await updater.requires_input( + message=updater.new_agent_message( + parts=[Part(root=TextPart(text=input_message))] + ) + ) + async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None: """Handle a single streaming event from the Strands Agent. @@ -175,8 +234,6 @@ async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpda updater.task_id, ), ) - elif "result" in event: - await self._handle_agent_result(event["result"], updater) async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None: """Handle the final result from the Strands Agent. @@ -219,20 +276,30 @@ async def _handle_agent_result(self, result: SAAgentResult | None, updater: Task async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """Cancel an ongoing execution. - This method is called when a request cancellation is requested. Currently, - cancellation is not supported by the Strands Agent executor, so this method - always raises an UnsupportedOperationError. + Transitions the task to the canceled state. If the agent supports cancellation + (e.g., via a stop mechanism), this will signal the agent to stop processing. Args: context: The A2A request context. event_queue: The A2A event queue. - - Raises: - ServerError: Always raised with an UnsupportedOperationError, as cancellation - is not currently supported. """ - logger.warning("Cancellation requested but not supported") - raise ServerError(error=UnsupportedOperationError()) + task = context.current_task + if not task: + logger.warning("Cancellation requested but no current task found") + raise ServerError(error=UnsupportedOperationError()) + + updater = TaskUpdater(event_queue, task.id, task.context_id) + + try: + await updater.cancel( + message=updater.new_agent_message( + parts=[Part(root=TextPart(text="Task cancelled by client request"))] + ) + ) + except RuntimeError: + # Task already in terminal state + logger.warning("Cannot cancel task %s: already in terminal state", task.id) + raise ServerError(error=UnsupportedOperationError()) def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]: """Classify file type based on MIME type. diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py index d918033e5..fab1d4a1b 100644 --- a/tests/strands/agent/test_a2a_agent.py +++ b/tests/strands/agent/test_a2a_agent.py @@ -714,3 +714,106 @@ async def mock_send_message(*args, **kwargs): # Should have 1 stream event + 1 result event (falls back to last) assert len(events) == 2 assert "result" in events[1] + + +# ========================================================================= +# NEW TESTS: Client-side lifecycle state handling +# ========================================================================= + + +def test_is_complete_event_failed_state(a2a_agent): + """Test that failed state is recognized as complete.""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + task = MagicMock() + status = MagicMock() + status.state = TaskState.failed + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_canceled_state(a2a_agent): + """Test that canceled state is recognized as complete.""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + task = MagicMock() + status = MagicMock() + status.state = TaskState.canceled + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_rejected_state(a2a_agent): + """Test that rejected state is recognized as complete.""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + task = MagicMock() + status = MagicMock() + status.state = TaskState.rejected + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_input_required_state(a2a_agent): + """Test that input_required state is recognized as complete (pausing).""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + task = MagicMock() + status = MagicMock() + status.state = TaskState.input_required + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_auth_required_state(a2a_agent): + """Test that auth_required state is recognized as complete (pausing).""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + task = MagicMock() + status = MagicMock() + status.state = TaskState.auth_required + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_working_state_not_complete(a2a_agent): + """Test that working state is NOT recognized as complete.""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + task = MagicMock() + status = MagicMock() + status.state = TaskState.working + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is False + + +def test_is_complete_event_submitted_state_not_complete(a2a_agent): + """Test that submitted state is NOT recognized as complete.""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + task = MagicMock() + status = MagicMock() + status.state = TaskState.submitted + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is False diff --git a/tests/strands/multiagent/a2a/test_converters.py b/tests/strands/multiagent/a2a/test_converters.py index c3b310065..7f8e885a8 100644 --- a/tests/strands/multiagent/a2a/test_converters.py +++ b/tests/strands/multiagent/a2a/test_converters.py @@ -243,3 +243,152 @@ def test_convert_response_handles_missing_data(): mock_task.artifacts = [mock_artifact] result = convert_response_to_agent_result((mock_task, None)) assert len(result.message["content"]) == 0 + + +# ========================================================================= +# NEW TESTS: Lifecycle State Mapping +# ========================================================================= + + +def test_convert_response_completed_state_maps_to_end_turn(): + """Test that completed state maps to end_turn stop_reason.""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus, Part, TextPart + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.completed, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + + +def test_convert_response_failed_state_maps_to_end_turn(): + """Test that failed state maps to end_turn stop_reason with error content.""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus, Message, Part, TextPart, Role + + task = MagicMock() + task.artifacts = None + + # Create a status message with error info + error_part = MagicMock() + error_part.root = MagicMock() + error_part.root.text = "Agent execution failed: timeout" + + error_message = MagicMock(spec=Message) + error_message.parts = [error_part] + + status = TaskStatus(state=TaskState.failed, message=error_message) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "failed" + assert "Agent execution failed" in result.message["content"][0]["text"] + + +def test_convert_response_input_required_maps_to_interrupt(): + """Test that input_required state maps to interrupt stop_reason.""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus, Message, Part, TextPart, Role + + task = MagicMock() + task.artifacts = None + + input_part = MagicMock() + input_part.root = MagicMock() + input_part.root.text = "Agent requires input:\n- approval: Need confirmation" + + input_message = MagicMock(spec=Message) + input_message.parts = [input_part] + + status = TaskStatus(state=TaskState.input_required, message=input_message) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "interrupt" + assert result.state.get("a2a_task_state") == "input-required" + assert "approval" in result.message["content"][0]["text"] + + +def test_convert_response_canceled_state_maps_to_end_turn(): + """Test that canceled state maps to end_turn stop_reason.""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.canceled, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "canceled" + + +def test_convert_response_rejected_state_maps_to_end_turn(): + """Test that rejected state maps to end_turn stop_reason.""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.rejected, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "rejected" + + +def test_convert_response_auth_required_maps_to_interrupt(): + """Test that auth_required state maps to interrupt stop_reason.""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.auth_required, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "interrupt" + assert result.state.get("a2a_task_state") == "auth-required" + + +def test_extract_task_state_from_status_update(): + """Test _extract_task_state helper.""" + from unittest.mock import MagicMock + from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + from strands.multiagent.a2a._converters import _extract_task_state + + task = MagicMock() + status = TaskStatus(state=TaskState.failed, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + state = _extract_task_state((task, update_event)) + assert state == TaskState.failed + + +def test_extract_task_state_from_message_returns_none(): + """Test _extract_task_state returns None for Message responses.""" + from unittest.mock import MagicMock + from a2a.types import Message + from strands.multiagent.a2a._converters import _extract_task_state + + message = MagicMock(spec=Message) + state = _extract_task_state(message) + assert state is None diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index dc90fbdd6..553d830fe 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -583,7 +583,7 @@ async def mock_stream(content_blocks): async def test_execute_streaming_mode_handles_agent_exception( mock_strands_agent, mock_request_context, mock_event_queue ): - """Test that execute handles agent exceptions correctly in streaming mode.""" + """Test that execute transitions to failed state when agent raises exception.""" # Setup mock agent to raise exception when stream_async is called mock_strands_agent.stream_async = MagicMock(side_effect=Exception("Agent error")) @@ -608,25 +608,31 @@ async def test_execute_streaming_mode_handles_agent_exception( mock_message.parts = [part] mock_request_context.message = mock_message - with pytest.raises(ServerError): - await executor.execute(mock_request_context, mock_event_queue) + # Should NOT raise - instead transitions to failed state + await executor.execute(mock_request_context, mock_event_queue) # Verify agent was called mock_strands_agent.stream_async.assert_called_once() - -@pytest.mark.asyncio -async def test_cancel_raises_unsupported_operation_error(mock_strands_agent, mock_request_context, mock_event_queue): - """Test that cancel raises UnsupportedOperationError.""" + # Verify a failed status event was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + from a2a.types import TaskStatusUpdateEvent, TaskState + failed_events = [ + e for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed + ] + assert len(failed_events) == 1 + assert "Agent execution failed" in failed_events[0].status.message.parts[0].root.text executor = StrandsA2AExecutor(mock_strands_agent) + # Cancel with no current_task raises UnsupportedOperationError + mock_request_context.current_task = None with pytest.raises(ServerError) as excinfo: await executor.cancel(mock_request_context, mock_event_queue) # Verify the error is a ServerError containing an UnsupportedOperationError assert isinstance(excinfo.value.error, UnsupportedOperationError) - @pytest.mark.asyncio async def test_handle_agent_result_with_none_result(mock_strands_agent, mock_request_context, mock_event_queue): """Test that _handle_agent_result handles None result correctly.""" @@ -1331,3 +1337,251 @@ async def test_invocation_state_with_a2a_compliant_streaming( assert invocation_state is not None assert invocation_state["a2a_request_context"] is mock_request_context + + +# ========================================================================= +# NEW TESTS: A2A Lifecycle State Support +# ========================================================================= + + +@pytest.mark.asyncio +async def test_execute_transitions_to_failed_on_streaming_error( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test that errors during streaming transition task to failed state.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + async def mock_stream(content_blocks, **kwargs): + """Mock streaming that raises mid-stream.""" + yield {"data": "partial output"} + raise RuntimeError("Connection lost") + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-fail" + mock_task.context_id = "ctx-fail" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # Should not raise + await executor.execute(mock_request_context, mock_event_queue) + + # Verify failed state was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + failed_events = [ + e for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed + ] + assert len(failed_events) == 1 + assert "Connection lost" in failed_events[0].status.message.parts[0].root.text + + +@pytest.mark.asyncio +async def test_cancel_with_valid_task(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel transitions task to canceled state when task exists.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancel" + mock_task.context_id = "ctx-cancel" + mock_request_context.current_task = mock_task + + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify canceled state was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + assert "cancelled" in canceled_events[0].status.message.parts[0].root.text.lower() + + +@pytest.mark.asyncio +async def test_cancel_without_task_raises_unsupported(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel raises UnsupportedOperationError when no task exists.""" + executor = StrandsA2AExecutor(mock_strands_agent) + mock_request_context.current_task = None + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + assert isinstance(excinfo.value.error, UnsupportedOperationError) + + +@pytest.mark.asyncio +async def test_execute_with_interrupt_transitions_to_input_required( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test that agent interrupts map to input_required state.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + from strands.interrupt import Interrupt + + # Create a mock result with interrupts + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_interrupt = Interrupt(id="int-1", name="approval", reason="Need user approval") + mock_result.interrupts = [mock_interrupt] + + async def mock_stream(content_blocks, **kwargs): + yield {"data": "Processing..."} + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-interrupt" + mock_task.context_id = "ctx-interrupt" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "delete file X" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Verify input_required state was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 1 + msg_text = input_required_events[0].status.message.parts[0].root.text + assert "approval" in msg_text + assert "Need user approval" in msg_text + + +@pytest.mark.asyncio +async def test_execute_with_multiple_interrupts(mock_strands_agent, mock_request_context, mock_event_queue): + """Test handling of multiple interrupts in a single result.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + from strands.interrupt import Interrupt + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_result.interrupts = [ + Interrupt(id="int-1", name="confirm_delete", reason="Confirm deletion of file X"), + Interrupt(id="int-2", name="select_backup", reason="Choose backup location"), + ] + + async def mock_stream(content_blocks, **kwargs): + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-multi-int" + mock_task.context_id = "ctx-multi-int" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "delete with backup" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 1 + msg_text = input_required_events[0].status.message.parts[0].root.text + assert "confirm_delete" in msg_text + assert "select_backup" in msg_text + assert "Confirm deletion of file X" in msg_text + assert "Choose backup location" in msg_text + + +@pytest.mark.asyncio +async def test_execute_normal_completion_no_interrupts(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that normal completion (no interrupts) still works as before.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "end_turn" + mock_result.interrupts = None + mock_result.__str__ = MagicMock(return_value="Task completed successfully") + + async def mock_stream(content_blocks, **kwargs): + yield {"data": "Working..."} + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-normal" + mock_task.context_id = "ctx-normal" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "do something" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Verify completed state was enqueued (not input_required) + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + completed_events = [ + e for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.completed + ] + assert len(completed_events) == 1 + + # Verify no input_required events + input_required_events = [ + e for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 0 + + +@pytest.mark.asyncio +async def test_execute_setup_failure_raises_server_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that setup failures (missing message) still raise ServerError.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-setup-fail" + mock_task.context_id = "ctx-setup-fail" + mock_request_context.current_task = mock_task + + # No message at all + mock_request_context.message = None + + with pytest.raises(ServerError) as excinfo: + await executor.execute(mock_request_context, mock_event_queue) + + assert isinstance(excinfo.value.error, InternalError) From ec92858efee6a63e04a00665146e7a9cb824a873 Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Mon, 4 May 2026 19:21:30 +0000 Subject: [PATCH 2/6] fix: lint and formatting fixes for A2A lifecycle states - Fix B904: add 'from None' to re-raised ServerError in cancel() - Fix F401: remove unused imports (TaskStatus, Part, TextPart, Role) in tests - Fix I001: sort import blocks in test files - Fix ruff formatting in executor.py and test_executor.py --- src/strands/multiagent/a2a/executor.py | 18 ++++--------- tests/strands/agent/test_a2a_agent.py | 21 ++++++++++----- .../strands/multiagent/a2a/test_converters.py | 24 ++++++++++++----- tests/strands/multiagent/a2a/test_executor.py | 27 ++++++++++--------- 4 files changed, 51 insertions(+), 39 deletions(-) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index fa028e60c..273b7dcdf 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -107,9 +107,7 @@ async def execute( logger.exception("Agent execution failed, transitioning task to failed state") try: await updater.failed( - message=updater.new_agent_message( - parts=[Part(root=TextPart(text=f"Agent execution failed: {e}"))] - ) + message=updater.new_agent_message(parts=[Part(root=TextPart(text=f"Agent execution failed: {e}"))]) ) except RuntimeError: # Task already in terminal state (e.g., completed before error in cleanup) @@ -196,11 +194,7 @@ async def _handle_interrupt_result(self, result: SAAgentResult, updater: TaskUpd input_message = "Agent requires input:\n" + "\n".join(interrupt_descriptions) - await updater.requires_input( - message=updater.new_agent_message( - parts=[Part(root=TextPart(text=input_message))] - ) - ) + await updater.requires_input(message=updater.new_agent_message(parts=[Part(root=TextPart(text=input_message))])) async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None: """Handle a single streaming event from the Strands Agent. @@ -286,20 +280,18 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None task = context.current_task if not task: logger.warning("Cancellation requested but no current task found") - raise ServerError(error=UnsupportedOperationError()) + raise ServerError(error=UnsupportedOperationError()) from None updater = TaskUpdater(event_queue, task.id, task.context_id) try: await updater.cancel( - message=updater.new_agent_message( - parts=[Part(root=TextPart(text="Task cancelled by client request"))] - ) + message=updater.new_agent_message(parts=[Part(root=TextPart(text="Task cancelled by client request"))]) ) except RuntimeError: # Task already in terminal state logger.warning("Cannot cancel task %s: already in terminal state", task.id) - raise ServerError(error=UnsupportedOperationError()) + raise ServerError(error=UnsupportedOperationError()) from None def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]: """Classify file type based on MIME type. diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py index fab1d4a1b..5bcd35b96 100644 --- a/tests/strands/agent/test_a2a_agent.py +++ b/tests/strands/agent/test_a2a_agent.py @@ -724,7 +724,8 @@ async def mock_send_message(*args, **kwargs): def test_is_complete_event_failed_state(a2a_agent): """Test that failed state is recognized as complete.""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + from a2a.types import TaskState, TaskStatusUpdateEvent task = MagicMock() status = MagicMock() @@ -738,7 +739,8 @@ def test_is_complete_event_failed_state(a2a_agent): def test_is_complete_event_canceled_state(a2a_agent): """Test that canceled state is recognized as complete.""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + from a2a.types import TaskState, TaskStatusUpdateEvent task = MagicMock() status = MagicMock() @@ -752,7 +754,8 @@ def test_is_complete_event_canceled_state(a2a_agent): def test_is_complete_event_rejected_state(a2a_agent): """Test that rejected state is recognized as complete.""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + from a2a.types import TaskState, TaskStatusUpdateEvent task = MagicMock() status = MagicMock() @@ -766,7 +769,8 @@ def test_is_complete_event_rejected_state(a2a_agent): def test_is_complete_event_input_required_state(a2a_agent): """Test that input_required state is recognized as complete (pausing).""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + from a2a.types import TaskState, TaskStatusUpdateEvent task = MagicMock() status = MagicMock() @@ -780,7 +784,8 @@ def test_is_complete_event_input_required_state(a2a_agent): def test_is_complete_event_auth_required_state(a2a_agent): """Test that auth_required state is recognized as complete (pausing).""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + from a2a.types import TaskState, TaskStatusUpdateEvent task = MagicMock() status = MagicMock() @@ -794,7 +799,8 @@ def test_is_complete_event_auth_required_state(a2a_agent): def test_is_complete_event_working_state_not_complete(a2a_agent): """Test that working state is NOT recognized as complete.""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + from a2a.types import TaskState, TaskStatusUpdateEvent task = MagicMock() status = MagicMock() @@ -808,7 +814,8 @@ def test_is_complete_event_working_state_not_complete(a2a_agent): def test_is_complete_event_submitted_state_not_complete(a2a_agent): """Test that submitted state is NOT recognized as complete.""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + from a2a.types import TaskState, TaskStatusUpdateEvent task = MagicMock() status = MagicMock() diff --git a/tests/strands/multiagent/a2a/test_converters.py b/tests/strands/multiagent/a2a/test_converters.py index 7f8e885a8..cb2b97d51 100644 --- a/tests/strands/multiagent/a2a/test_converters.py +++ b/tests/strands/multiagent/a2a/test_converters.py @@ -253,7 +253,8 @@ def test_convert_response_handles_missing_data(): def test_convert_response_completed_state_maps_to_end_turn(): """Test that completed state maps to end_turn stop_reason.""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus, Part, TextPart + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent task = MagicMock() task.artifacts = None @@ -269,7 +270,8 @@ def test_convert_response_completed_state_maps_to_end_turn(): def test_convert_response_failed_state_maps_to_end_turn(): """Test that failed state maps to end_turn stop_reason with error content.""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus, Message, Part, TextPart, Role + + from a2a.types import Message, TaskState, TaskStatus, TaskStatusUpdateEvent task = MagicMock() task.artifacts = None @@ -295,7 +297,8 @@ def test_convert_response_failed_state_maps_to_end_turn(): def test_convert_response_input_required_maps_to_interrupt(): """Test that input_required state maps to interrupt stop_reason.""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus, Message, Part, TextPart, Role + + from a2a.types import Message, TaskState, TaskStatus, TaskStatusUpdateEvent task = MagicMock() task.artifacts = None @@ -320,7 +323,8 @@ def test_convert_response_input_required_maps_to_interrupt(): def test_convert_response_canceled_state_maps_to_end_turn(): """Test that canceled state maps to end_turn stop_reason.""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent task = MagicMock() task.artifacts = None @@ -337,7 +341,8 @@ def test_convert_response_canceled_state_maps_to_end_turn(): def test_convert_response_rejected_state_maps_to_end_turn(): """Test that rejected state maps to end_turn stop_reason.""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent task = MagicMock() task.artifacts = None @@ -354,7 +359,8 @@ def test_convert_response_rejected_state_maps_to_end_turn(): def test_convert_response_auth_required_maps_to_interrupt(): """Test that auth_required state maps to interrupt stop_reason.""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent task = MagicMock() task.artifacts = None @@ -371,7 +377,9 @@ def test_convert_response_auth_required_maps_to_interrupt(): def test_extract_task_state_from_status_update(): """Test _extract_task_state helper.""" from unittest.mock import MagicMock - from a2a.types import TaskState, TaskStatusUpdateEvent, TaskStatus + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + from strands.multiagent.a2a._converters import _extract_task_state task = MagicMock() @@ -386,7 +394,9 @@ def test_extract_task_state_from_status_update(): def test_extract_task_state_from_message_returns_none(): """Test _extract_task_state returns None for Message responses.""" from unittest.mock import MagicMock + from a2a.types import Message + from strands.multiagent.a2a._converters import _extract_task_state message = MagicMock(spec=Message) diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 553d830fe..63891e8bb 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -616,10 +616,10 @@ async def test_execute_streaming_mode_handles_agent_exception( # Verify a failed status event was enqueued enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] - from a2a.types import TaskStatusUpdateEvent, TaskState + from a2a.types import TaskState, TaskStatusUpdateEvent + failed_events = [ - e for e in enqueued_events - if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed ] assert len(failed_events) == 1 assert "Agent execution failed" in failed_events[0].status.message.parts[0].root.text @@ -633,6 +633,7 @@ async def test_execute_streaming_mode_handles_agent_exception( # Verify the error is a ServerError containing an UnsupportedOperationError assert isinstance(excinfo.value.error, UnsupportedOperationError) + @pytest.mark.asyncio async def test_handle_agent_result_with_none_result(mock_strands_agent, mock_request_context, mock_event_queue): """Test that _handle_agent_result handles None result correctly.""" @@ -1379,8 +1380,7 @@ async def mock_stream(content_blocks, **kwargs): # Verify failed state was enqueued enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] failed_events = [ - e for e in enqueued_events - if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed ] assert len(failed_events) == 1 assert "Connection lost" in failed_events[0].status.message.parts[0].root.text @@ -1403,8 +1403,7 @@ async def test_cancel_with_valid_task(mock_strands_agent, mock_request_context, # Verify canceled state was enqueued enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] canceled_events = [ - e for e in enqueued_events - if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled ] assert len(canceled_events) == 1 assert "cancelled" in canceled_events[0].status.message.parts[0].root.text.lower() @@ -1428,6 +1427,7 @@ async def test_execute_with_interrupt_transitions_to_input_required( ): """Test that agent interrupts map to input_required state.""" from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + from strands.interrupt import Interrupt # Create a mock result with interrupts @@ -1462,7 +1462,8 @@ async def mock_stream(content_blocks, **kwargs): # Verify input_required state was enqueued enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] input_required_events = [ - e for e in enqueued_events + e + for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required ] assert len(input_required_events) == 1 @@ -1475,6 +1476,7 @@ async def mock_stream(content_blocks, **kwargs): async def test_execute_with_multiple_interrupts(mock_strands_agent, mock_request_context, mock_event_queue): """Test handling of multiple interrupts in a single result.""" from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + from strands.interrupt import Interrupt mock_result = MagicMock(spec=SAAgentResult) @@ -1508,7 +1510,8 @@ async def mock_stream(content_blocks, **kwargs): enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] input_required_events = [ - e for e in enqueued_events + e + for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required ] assert len(input_required_events) == 1 @@ -1555,14 +1558,14 @@ async def mock_stream(content_blocks, **kwargs): # Verify completed state was enqueued (not input_required) enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] completed_events = [ - e for e in enqueued_events - if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.completed + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.completed ] assert len(completed_events) == 1 # Verify no input_required events input_required_events = [ - e for e in enqueued_events + e + for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required ] assert len(input_required_events) == 0 From 451cae9bf4789c0e2589cda527127fc1ab346951 Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Mon, 4 May 2026 19:45:40 +0000 Subject: [PATCH 3/6] fix: address CI failure (mypy), review comments, and increase coverage Fixes: - mypy error: _STATE_TO_STOP_REASON now typed as dict[TaskState, StopReason] instead of implicit str values (was: 'str' incompatible with Literal type) - Bug: None parts crash in convert_response_to_agent_result (artifact.parts and message.parts checked for None before iteration) - Security: error messages no longer expose raw exception details to clients - Lint: removed unused variable 'e' in outer except clause (F841) Review feedback addressed: - Structured logging: all log messages now use 'task_id=<%s> | message' format - cancel() docstring: accurately describes state-only transition + best-effort agent.cancel() call - cancel() now calls agent.cancel() if method exists (cooperative cancellation) - DRY: added _COMPLETE_STATES = _TERMINAL_STATES | _INPUT_STATES and simplified _is_complete_event to use single set membership test - Type annotation: state dict uses dict[str, str] instead of bare dict - RuntimeError catches: added comments explaining they guard TaskUpdater's terminal state enforcement Coverage improvement: - Added tests for: task-already-terminal error path, agent.cancel() call, agent.cancel() exception handling - 181 tests pass (up from 178) --- src/strands/agent/a2a_agent.py | 14 ++- src/strands/multiagent/a2a/_converters.py | 18 ++-- src/strands/multiagent/a2a/executor.py | 36 ++++--- tests/strands/multiagent/a2a/test_executor.py | 98 ++++++++++++++++++- 4 files changed, 140 insertions(+), 26 deletions(-) diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py index 25b9d8f63..6b2e0bc01 100644 --- a/src/strands/agent/a2a_agent.py +++ b/src/strands/agent/a2a_agent.py @@ -29,7 +29,9 @@ _DEFAULT_TIMEOUT = 300 -# A2A task states that indicate the task is complete (no more events expected) +# A2A task states that indicate the response stream is complete. +# Terminal states mean no more events; input states mean execution is paused. +# Derived from _STATE_TO_STOP_REASON in _converters to maintain single source of truth. _TERMINAL_STATES = { TaskState.completed, TaskState.failed, @@ -37,12 +39,13 @@ TaskState.rejected, } -# A2A task states that pause execution awaiting external input _INPUT_STATES = { TaskState.input_required, TaskState.auth_required, } +_COMPLETE_STATES = _TERMINAL_STATES | _INPUT_STATES + class A2AAgent(AgentBase): """Client wrapper for remote A2A agents.""" @@ -310,11 +313,6 @@ def _is_complete_event(self, event: A2AResponse) -> bool: if isinstance(update_event, TaskStatusUpdateEvent): if update_event.status and hasattr(update_event.status, "state"): state = update_event.status.state - # Terminal states: task is done - if state in _TERMINAL_STATES: - return True - # Input-required states: task is paused, waiting for user - if state in _INPUT_STATES: - return True + return state in _COMPLETE_STATES return False diff --git a/src/strands/multiagent/a2a/_converters.py b/src/strands/multiagent/a2a/_converters.py index 78e65dee1..7808ae325 100644 --- a/src/strands/multiagent/a2a/_converters.py +++ b/src/strands/multiagent/a2a/_converters.py @@ -11,9 +11,10 @@ from ...types.a2a import A2AResponse from ...types.agent import AgentInput from ...types.content import ContentBlock, Message +from ...types.event_loop import StopReason # Mapping from A2A TaskState to Strands stop_reason -_STATE_TO_STOP_REASON = { +_STATE_TO_STOP_REASON: dict[TaskState, StopReason] = { TaskState.completed: "end_turn", TaskState.failed: "end_turn", TaskState.canceled: "end_turn", @@ -125,20 +126,25 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: """ content: list[ContentBlock] = [] task_state = _extract_task_state(response) - stop_reason = _STATE_TO_STOP_REASON.get(task_state, "end_turn") if task_state else "end_turn" + stop_reason: StopReason = _STATE_TO_STOP_REASON.get(task_state, "end_turn") if task_state else "end_turn" if isinstance(response, tuple) and len(response) == 2: task, update_event = response # Handle artifact updates if isinstance(update_event, TaskArtifactUpdateEvent): - if update_event.artifact and hasattr(update_event.artifact, "parts"): + if update_event.artifact and hasattr(update_event.artifact, "parts") and update_event.artifact.parts: for part in update_event.artifact.parts: if hasattr(part, "root") and hasattr(part.root, "text"): content.append({"text": part.root.text}) # Handle status updates with messages elif isinstance(update_event, TaskStatusUpdateEvent): - if update_event.status and hasattr(update_event.status, "message") and update_event.status.message: + if ( + update_event.status + and hasattr(update_event.status, "message") + and update_event.status.message + and update_event.status.message.parts + ): for part in update_event.status.message.parts: if hasattr(part, "root") and hasattr(part.root, "text"): content.append({"text": part.root.text}) @@ -146,7 +152,7 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: # Use task.artifacts when no content was extracted from the event if not content and task and hasattr(task, "artifacts") and task.artifacts is not None: for artifact in task.artifacts: - if hasattr(artifact, "parts"): + if hasattr(artifact, "parts") and artifact.parts: for part in artifact.parts: if hasattr(part, "root") and hasattr(part.root, "text"): content.append({"text": part.root.text}) @@ -161,7 +167,7 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: } # Build state dict with A2A metadata - state: dict = {} + state: dict[str, str] = {} if task_state is not None: state["a2a_task_state"] = task_state.value diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 273b7dcdf..7a67e3d64 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -102,16 +102,16 @@ async def execute( except ServerError: # Re-raise ServerErrors (setup failures like missing input) raise - except Exception as e: + except Exception: # Agent execution failures transition to failed state - logger.exception("Agent execution failed, transitioning task to failed state") + logger.exception("task_id=<%s> | agent execution failed, transitioning to failed state", task.id) try: await updater.failed( - message=updater.new_agent_message(parts=[Part(root=TextPart(text=f"Agent execution failed: {e}"))]) + message=updater.new_agent_message(parts=[Part(root=TextPart(text="Agent execution failed"))]) ) except RuntimeError: # Task already in terminal state (e.g., completed before error in cleanup) - logger.debug("Task already in terminal state, cannot transition to failed") + logger.debug("task_id=<%s> | task already in terminal state, cannot transition to failed", task.id) async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater) -> None: """Execute request in streaming mode. @@ -130,9 +130,9 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater if context.message and hasattr(context.message, "parts"): content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts) if not content_blocks: - raise ServerError(error=InternalError()) + raise ServerError(error=InternalError()) from None else: - raise ServerError(error=InternalError()) + raise ServerError(error=InternalError()) from None if not self.enable_a2a_compliant_streaming: warnings.warn( @@ -270,18 +270,32 @@ async def _handle_agent_result(self, result: SAAgentResult | None, updater: Task async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """Cancel an ongoing execution. - Transitions the task to the canceled state. If the agent supports cancellation - (e.g., via a stop mechanism), this will signal the agent to stop processing. + Transitions the task to the canceled state and attempts to stop the agent. + The agent's cancel() method is called if available to signal cooperative + cancellation of in-flight execution. + + Note: This transitions the A2A task state. The underlying agent execution + may still complete its current model call before stopping. Args: context: The A2A request context. event_queue: The A2A event queue. + + Raises: + ServerError: If no current task exists or the task is already in a terminal state. """ task = context.current_task if not task: - logger.warning("Cancellation requested but no current task found") + logger.warning("cancel requested but no current task found") raise ServerError(error=UnsupportedOperationError()) from None + # Attempt to stop the agent if it supports cancellation + if hasattr(self.agent, "cancel") and callable(self.agent.cancel): + try: + self.agent.cancel() + except Exception: + logger.debug("task_id=<%s> | agent cancel signal failed (non-critical)", task.id) + updater = TaskUpdater(event_queue, task.id, task.context_id) try: @@ -289,8 +303,8 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None message=updater.new_agent_message(parts=[Part(root=TextPart(text="Task cancelled by client request"))]) ) except RuntimeError: - # Task already in terminal state - logger.warning("Cannot cancel task %s: already in terminal state", task.id) + # TaskUpdater raises RuntimeError when task is already in a terminal state + logger.warning("task_id=<%s> | cannot cancel, already in terminal state", task.id) raise ServerError(error=UnsupportedOperationError()) from None def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]: diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 63891e8bb..1299d38bc 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -1383,7 +1383,7 @@ async def mock_stream(content_blocks, **kwargs): e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed ] assert len(failed_events) == 1 - assert "Connection lost" in failed_events[0].status.message.parts[0].root.text + assert "Agent execution failed" in failed_events[0].status.message.parts[0].root.text @pytest.mark.asyncio @@ -1588,3 +1588,99 @@ async def test_execute_setup_failure_raises_server_error(mock_strands_agent, moc await executor.execute(mock_request_context, mock_event_queue) assert isinstance(excinfo.value.error, InternalError) + + +@pytest.mark.asyncio +async def test_execute_error_when_task_already_terminal(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that error during execution is handled gracefully when task is already in terminal state.""" + from a2a.types import TextPart + + # Make stream_async raise but also make the event queue raise RuntimeError + # (simulating task already in terminal state when we try to mark as failed) + mock_strands_agent.stream_async = MagicMock(side_effect=Exception("Agent error")) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-already-done" + mock_task.context_id = "ctx-already-done" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # Simulate task already in terminal state by making enqueue raise RuntimeError + # after the first call (task creation) + call_count = [0] + original_enqueue = mock_event_queue.enqueue_event + + async def enqueue_with_terminal_error(event): + call_count[0] += 1 + if call_count[0] > 1: + # Simulate RuntimeError from TaskUpdater terminal state check + raise RuntimeError("Task test-task-id is already in a terminal state.") + return await original_enqueue(event) + + mock_event_queue.enqueue_event = enqueue_with_terminal_error + + # Should NOT raise - handles RuntimeError gracefully + await executor.execute(mock_request_context, mock_event_queue) + + +@pytest.mark.asyncio +async def test_cancel_calls_agent_cancel_method(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel() attempts to call agent.cancel() if available.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + # Give the agent a cancel method + mock_strands_agent.cancel = MagicMock() + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancel-agent" + mock_task.context_id = "ctx-cancel-agent" + mock_request_context.current_task = mock_task + + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify agent.cancel() was called + mock_strands_agent.cancel.assert_called_once() + + # Verify task state is canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_handles_agent_cancel_exception(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel() gracefully handles agent.cancel() raising an exception.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + # Give the agent a cancel method that raises + mock_strands_agent.cancel = MagicMock(side_effect=RuntimeError("Cannot cancel")) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancel-err" + mock_task.context_id = "ctx-cancel-err" + mock_request_context.current_task = mock_task + + # Should still succeed (agent cancel is best-effort) + await executor.cancel(mock_request_context, mock_event_queue) + + # Task should still be transitioned to canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 From 2c8539f82b814c4765120834f205b65e0911ad1c Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral Date: Mon, 4 May 2026 20:26:30 +0000 Subject: [PATCH 4/6] fix: address mkmeral review comments + increase codecov Changes: 1. InternalError now includes descriptive messages: - 'No valid content found in request message parts' - 'Request message is missing or has no parts' (was: bare InternalError() with no explanation) 2. State sets derived from _STATE_TO_STOP_REASON (single source of truth): - _TERMINAL_STATES = {s for s, r in _STATE_TO_STOP_REASON.items() if r == 'end_turn'} - _INPUT_STATES = {s for s, r in _STATE_TO_STOP_REASON.items() if r == 'interrupt'} - Removed manual TaskState enum listing (no more duplication) 3. Structured logging on cancel no-task path: - context_id=<%s> | cancel requested but no current task found 4. Coverage: executor.py now 98% (0 missing statements) - Fixed test_execute_error_when_task_already_terminal: patches TaskUpdater - Added test_cancel_raises_when_task_already_terminal: covers RuntimeError path - 182 tests pass --- src/strands/agent/a2a_agent.py | 26 ++++------ src/strands/multiagent/a2a/executor.py | 8 +-- tests/strands/multiagent/a2a/test_executor.py | 51 +++++++++++++------ 3 files changed, 50 insertions(+), 35 deletions(-) diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py index 6b2e0bc01..eeb96f7a2 100644 --- a/src/strands/agent/a2a_agent.py +++ b/src/strands/agent/a2a_agent.py @@ -15,10 +15,14 @@ import httpx from a2a.client import A2ACardResolver, ClientConfig, ClientFactory -from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent +from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskStatusUpdateEvent from .._async import run_async -from ..multiagent.a2a._converters import convert_input_to_message, convert_response_to_agent_result +from ..multiagent.a2a._converters import ( + _STATE_TO_STOP_REASON, + convert_input_to_message, + convert_response_to_agent_result, +) from ..types._events import AgentResultEvent from ..types.a2a import A2AResponse, A2AStreamEvent from ..types.agent import AgentInput @@ -30,20 +34,10 @@ _DEFAULT_TIMEOUT = 300 # A2A task states that indicate the response stream is complete. -# Terminal states mean no more events; input states mean execution is paused. -# Derived from _STATE_TO_STOP_REASON in _converters to maintain single source of truth. -_TERMINAL_STATES = { - TaskState.completed, - TaskState.failed, - TaskState.canceled, - TaskState.rejected, -} - -_INPUT_STATES = { - TaskState.input_required, - TaskState.auth_required, -} - +# Derived from the canonical _STATE_TO_STOP_REASON mapping in _converters. +# Terminal states (end_turn) mean no more events; input states (interrupt) mean execution is paused. +_TERMINAL_STATES = {state for state, reason in _STATE_TO_STOP_REASON.items() if reason == "end_turn"} +_INPUT_STATES = {state for state, reason in _STATE_TO_STOP_REASON.items() if reason == "interrupt"} _COMPLETE_STATES = _TERMINAL_STATES | _INPUT_STATES diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 7a67e3d64..72c339b9c 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -130,9 +130,11 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater if context.message and hasattr(context.message, "parts"): content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts) if not content_blocks: - raise ServerError(error=InternalError()) from None + raise ServerError( + error=InternalError(message="No valid content found in request message parts") + ) from None else: - raise ServerError(error=InternalError()) from None + raise ServerError(error=InternalError(message="Request message is missing or has no parts")) from None if not self.enable_a2a_compliant_streaming: warnings.warn( @@ -286,7 +288,7 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None """ task = context.current_task if not task: - logger.warning("cancel requested but no current task found") + logger.warning("context_id=<%s> | cancel requested but no current task found", context.context_id) raise ServerError(error=UnsupportedOperationError()) from None # Attempt to stop the agent if it supports cancellation diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 1299d38bc..778885ac4 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -1595,8 +1595,7 @@ async def test_execute_error_when_task_already_terminal(mock_strands_agent, mock """Test that error during execution is handled gracefully when task is already in terminal state.""" from a2a.types import TextPart - # Make stream_async raise but also make the event queue raise RuntimeError - # (simulating task already in terminal state when we try to mark as failed) + # Make stream_async raise to trigger the error path mock_strands_agent.stream_async = MagicMock(side_effect=Exception("Agent error")) executor = StrandsA2AExecutor(mock_strands_agent) @@ -1614,22 +1613,18 @@ async def test_execute_error_when_task_already_terminal(mock_strands_agent, mock mock_message.parts = [mock_part] mock_request_context.message = mock_message - # Simulate task already in terminal state by making enqueue raise RuntimeError - # after the first call (task creation) - call_count = [0] - original_enqueue = mock_event_queue.enqueue_event + # Patch TaskUpdater.failed to raise RuntimeError (simulating task already in terminal state) + with patch("strands.multiagent.a2a.executor.TaskUpdater") as MockTaskUpdater: + mock_updater = MagicMock() + mock_updater.failed = AsyncMock(side_effect=RuntimeError("Task is already in a terminal state")) + mock_updater.new_agent_message = MagicMock(return_value=MagicMock()) + MockTaskUpdater.return_value = mock_updater - async def enqueue_with_terminal_error(event): - call_count[0] += 1 - if call_count[0] > 1: - # Simulate RuntimeError from TaskUpdater terminal state check - raise RuntimeError("Task test-task-id is already in a terminal state.") - return await original_enqueue(event) - - mock_event_queue.enqueue_event = enqueue_with_terminal_error + # Should NOT raise - handles RuntimeError gracefully + await executor.execute(mock_request_context, mock_event_queue) - # Should NOT raise - handles RuntimeError gracefully - await executor.execute(mock_request_context, mock_event_queue) + # Verify failed() was attempted + mock_updater.failed.assert_called_once() @pytest.mark.asyncio @@ -1684,3 +1679,27 @@ async def test_cancel_handles_agent_cancel_exception(mock_strands_agent, mock_re e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled ] assert len(canceled_events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_raises_when_task_already_terminal(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel() raises ServerError when task is already in a terminal state.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-terminal" + mock_task.context_id = "ctx-terminal" + mock_request_context.current_task = mock_task + + # Patch TaskUpdater.cancel to raise RuntimeError (task already completed/failed) + with patch("strands.multiagent.a2a.executor.TaskUpdater") as MockTaskUpdater: + mock_updater = MagicMock() + mock_updater.cancel = AsyncMock(side_effect=RuntimeError("Task is already in a terminal state")) + mock_updater.new_agent_message = MagicMock(return_value=MagicMock()) + MockTaskUpdater.return_value = mock_updater + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + assert isinstance(excinfo.value.error, UnsupportedOperationError) + mock_updater.cancel.assert_called_once() From a66e210c2f661ea36aa211d92070a257c85cc8d0 Mon Sep 17 00:00:00 2001 From: Strands Agent Date: Mon, 4 May 2026 20:54:16 +0000 Subject: [PATCH 5/6] =?UTF-8?q?fix:=20address=20devil's=20advocate=20findi?= =?UTF-8?q?ngs=20=E2=80=94=20critical=20test=20gaps=20and=20code=20bugs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Devil's Advocate Review Findings Addressed: Critical (2): 1. asyncio.CancelledError now transitions task to 'canceled' state before re-raising. Previously, CancelledError (BaseException, not Exception) would propagate uncaught, leaving the A2A task stuck in 'working' forever (zombie). - Added explicit 'except asyncio.CancelledError' handler in execute() - Transitions to canceled, then re-raises for framework cleanup - Handles edge case where task is already terminal (RuntimeError) 2. stop_reason='interrupt' with empty/None interrupts list no longer silently completes the task. The stop_reason is now the authoritative signal — if the agent says 'interrupt', we transition to input_required regardless of whether the interrupts list is populated. - Removed 'and result.interrupts' from the condition - Added fallback message: 'Agent requires additional input to continue' Major (3): 3. test_convert_response_completed_state now asserts result.state metadata (was the only lifecycle test missing this assertion) 4. Added test for TaskState.unknown → end_turn default behavior 5. Added test_state_to_stop_reason_covers_all_lifecycle_states (guards against future a2a-sdk additions we miss) Minor (2): 6. Added test_extract_task_state_from_artifact_update_returns_none 7. Added parametrized test covering ALL 9 TaskState values for _is_complete_event (replaces verbose individual tests) Code fixes: - cancel(): Removed hasattr/callable duck-typing (nit from review), now uses try/except (AttributeError, NotImplementedError) directly - Added 'import asyncio' to executor.py Tests: 201 pass (was 182) --- src/strands/multiagent/a2a/executor.py | 42 +++- tests/strands/agent/test_a2a_agent.py | 52 +++- .../strands/multiagent/a2a/test_converters.py | 124 +++++++++ tests/strands/multiagent/a2a/test_executor.py | 238 ++++++++++++++++++ 4 files changed, 447 insertions(+), 9 deletions(-) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 72c339b9c..9dfe0dec3 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -8,6 +8,7 @@ streamed requests to the A2AServer. """ +import asyncio import base64 import json import logging @@ -102,6 +103,21 @@ async def execute( except ServerError: # Re-raise ServerErrors (setup failures like missing input) raise + except asyncio.CancelledError: + # asyncio.CancelledError is a BaseException (not Exception) — raised when + # the asyncio task is cancelled (e.g., HTTP client disconnect, server shutdown). + # We transition to canceled state so the task doesn't remain a zombie in "working". + logger.warning("task_id=<%s> | asyncio task cancelled, transitioning to canceled state", task.id) + try: + await updater.cancel( + message=updater.new_agent_message( + parts=[Part(root=TextPart(text="Task cancelled due to connection termination"))] + ) + ) + except RuntimeError: + # Task already in terminal state + logger.debug("task_id=<%s> | task already in terminal state, cannot transition to canceled", task.id) + raise except Exception: # Agent execution failures transition to failed state logger.exception("task_id=<%s> | agent execution failed, transitioning to failed state", task.id) @@ -163,7 +179,9 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater await self._handle_streaming_event(event, updater) # Check if agent returned with interrupts (input_required) - if result is not None and result.stop_reason == "interrupt" and result.interrupts: + # Note: stop_reason="interrupt" is the authoritative signal. Even if interrupts + # list is empty (edge case), the agent still indicated it needs input. + if result is not None and result.stop_reason == "interrupt": await self._handle_interrupt_result(result, updater) else: await self._handle_agent_result(result, updater) @@ -194,7 +212,12 @@ async def _handle_interrupt_result(self, result: SAAgentResult, updater: TaskUpd desc += f": {interrupt.reason}" interrupt_descriptions.append(desc) - input_message = "Agent requires input:\n" + "\n".join(interrupt_descriptions) + if interrupt_descriptions: + input_message = "Agent requires input:\n" + "\n".join(interrupt_descriptions) + else: + # Edge case: stop_reason="interrupt" but no interrupt details provided. + # Still transition to input_required — the agent signaled it needs input. + input_message = "Agent requires additional input to continue" await updater.requires_input(message=updater.new_agent_message(parts=[Part(root=TextPart(text=input_message))])) @@ -291,12 +314,15 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None logger.warning("context_id=<%s> | cancel requested but no current task found", context.context_id) raise ServerError(error=UnsupportedOperationError()) from None - # Attempt to stop the agent if it supports cancellation - if hasattr(self.agent, "cancel") and callable(self.agent.cancel): - try: - self.agent.cancel() - except Exception: - logger.debug("task_id=<%s> | agent cancel signal failed (non-critical)", task.id) + # Attempt to cooperatively cancel the agent's execution (best-effort). + # Agent.cancel() may not exist on all implementations, so we guard with hasattr. + try: + self.agent.cancel() + except (AttributeError, NotImplementedError): + # Agent doesn't support cancel — proceed with state transition only + pass + except Exception: + logger.debug("task_id=<%s> | agent cancel signal failed (non-critical)", task.id) updater = TaskUpdater(event_queue, task.id, task.context_id) diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py index 5bcd35b96..9c3be7917 100644 --- a/tests/strands/agent/test_a2a_agent.py +++ b/tests/strands/agent/test_a2a_agent.py @@ -7,7 +7,7 @@ import pytest from a2a.client import ClientConfig -from a2a.types import AgentCard, Message, Part, Role, TextPart +from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart from strands.agent.a2a_agent import A2AAgent from strands.agent.agent_result import AgentResult @@ -824,3 +824,53 @@ def test_is_complete_event_submitted_state_not_complete(a2a_agent): update_event.status = status assert a2a_agent._is_complete_event((task, update_event)) is False + + +# ========================================================================= +# DEVIL'S ADVOCATE FINDINGS — Tests addressing review gaps +# ========================================================================= + + +@pytest.mark.parametrize( + "state,expected_complete", + [ + (TaskState.completed, True), + (TaskState.failed, True), + (TaskState.canceled, True), + (TaskState.rejected, True), + (TaskState.input_required, True), + (TaskState.auth_required, True), + (TaskState.working, False), + (TaskState.submitted, False), + (TaskState.unknown, False), + ], + ids=[ + "completed-is-complete", + "failed-is-complete", + "canceled-is-complete", + "rejected-is-complete", + "input_required-is-complete", + "auth_required-is-complete", + "working-not-complete", + "submitted-not-complete", + "unknown-not-complete", + ], +) +def test_is_complete_event_all_states_parametrized(a2a_agent, state, expected_complete): + """Minor Finding 7: Parametrized test covering ALL TaskState values. + + This replaces verbose individual tests with a single parameterized test that + covers all 9 TaskState values. When a2a-sdk adds new states, adding a row here + is trivial. + """ + from unittest.mock import MagicMock + + from a2a.types import TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = state + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is expected_complete diff --git a/tests/strands/multiagent/a2a/test_converters.py b/tests/strands/multiagent/a2a/test_converters.py index cb2b97d51..fff48653b 100644 --- a/tests/strands/multiagent/a2a/test_converters.py +++ b/tests/strands/multiagent/a2a/test_converters.py @@ -402,3 +402,127 @@ def test_extract_task_state_from_message_returns_none(): message = MagicMock(spec=Message) state = _extract_task_state(message) assert state is None + + +# ========================================================================= +# DEVIL'S ADVOCATE FINDINGS — Tests addressing review gaps +# ========================================================================= + + +def test_convert_response_completed_state_includes_state_metadata(): + """Major Finding 3: The completed state test was missing state assertion. + + Every other state test asserts both stop_reason AND result.state, but the most + important one (completed — the happy path) was missing the state check. This ensures + downstream consumers relying on result.state["a2a_task_state"] won't break silently. + """ + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.completed, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "completed" # THIS WAS MISSING + + +def test_convert_response_unknown_state_defaults_to_end_turn(): + """Major Finding 4: TaskState.unknown should default to end_turn. + + The a2a-sdk has a TaskState.unknown value. Our code handles it via the .get() + default ("end_turn"). This test documents that this is an intentional design + decision: unknown states are treated as terminal completions rather than errors. + + Rationale: An unknown state from a remote server is ambiguous. Treating it as + end_turn (completed) is the safest default — the client won't hang waiting for + more events, and the result content (if any) is still accessible. + """ + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.unknown, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + # unknown is NOT in _STATE_TO_STOP_REASON, so defaults to "end_turn" + assert result.stop_reason == "end_turn" + # state metadata should reflect the actual state value + assert result.state.get("a2a_task_state") == "unknown" + + +def test_convert_response_working_state_defaults_to_end_turn(): + """Test that working state (not in mapping) defaults to end_turn. + + This covers the edge case where a TaskStatusUpdateEvent with state=working + somehow reaches the converter (shouldn't normally happen since _is_complete_event + filters these out, but defense-in-depth). + """ + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.working, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "working" + + +def test_extract_task_state_from_artifact_update_returns_none(): + """Minor Finding 5: _extract_task_state with TaskArtifactUpdateEvent returns None. + + This is the untested path where the update event is an artifact (not status). + """ + from unittest.mock import MagicMock + + from a2a.types import TaskArtifactUpdateEvent + + from strands.multiagent.a2a._converters import _extract_task_state + + task = MagicMock() + mock_event = MagicMock(spec=TaskArtifactUpdateEvent) + + state = _extract_task_state((task, mock_event)) + assert state is None + + +def test_state_to_stop_reason_covers_all_lifecycle_states(): + """Verify _STATE_TO_STOP_REASON has mappings for all documented lifecycle states. + + Guards against future additions to the a2a-sdk that we miss. + """ + from a2a.types import TaskState + + from strands.multiagent.a2a._converters import _STATE_TO_STOP_REASON + + # These are the states we explicitly handle + expected_mapped = { + TaskState.completed, + TaskState.failed, + TaskState.canceled, + TaskState.rejected, + TaskState.input_required, + TaskState.auth_required, + } + assert set(_STATE_TO_STOP_REASON.keys()) == expected_mapped + + # These should NOT be in the mapping (they're non-terminal progress states) + assert TaskState.working not in _STATE_TO_STOP_REASON + assert TaskState.submitted not in _STATE_TO_STOP_REASON + assert TaskState.unknown not in _STATE_TO_STOP_REASON diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 778885ac4..940d26f8c 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -1703,3 +1703,241 @@ async def test_cancel_raises_when_task_already_terminal(mock_strands_agent, mock assert isinstance(excinfo.value.error, UnsupportedOperationError) mock_updater.cancel.assert_called_once() + + +# ========================================================================= +# DEVIL'S ADVOCATE FINDINGS — Tests addressing review gaps +# ========================================================================= + + +@pytest.mark.asyncio +async def test_execute_handles_asyncio_cancelled_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Critical Finding 1: asyncio.CancelledError transitions task to canceled state. + + asyncio.CancelledError is a BaseException (not Exception). It's raised when an asyncio + task is cancelled — e.g., HTTP client disconnect, server shutdown, task group cancellation. + Without explicit handling, the task would remain stuck in 'working' state forever (zombie). + + This test verifies the task transitions to 'canceled' before re-raising CancelledError. + """ + import asyncio + + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + async def mock_stream(content_blocks, **kwargs): + """Mock streaming that gets cancelled mid-stream.""" + yield {"data": "partial output"} + raise asyncio.CancelledError() + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancelled" + mock_task.context_id = "ctx-cancelled" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # CancelledError should be re-raised (framework needs to know task was cancelled) + with pytest.raises(asyncio.CancelledError): + await executor.execute(mock_request_context, mock_event_queue) + + # But BEFORE re-raising, the task should have been transitioned to canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + assert ( + "cancelled" in canceled_events[0].status.message.parts[0].root.text.lower() + or "connection termination" in canceled_events[0].status.message.parts[0].root.text.lower() + ) + + +@pytest.mark.asyncio +async def test_execute_asyncio_cancelled_when_task_already_terminal( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test CancelledError handling when task is already in a terminal state. + + If the task completed right before the cancellation arrives, the updater.cancel() + will raise RuntimeError. We should handle this gracefully and still re-raise CancelledError. + """ + import asyncio + + from a2a.types import TextPart + + async def mock_stream(content_blocks, **kwargs): + """Async generator that immediately raises CancelledError.""" + yield {"data": "partial"} # Must yield to be async generator + raise asyncio.CancelledError() + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancelled-terminal" + mock_task.context_id = "ctx-cancelled-terminal" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # Patch TaskUpdater to simulate task already in terminal state + with patch("strands.multiagent.a2a.executor.TaskUpdater") as MockTaskUpdater: + mock_updater = MagicMock() + mock_updater.cancel = AsyncMock(side_effect=RuntimeError("Task is already in a terminal state")) + mock_updater.update_status = AsyncMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.new_agent_message = MagicMock(return_value=MagicMock()) + mock_updater.context_id = "ctx-cancelled-terminal" + mock_updater.task_id = "task-cancelled-terminal" + MockTaskUpdater.return_value = mock_updater + + # Should still re-raise CancelledError + with pytest.raises(asyncio.CancelledError): + await executor.execute(mock_request_context, mock_event_queue) + + # cancel() was attempted + mock_updater.cancel.assert_called_once() + + +@pytest.mark.asyncio +async def test_execute_with_interrupt_empty_list_transitions_to_input_required( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Critical Finding 2: stop_reason='interrupt' with empty interrupts list. + + The agent explicitly signaled it needs input (stop_reason="interrupt") but provided + no interrupt details. This should STILL transition to input_required — the stop_reason + is the authoritative signal. Previously this would silently complete the task. + """ + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_result.interrupts = [] # Empty list — previously this was falsy and caused completion! + + async def mock_stream(content_blocks, **kwargs): + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-empty-interrupts" + mock_task.context_id = "ctx-empty-interrupts" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "do something" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Should transition to input_required, NOT completed + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + completed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.completed + ] + + assert len(input_required_events) == 1, "Empty interrupts list should still trigger input_required" + assert len(completed_events) == 0, "Should NOT complete when stop_reason='interrupt'" + # Verify the fallback message is used + assert "additional input" in input_required_events[0].status.message.parts[0].root.text.lower() + + +@pytest.mark.asyncio +async def test_execute_with_interrupt_none_list_transitions_to_input_required( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Edge case: stop_reason='interrupt' with interrupts=None. + + Same logic — the stop_reason is authoritative. None interrupts should + still result in input_required transition. + """ + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_result.interrupts = None # None, not empty list + + async def mock_stream(content_blocks, **kwargs): + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-none-interrupts" + mock_task.context_id = "ctx-none-interrupts" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "do something" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_without_hasattr_cancel(mock_strands_agent, mock_request_context, mock_event_queue): + """Test cancel works when agent doesn't have cancel() method (AttributeError).""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + # Remove cancel method entirely + del mock_strands_agent.cancel + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-no-cancel-method" + mock_task.context_id = "ctx-no-cancel-method" + mock_request_context.current_task = mock_task + + # Should succeed — AttributeError from agent.cancel() is caught + await executor.cancel(mock_request_context, mock_event_queue) + + # Task should still be transitioned to canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 From e6b132369a19e370883aafc92efd87113e8b5157 Mon Sep 17 00:00:00 2001 From: agent-of-mkmeral <217235299+strands-agent@users.noreply.github.com> Date: Wed, 6 May 2026 22:42:03 +0000 Subject: [PATCH 6/6] fix(a2a): remove unnecessary duck-typing in cancel(), fix stale comment - Remove AttributeError/NotImplementedError catch since self.agent is typed as Agent which always has cancel() - Update comment from 'we guard with hasattr' to reflect actual impl - Update docstring to remove 'if available' language Addresses review feedback from @mkmeral on executor.py:327 --- src/strands/multiagent/a2a/executor.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 9dfe0dec3..7526386e8 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -296,8 +296,8 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None """Cancel an ongoing execution. Transitions the task to the canceled state and attempts to stop the agent. - The agent's cancel() method is called if available to signal cooperative - cancellation of in-flight execution. + The agent's cancel() method is called to signal cooperative cancellation + of in-flight execution. Note: This transitions the A2A task state. The underlying agent execution may still complete its current model call before stopping. @@ -314,13 +314,10 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None logger.warning("context_id=<%s> | cancel requested but no current task found", context.context_id) raise ServerError(error=UnsupportedOperationError()) from None - # Attempt to cooperatively cancel the agent's execution (best-effort). - # Agent.cancel() may not exist on all implementations, so we guard with hasattr. + # Cooperatively cancel the agent's execution (best-effort). + # Agent.cancel() is always available since self.agent is typed as Agent. try: self.agent.cancel() - except (AttributeError, NotImplementedError): - # Agent doesn't support cancel — proceed with state transition only - pass except Exception: logger.debug("task_id=<%s> | agent cancel signal failed (non-critical)", task.id)