diff --git a/src/google/adk/flows/llm_flows/_code_execution.py b/src/google/adk/flows/llm_flows/_code_execution.py index 19ec4cc219..60fa4ecfe0 100644 --- a/src/google/adk/flows/llm_flows/_code_execution.py +++ b/src/google/adk/flows/llm_flows/_code_execution.py @@ -74,6 +74,91 @@ class DataFileUtil: ), } +_NON_BUILTIN_EXECUTOR_INSTRUCTION = """\ +# CRITICAL: Code execution format + +You have access to an external Python sandbox managed by the host +application. To run Python code, output it inside a fenced markdown +block exactly like this: + +```tool_code +print("hello") +``` + +DO NOT emit native executable_code parts. +DO NOT attempt to call a code_execution tool — no such tool is +registered for this request and the API will reject the response with +UNEXPECTED_TOOL_CALL or MALFORMED_FUNCTION_CALL. + +Always wrap Python code in the tool_code markdown fence shown +above. +""" + +# Recoverable API rejection codes for Gemini 2.x emitting code as a +# native tool call when no code_execution tool was declared. +_RECOVERABLE_API_ERRORS = frozenset( + {'UNEXPECTED_TOOL_CALL', 'MALFORMED_FUNCTION_CALL'} +) + +_UNEXPECTED_TOOL_CALL_RE = re.compile( + r'^\s*Unexpected tool call:\s*(?P.+?)\s*$', re.DOTALL +) + + +def _extract_code_from_error_message( + error_message: Optional[str], +) -> Optional[str]: + """Best-effort extraction of code from a Gemini API rejection error message.""" + if not error_message: + return None + m = _UNEXPECTED_TOOL_CALL_RE.match(error_message) + if m: + return m.group('code').strip() or None + return None + + +def _maybe_recover_from_api_rejection(llm_response: LlmResponse) -> bool: + """Recovers an executable_code part from a Gemini 2.x API rejection. + + When ADK uses a non-built-in code executor (e.g., + AgentEngineSandboxCodeExecutor) with Gemini 2.x, the model may emit a + native code_execution tool call. Because no such tool is declared in + the request, the server rejects the response with UNEXPECTED_TOOL_CALL + (or MALFORMED_FUNCTION_CALL when other tools are present), and + llm_response.content ends up empty. + + This function reconstructs the executable_code part the model intended + to emit so the existing post-processor pipeline can run it through the + configured sandbox executor. + + Returns True if recovery occurred and llm_response was mutated. + """ + error_code = llm_response.error_code + if error_code is None: + return False + error_code_name = getattr(error_code, 'name', str(error_code)) + if error_code_name not in _RECOVERABLE_API_ERRORS: + return False + + code_str = _extract_code_from_error_message(llm_response.error_message) + if not code_str: + return False + + llm_response.content = types.Content( + role='model', + parts=[CodeExecutionUtils.build_executable_code_part(code_str)], + ) + llm_response.error_code = None + llm_response.error_message = None + llm_response.finish_reason = None + logger.info( + 'Recovered code from API %s rejection; routing to configured' + ' code executor.', + error_code_name, + ) + return True + + _DATA_FILE_HELPER_LIB = ''' import pandas as pd @@ -114,7 +199,7 @@ def explore_df(df: pd.DataFrame) -> None: ''' -class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor): +class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor): # type: ignore[misc] """Processes code execution requests.""" @override @@ -148,7 +233,7 @@ async def run_async( request_processor = _CodeExecutionRequestProcessor() -class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor): +class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor): # type: ignore[misc] """Processes code execution responses.""" @override @@ -187,6 +272,13 @@ async def _run_pre_processor( code_executor.process_llm_request(llm_request) return + # Steer Gemini 2.x (and other modern models) away from emitting a + # native `executable_code` / code_execution tool call. When the + # configured executor is *not* the built-in one, no `code_execution` + # tool is declared on the request, and a native emission would be + # rejected by the API as UNEXPECTED_TOOL_CALL / MALFORMED_FUNCTION_CALL. + llm_request.append_instructions([_NON_BUILTIN_EXECUTOR_INSTRUCTION]) + if not code_executor.optimize_data_file: return @@ -265,7 +357,7 @@ async def _run_pre_processor( async def _run_post_processor( invocation_context: InvocationContext, - llm_response, + llm_response: LlmResponse, ) -> AsyncGenerator[Event, None]: """Post-process the model response by extracting and executing the first code block.""" agent = invocation_context.agent @@ -273,7 +365,21 @@ async def _run_post_processor( if not code_executor or not isinstance(code_executor, BaseCodeExecutor): return - if not llm_response or not llm_response.content: + if not llm_response: + return + + # When the API rejected the response because the model emitted a native + # code_execution tool call (UNEXPECTED_TOOL_CALL / MALFORMED_FUNCTION_CALL), + # llm_response.content is empty. For non-built-in executors, try to + # recover the intended code from the error message so we can still run + # it in the configured sandbox. + if not llm_response.content and not isinstance( + code_executor, BuiltInCodeExecutor + ): + if not _maybe_recover_from_api_rejection(llm_response): + return + + if not llm_response.content: return if isinstance(code_executor, BuiltInCodeExecutor): @@ -376,7 +482,7 @@ def _extract_and_replace_inline_files( llm_request: LlmRequest, ) -> list[File]: """Extracts and replaces inline files with file names in the LLM request.""" - all_input_files = code_executor_context.get_input_files() + all_input_files: list[File] = code_executor_context.get_input_files() saved_file_names = set(f.name for f in all_input_files) # [Step 1] Process input files from LlmRequest and cache them in CodeExecutor. @@ -425,7 +531,7 @@ def _get_or_set_execution_id( if not invocation_context.agent.code_executor.stateful: return None - execution_id = code_executor_context.get_execution_id() + execution_id: Optional[str] = code_executor_context.get_execution_id() if not execution_id: execution_id = invocation_context.session.id code_executor_context.set_execution_id(execution_id) @@ -517,7 +623,7 @@ def _get_normalized_file_name(file_name: str) -> str: return var_name if file.mime_type not in _DATA_FILE_UTIL_MAP: - return + return None var_name = _get_normalized_file_name(file.name) loader_code = _DATA_FILE_UTIL_MAP[file.mime_type].loader_code_template.format( diff --git a/tests/unittests/flows/llm_flows/test_code_execution.py b/tests/unittests/flows/llm_flows/test_code_execution.py index 69f2d7832d..dce7cc8f72 100644 --- a/tests/unittests/flows/llm_flows/test_code_execution.py +++ b/tests/unittests/flows/llm_flows/test_code_execution.py @@ -23,13 +23,41 @@ from google.adk.code_executors.base_code_executor import BaseCodeExecutor from google.adk.code_executors.built_in_code_executor import BuiltInCodeExecutor from google.adk.code_executors.code_execution_utils import CodeExecutionResult +from google.adk.flows.llm_flows._code_execution import _extract_code_from_error_message +from google.adk.flows.llm_flows._code_execution import _maybe_recover_from_api_rejection +from google.adk.flows.llm_flows._code_execution import _NON_BUILTIN_EXECUTOR_INSTRUCTION +from google.adk.flows.llm_flows._code_execution import request_processor from google.adk.flows.llm_flows._code_execution import response_processor +from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.genai import types import pytest from ... import testing_utils +# --------------------------------------------------------------------------- +# _extract_code_from_error_message +# --------------------------------------------------------------------------- + + +def test_extract_code_from_error_message_valid(): + code = _extract_code_from_error_message('Unexpected tool call: print(1+1)') + assert code == 'print(1+1)' + + +def test_extract_code_from_error_message_multiline(): + msg = 'Unexpected tool call: x = 1\nprint(x)' + code = _extract_code_from_error_message(msg) + assert code == 'x = 1\nprint(x)' + + +def test_extract_code_from_error_message_none(): + assert _extract_code_from_error_message(None) is None + + +def test_extract_code_from_error_message_no_match(): + assert _extract_code_from_error_message('some other error') is None + @pytest.mark.asyncio @patch('google.adk.flows.llm_flows._code_execution.datetime') @@ -150,3 +178,175 @@ async def test_logs_executed_code(mock_logger): mock_logger.debug.assert_called_once_with( 'Executed code:\n```\n%s\n```', 'print("hello")' ) + + +# --------------------------------------------------------------------------- +# _maybe_recover_from_api_rejection +# --------------------------------------------------------------------------- + + +def _make_rejected_response(error_code: str, code_snippet: str) -> LlmResponse: + return LlmResponse( + content=None, + error_code=error_code, + error_message=f'Unexpected tool call: {code_snippet}', + ) + + +def test_maybe_recover_unexpected_tool_call(): + llm_response = _make_rejected_response('UNEXPECTED_TOOL_CALL', 'print(42)') + recovered = _maybe_recover_from_api_rejection(llm_response) + + assert recovered is True + assert llm_response.content is not None + assert len(llm_response.content.parts) == 1 + assert llm_response.content.parts[0].executable_code.code == 'print(42)' + assert llm_response.error_code is None + assert llm_response.error_message is None + assert llm_response.finish_reason is None + + +def test_maybe_recover_malformed_function_call(): + llm_response = _make_rejected_response('MALFORMED_FUNCTION_CALL', 'x=1') + assert _maybe_recover_from_api_rejection(llm_response) is True + assert llm_response.content is not None + + +def test_maybe_recover_unrecognised_error_code(): + llm_response = _make_rejected_response('SAFETY', 'print(42)') + assert _maybe_recover_from_api_rejection(llm_response) is False + assert llm_response.content is None + + +def test_maybe_recover_no_error_code(): + llm_response = LlmResponse(content=None, error_code=None, error_message=None) + assert _maybe_recover_from_api_rejection(llm_response) is False + + +def test_maybe_recover_unparseable_message(): + llm_response = LlmResponse( + content=None, + error_code='UNEXPECTED_TOOL_CALL', + error_message='some completely different message', + ) + assert _maybe_recover_from_api_rejection(llm_response) is False + + +# --------------------------------------------------------------------------- +# Pre-processor: instruction injection +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_pre_processor_injects_instruction_for_non_builtin_executor(): + mock_executor = MagicMock(spec=BaseCodeExecutor) + mock_executor.optimize_data_file = False + + agent = Agent(name='test_agent', code_executor=mock_executor) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='run some code' + ) + llm_request = LlmRequest() + + _ = [ + event + async for event in request_processor.run_async( + invocation_context, llm_request + ) + ] + + assert llm_request.config.system_instruction is not None + assert _NON_BUILTIN_EXECUTOR_INSTRUCTION in str( + llm_request.config.system_instruction + ) + + +@pytest.mark.asyncio +async def test_pre_processor_does_not_inject_instruction_for_builtin_executor(): + code_executor = BuiltInCodeExecutor() + agent = Agent(name='test_agent', code_executor=code_executor) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='run some code' + ) + llm_request = LlmRequest(model='gemini-2.0-flash') + + _ = [ + event + async for event in request_processor.run_async( + invocation_context, llm_request + ) + ] + + system_instruction = str(llm_request.config.system_instruction or '') + assert _NON_BUILTIN_EXECUTOR_INSTRUCTION not in system_instruction + + +# --------------------------------------------------------------------------- +# Post-processor: API rejection recovery path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch('google.adk.flows.llm_flows._code_execution.logger') +async def test_post_processor_recovers_from_unexpected_tool_call(mock_logger): + mock_executor = MagicMock(spec=BaseCodeExecutor) + mock_executor.code_block_delimiters = [('```tool_code\n', '\n```')] + mock_executor.error_retry_attempts = 2 + mock_executor.stateful = False + mock_executor.execute_code.return_value = CodeExecutionResult(stdout='42') + + agent = Agent(name='test_agent', code_executor=mock_executor) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='run some code' + ) + invocation_context.artifact_service = MagicMock() + invocation_context.artifact_service.save_artifact = AsyncMock( + return_value='v1' + ) + + llm_response = LlmResponse( + content=None, + error_code='UNEXPECTED_TOOL_CALL', + error_message='Unexpected tool call: print(6*7)', + ) + + events = [ + event + async for event in response_processor.run_async( + invocation_context, llm_response + ) + ] + + mock_executor.execute_code.assert_called_once() + call_input = mock_executor.execute_code.call_args[0][1] + assert call_input.code == 'print(6*7)' + assert len(events) == 2 + mock_logger.info.assert_called_once() + + +@pytest.mark.asyncio +async def test_post_processor_skips_recovery_for_builtin_executor(): + code_executor = BuiltInCodeExecutor() + agent = Agent(name='test_agent', code_executor=code_executor) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='run some code' + ) + invocation_context.artifact_service = MagicMock() + invocation_context.artifact_service.save_artifact = AsyncMock() + + llm_response = LlmResponse( + content=None, + error_code='UNEXPECTED_TOOL_CALL', + error_message='Unexpected tool call: print(1)', + ) + + events = [ + event + async for event in response_processor.run_async( + invocation_context, llm_response + ) + ] + + # BuiltInCodeExecutor path bails out early — no events, no artifact saves. + assert events == [] + invocation_context.artifact_service.save_artifact.assert_not_called()