Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/google/adk/agents/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,14 @@ class InvocationContext(BaseModel):
agent_states: dict[str, dict[str, Any]] = Field(default_factory=dict)
"""The state of the agent for this invocation."""

request_state: dict[str, Any] = Field(default_factory=dict)
"""The ephemeral state of the request.

This state is not persisted to the session and is only available for the
current invocation. It is used to pass sensitive information like tokens
that should not be stored in the session state.
"""

end_of_agents: dict[str, bool] = Field(default_factory=dict)
"""The end of agent status for each agent in this invocation."""

Expand Down
16 changes: 14 additions & 2 deletions src/google/adk/agents/readonly_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from collections import ChainMap
from types import MappingProxyType
from typing import Any
from typing import Optional
Expand Down Expand Up @@ -53,8 +54,19 @@ def agent_name(self) -> str:

@property
def state(self) -> MappingProxyType[str, Any]:
"""The state of the current session. READONLY field."""
return MappingProxyType(self._invocation_context.session.state)
"""The state of the current session. READONLY field.

Note: This property returns a merged view of ephemeral request_state and
persistent session.state using ChainMap. Changes to the underlying
request_state or session.state dictionaries will be reflected through
this view, but direct writes through this property are prevented.
"""
return MappingProxyType(
ChainMap(
self._invocation_context.request_state,
self._invocation_context.session.state,
)
)

@property
def session(self) -> Session:
Expand Down
3 changes: 3 additions & 0 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ class RunAgentRequest(common.BaseModel):
state_delta: Optional[dict[str, Any]] = None
# for long-running function resume requests (e.g., OAuth callback)
function_call_event_id: Optional[str] = None
request_state: Optional[dict[str, Any]] = None
# for resume long-running functions
invocation_id: Optional[str] = None

Expand Down Expand Up @@ -1899,6 +1900,7 @@ async def run_agent(req: RunAgentRequest) -> list[Event]:
new_message=req.new_message,
state_delta=req.state_delta,
invocation_id=req.invocation_id,
request_state=req.request_state,
)
) as agen:
events = [event async for event in agen]
Expand Down Expand Up @@ -1942,6 +1944,7 @@ async def event_generator():
state_delta=req.state_delta,
run_config=RunConfig(streaming_mode=stream_mode),
invocation_id=req.invocation_id,
request_state=req.request_state,
)
) as agen:
try:
Expand Down
39 changes: 33 additions & 6 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ async def run_async(
invocation_id: Optional[str] = None,
new_message: Optional[types.Content] = None,
state_delta: Optional[dict[str, Any]] = None,
request_state: Optional[dict[str, Any]] = None,
run_config: Optional[RunConfig] = None,
) -> AsyncGenerator[Event, None]:
"""Main entry method to run the agent in this runner.
Expand All @@ -524,6 +525,7 @@ async def run_async(
interrupted invocation.
new_message: A new message to append to the session.
state_delta: Optional state changes to apply to the session.
request_state: Optional ephemeral state for the request.
run_config: The run config for the agent.

Yields:
Expand Down Expand Up @@ -559,18 +561,32 @@ async def _run_with_trace(
is_resumable = (
self.resumability_config and self.resumability_config.is_resumable
)
if not is_resumable and not new_message:
raise ValueError(
'Running an agent requires a new_message or a resumable app. '
f'Session: {session_id}, User: {user_id}'
if invocation_id:
if not is_resumable:
raise ValueError(
f'invocation_id: {invocation_id} is provided but the app is not'
' resumable.'
)
invocation_context = await self._setup_context_for_resumed_invocation(
session=session,
new_message=new_message,
invocation_id=invocation_id,
run_config=run_config,
state_delta=state_delta,
request_state=request_state,
)

if not is_resumable:
elif not is_resumable:
if not new_message:
raise ValueError(
'Running an agent requires a new_message or a resumable app. '
f'Session: {session_id}, User: {user_id}'
)
invocation_context = await self._setup_context_for_new_invocation(
session=session,
new_message=new_message,
run_config=run_config,
state_delta=state_delta,
request_state=request_state,
)
else:
invocation_id = self._resolve_invocation_id(
Expand All @@ -582,6 +598,7 @@ async def _run_with_trace(
new_message=new_message,
run_config=run_config,
state_delta=state_delta,
request_state=request_state,
)
else:
invocation_context = (
Expand All @@ -591,6 +608,7 @@ async def _run_with_trace(
invocation_id=invocation_id,
run_config=run_config,
state_delta=state_delta,
request_state=request_state,
)
)
if invocation_context.end_of_agents.get(
Expand Down Expand Up @@ -1334,6 +1352,7 @@ async def _setup_context_for_new_invocation(
new_message: types.Content,
run_config: RunConfig,
state_delta: Optional[dict[str, Any]],
request_state: Optional[dict[str, Any]] = None,
) -> InvocationContext:
"""Sets up the context for a new invocation.

Expand All @@ -1342,6 +1361,7 @@ async def _setup_context_for_new_invocation(
new_message: The new message to process and append to the session.
run_config: The run config of the agent.
state_delta: Optional state changes to apply to the session.
request_state: Optional ephemeral state for the request.

Returns:
The invocation context for the new invocation.
Expand All @@ -1351,6 +1371,7 @@ async def _setup_context_for_new_invocation(
session,
new_message=new_message,
run_config=run_config,
request_state=request_state,
)
# Step 2: Handle new message, by running callbacks and appending to
# session.
Expand All @@ -1373,6 +1394,7 @@ async def _setup_context_for_resumed_invocation(
invocation_id: Optional[str],
run_config: RunConfig,
state_delta: Optional[dict[str, Any]],
request_state: Optional[dict[str, Any]] = None,
) -> InvocationContext:
"""Sets up the context for a resumed invocation.

Expand All @@ -1382,6 +1404,7 @@ async def _setup_context_for_resumed_invocation(
invocation_id: The invocation id to resume.
run_config: The run config of the agent.
state_delta: Optional state changes to apply to the session.
request_state: Optional ephemeral state for the request.

Returns:
The invocation context for the resumed invocation.
Expand All @@ -1407,6 +1430,7 @@ async def _setup_context_for_resumed_invocation(
new_message=user_message,
run_config=run_config,
invocation_id=invocation_id,
request_state=request_state,
)
# Step 3: Maybe handle new message.
if new_message:
Expand Down Expand Up @@ -1455,6 +1479,7 @@ def _new_invocation_context(
new_message: Optional[types.Content] = None,
live_request_queue: Optional[LiveRequestQueue] = None,
run_config: Optional[RunConfig] = None,
request_state: Optional[dict[str, Any]] = None,
) -> InvocationContext:
"""Creates a new invocation context.

Expand All @@ -1464,6 +1489,7 @@ def _new_invocation_context(
new_message: The new message for the context.
live_request_queue: The live request queue for the context.
run_config: The run config for the context.
request_state: The ephemeral state for the request.

Returns:
The new invocation context.
Expand Down Expand Up @@ -1498,6 +1524,7 @@ def _new_invocation_context(
live_request_queue=live_request_queue,
run_config=run_config,
resumability_config=self.resumability_config,
request_state=request_state if request_state is not None else {},
)

def _new_invocation_context_for_live(
Expand Down
2 changes: 2 additions & 0 deletions src/google/adk/tools/mcp_tool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .mcp_session_manager import StreamableHTTPConnectionParams
from .mcp_tool import MCPTool
from .mcp_tool import McpTool
from .mcp_toolset import create_session_state_header_provider
from .mcp_toolset import MCPToolset
from .mcp_toolset import McpToolset

Expand All @@ -32,6 +33,7 @@
'MCPTool',
'McpToolset',
'MCPToolset',
'create_session_state_header_provider',
'SseConnectionParams',
'StdioConnectionParams',
'StreamableHTTPConnectionParams',
Expand Down
Loading