Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 113 additions & 7 deletions src/google/adk/flows/llm_flows/_code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,91 @@ class DataFileUtil:
),
}

_NON_BUILTIN_EXECUTOR_INSTRUCTION = """\
# CRITICAL: Code execution format

You have access to an external Python sandbox managed by the host
application. To run Python code, output it inside a fenced markdown
block exactly like this:

```tool_code
print("hello")
```

DO NOT emit native executable_code parts.
DO NOT attempt to call a code_execution tool — no such tool is
registered for this request and the API will reject the response with
UNEXPECTED_TOOL_CALL or MALFORMED_FUNCTION_CALL.

Always wrap Python code in the tool_code markdown fence shown
above.
"""

# Recoverable API rejection codes for Gemini 2.x emitting code as a
# native tool call when no code_execution tool was declared.
_RECOVERABLE_API_ERRORS = frozenset(
{'UNEXPECTED_TOOL_CALL', 'MALFORMED_FUNCTION_CALL'}
)

_UNEXPECTED_TOOL_CALL_RE = re.compile(
r'^\s*Unexpected tool call:\s*(?P<code>.+?)\s*$', re.DOTALL
)


def _extract_code_from_error_message(
error_message: Optional[str],
) -> Optional[str]:
"""Best-effort extraction of code from a Gemini API rejection error message."""
if not error_message:
return None
m = _UNEXPECTED_TOOL_CALL_RE.match(error_message)
if m:
return m.group('code').strip() or None
return None


def _maybe_recover_from_api_rejection(llm_response: LlmResponse) -> bool:
"""Recovers an executable_code part from a Gemini 2.x API rejection.

When ADK uses a non-built-in code executor (e.g.,
AgentEngineSandboxCodeExecutor) with Gemini 2.x, the model may emit a
native code_execution tool call. Because no such tool is declared in
the request, the server rejects the response with UNEXPECTED_TOOL_CALL
(or MALFORMED_FUNCTION_CALL when other tools are present), and
llm_response.content ends up empty.

This function reconstructs the executable_code part the model intended
to emit so the existing post-processor pipeline can run it through the
configured sandbox executor.

Returns True if recovery occurred and llm_response was mutated.
"""
error_code = llm_response.error_code
if error_code is None:
return False
error_code_name = getattr(error_code, 'name', str(error_code))
if error_code_name not in _RECOVERABLE_API_ERRORS:
return False

code_str = _extract_code_from_error_message(llm_response.error_message)
if not code_str:
return False

llm_response.content = types.Content(
role='model',
parts=[CodeExecutionUtils.build_executable_code_part(code_str)],
)
llm_response.error_code = None
llm_response.error_message = None
llm_response.finish_reason = None
logger.info(
'Recovered code from API %s rejection; routing to configured'
' code executor.',
error_code_name,
)
return True


_DATA_FILE_HELPER_LIB = '''
import pandas as pd

Expand Down Expand Up @@ -114,7 +199,7 @@ def explore_df(df: pd.DataFrame) -> None:
'''


class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor): # type: ignore[misc]
"""Processes code execution requests."""

@override
Expand Down Expand Up @@ -148,7 +233,7 @@ async def run_async(
request_processor = _CodeExecutionRequestProcessor()


class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor): # type: ignore[misc]
"""Processes code execution responses."""

@override
Expand Down Expand Up @@ -187,6 +272,13 @@ async def _run_pre_processor(
code_executor.process_llm_request(llm_request)
return

# Steer Gemini 2.x (and other modern models) away from emitting a
# native `executable_code` / code_execution tool call. When the
# configured executor is *not* the built-in one, no `code_execution`
# tool is declared on the request, and a native emission would be
# rejected by the API as UNEXPECTED_TOOL_CALL / MALFORMED_FUNCTION_CALL.
llm_request.append_instructions([_NON_BUILTIN_EXECUTOR_INSTRUCTION])

if not code_executor.optimize_data_file:
return

Expand Down Expand Up @@ -265,15 +357,29 @@ async def _run_pre_processor(

async def _run_post_processor(
invocation_context: InvocationContext,
llm_response,
llm_response: LlmResponse,
) -> AsyncGenerator[Event, None]:
"""Post-process the model response by extracting and executing the first code block."""
agent = invocation_context.agent
code_executor = agent.code_executor

if not code_executor or not isinstance(code_executor, BaseCodeExecutor):
return
if not llm_response or not llm_response.content:
if not llm_response:
return

# When the API rejected the response because the model emitted a native
# code_execution tool call (UNEXPECTED_TOOL_CALL / MALFORMED_FUNCTION_CALL),
# llm_response.content is empty. For non-built-in executors, try to
# recover the intended code from the error message so we can still run
# it in the configured sandbox.
if not llm_response.content and not isinstance(
code_executor, BuiltInCodeExecutor
):
if not _maybe_recover_from_api_rejection(llm_response):
return

if not llm_response.content:
return

if isinstance(code_executor, BuiltInCodeExecutor):
Expand Down Expand Up @@ -376,7 +482,7 @@ def _extract_and_replace_inline_files(
llm_request: LlmRequest,
) -> list[File]:
"""Extracts and replaces inline files with file names in the LLM request."""
all_input_files = code_executor_context.get_input_files()
all_input_files: list[File] = code_executor_context.get_input_files()
saved_file_names = set(f.name for f in all_input_files)

# [Step 1] Process input files from LlmRequest and cache them in CodeExecutor.
Expand Down Expand Up @@ -425,7 +531,7 @@ def _get_or_set_execution_id(
if not invocation_context.agent.code_executor.stateful:
return None

execution_id = code_executor_context.get_execution_id()
execution_id: Optional[str] = code_executor_context.get_execution_id()
if not execution_id:
execution_id = invocation_context.session.id
code_executor_context.set_execution_id(execution_id)
Expand Down Expand Up @@ -517,7 +623,7 @@ def _get_normalized_file_name(file_name: str) -> str:
return var_name

if file.mime_type not in _DATA_FILE_UTIL_MAP:
return
return None

var_name = _get_normalized_file_name(file.name)
loader_code = _DATA_FILE_UTIL_MAP[file.mime_type].loader_code_template.format(
Expand Down
Loading