Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/itk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ jobs:
run: bash run_itk.sh
working-directory: itk
env:
A2A_SAMPLES_REVISION: itk-v.02-alpha
A2A_SAMPLES_REVISION: itk-v.021-alpha
2 changes: 1 addition & 1 deletion itk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ You must set the `A2A_SAMPLES_REVISION` environment variable to specify which re

Example:
```
export A2A_SAMPLES_REVISION=itk-v.02-alpha
export A2A_SAMPLES_REVISION=itk-v.021-alpha
```

### 2. Execute Tests
Expand Down
249 changes: 197 additions & 52 deletions itk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
import base64
import logging
import os
import signal
import uuid

import grpc
import httpx
import uvicorn

from fastapi import FastAPI
from typing import Any

from pyproto import instruction_pb2

from a2a.client import ClientConfig, create_client
from a2a.client import Client, ClientConfig, create_client
from a2a.client.errors import A2AClientError
from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
from a2a.server.agent_execution import AgentExecutor, RequestContext
Expand All @@ -36,9 +39,11 @@
AgentCapabilities,
AgentCard,
AgentInterface,
CancelTaskRequest,
Message,
Part,
SendMessageRequest,
SubscribeToTaskRequest,
Task,
TaskState,
TaskStatus,
Expand Down Expand Up @@ -98,6 +103,95 @@ def extract_instruction(
return None


def _extract_text_from_event(event: Any) -> list[str]:
"""Extracts text parts from an event's message."""
if isinstance(event, tuple):
results = []
for item in event:
results.extend(_extract_text_from_event(item))
return results

message = None
if hasattr(event, 'HasField'):
if event.HasField('message'):
message = event.message
elif event.HasField('task') and event.task.status.HasField('message'):
message = event.task.status.message
elif event.HasField(
'status_update'
) and event.status_update.status.HasField('message'):
message = event.status_update.status.message

results = []
if message:
results.extend(part.text for part in message.parts if part.text)
return results


async def _handle_call_agent_with_resubscribe(
client: Client, request: SendMessageRequest
) -> list[str]:
"""Handles the send-disconnect-resubscribe flow."""
results = []
logger.info('Executing re-subscribe behavior')
agen = client.send_message(request)
task_id = None

async for event in agen:
logger.info('Event before disconnect: %s', event)
if event.HasField('task'):
task_id = event.task.id
elif event.HasField('status_update'):
task_id = event.status_update.task_id
break

await agen.aclose()
logger.info('Disconnected from task %s. Now re-subscribing.', task_id)

resub_agen = client.subscribe(SubscribeToTaskRequest(id=task_id))

task_obj = None
finished = False
async for event in resub_agen:
logger.info('Event after re-subscribe: %s', event)
if hasattr(event, 'HasField') and event.HasField('task'):
task_obj = event.task

extracted_text = _extract_text_from_event(event)
for text in extracted_text:
processed_text = text.replace('task-finished', '')
results.append(processed_text)
if any('task-finished' in text for text in extracted_text):
logger.info(
'Received task-finished after re-subscribe, breaking loop.'
)
finished = True
break

if not results and task_obj and hasattr(task_obj, 'history'):
logger.info('Results empty after loop, reading from history.')
for msg in task_obj.history:
# Check stringified role to support protobuf enums (2 for ROLE_AGENT in v0.3 and v1.0)
# as well as string descriptors from dict/JSON forms.
if str(msg.role) in {'2', 'ROLE_AGENT', 'agent'}:
results.extend(
part.text.replace('task-finished', '')
for part in msg.parts
if part.text
)

if not finished:
logger.info('Canceling task %s after retrieval.', task_id)
try:
await client.cancel_task(CancelTaskRequest(id=task_id))
logger.info('Task cancelled successfully: %s', task_id)
except A2AClientError:
logger.exception('Failed to cancel task %s', task_id)
raise

return results


def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message:
"""Wraps an Instruction proto into an A2A Message."""
inst_bytes = inst.SerializeToString()
Expand Down Expand Up @@ -129,18 +223,22 @@ async def handle_call_agent(
'GRPC': TransportProtocol.GRPC,
}

selected_transport = transport_map.get(call.transport.upper())
selected_transport = transport_map.get(
call.transport.upper(), TransportProtocol.JSONRPC
)
if selected_transport is None:
raise ValueError(f'Unsupported transport: {call.transport}')

config = ClientConfig()
config.httpx_client = httpx.AsyncClient(timeout=30.0)
config.grpc_channel_factory = grpc.aio.insecure_channel
config.supported_protocol_bindings = [selected_transport]
config.streaming = call.streaming or (
selected_transport == TransportProtocol.GRPC
)

if call.HasField('resubscribe') and not config.streaming:
raise ValueError('Re-subscription requires streaming to be enabled')

if call.HasField('push_notification'):
url = call.push_notification.url
if not url:
Expand All @@ -152,44 +250,45 @@ async def handle_call_agent(
token='itk-token', # noqa: S106
)

try:
client = await create_client(
call.agent_card_uri,
client_config=config,
)
async with httpx.AsyncClient(timeout=30.0) as httpx_client:
config.httpx_client = httpx_client
try:
client = await create_client(
call.agent_card_uri,
client_config=config,
)

# Wrap nested instruction
nested_msg = wrap_instruction_to_request(call.instruction)
request = SendMessageRequest(message=nested_msg)
# Wrap nested instruction
nested_msg = wrap_instruction_to_request(call.instruction)
request = SendMessageRequest(message=nested_msg)

results = []
async for event in client.send_message(request):
# Event is streaming response and task
logger.info('Event: %s', event)
stream_resp = event

message = None
if stream_resp.HasField('message'):
message = stream_resp.message
elif stream_resp.HasField(
'task'
) and stream_resp.task.status.HasField('message'):
message = stream_resp.task.status.message
elif stream_resp.HasField(
'status_update'
) and stream_resp.status_update.status.HasField('message'):
message = stream_resp.status_update.status.message

if message:
results.extend(part.text for part in message.parts if part.text)

except Exception as e:
logger.exception('Failed to call outbound agent')
raise RuntimeError(
f'Outbound call to {call.agent_card_uri} failed: {e!s}'
) from e
else:
return results
results = []

if call.HasField('resubscribe'):
results.extend(
await _handle_call_agent_with_resubscribe(client, request)
)
else:
async for event in client.send_message(request):
logger.info('Event: %s', event)
results.extend(_extract_text_from_event(event))

except Exception as e:
logger.exception('Failed to call outbound agent')
raise RuntimeError(
f'Outbound call to {call.agent_card_uri} failed: {e!s}'
) from e
else:
return results


def _should_hold(inst: instruction_pb2.Instruction) -> bool:
"""Recursively checks if any part of the instruction requests holding the task."""
if inst.HasField('return_response') and inst.return_response.hold_task:
return True
if inst.HasField('steps'):
return any(_should_hold(step) for step in inst.steps.instructions)
return False


async def handle_instruction(
Expand Down Expand Up @@ -245,23 +344,58 @@ async def execute(
)
return

should_hold_task = _should_hold(instruction)

try:
logger.info('Instruction: %s', instruction)
results = await handle_instruction(instruction)

response_text = '\n'.join(results)
logger.info('Response: %s', response_text)
await task_updater.update_status(
TaskState.TASK_STATE_COMPLETED,
message=task_updater.new_agent_message(
[Part(text=response_text)]
),
)
logger.info('Task %s completed', context.task_id)
except Exception as e:

if should_hold_task:
logger.info('Holding task %s as requested', context.task_id)
# Emitted event: response + task-finished
logger.info(
'Emitting response and task-finished for held task %s',
context.task_id,
)
await task_updater.update_status(
TaskState.TASK_STATE_WORKING,
message=task_updater.new_agent_message(
[Part(text=response_text + '\n' + 'task-finished')]
),
)
await asyncio.sleep(2)

# Continue emitting "task-finished" every 2 seconds
try:
while True:
logger.info(
'Emitting periodic status update for held task %s',
context.task_id,
)
await task_updater.update_status(
TaskState.TASK_STATE_WORKING,
message=None,
)
await asyncio.sleep(2)
except asyncio.CancelledError:
logger.info('Task %s cancelled', context.task_id)
return
else:
await task_updater.update_status(
TaskState.TASK_STATE_COMPLETED,
message=task_updater.new_agent_message(
[Part(text=response_text)]
),
)
logger.info('Task %s completed', context.task_id)
except Exception:
logger.exception('Error during instruction handling')
await task_updater.update_status(
TaskState.TASK_STATE_FAILED,
message=task_updater.new_agent_message([Part(text=str(e))]),
message=None,
)

async def cancel(
Expand Down Expand Up @@ -325,18 +459,17 @@ async def main_async(http_port: int, grpc_port: int) -> None:
name='ITK v10 Agent',
description='Python agent using SDK 1.0.',
version='1.0.0',
capabilities=AgentCapabilities(
streaming=True, push_notifications=True, extended_agent_card=True
),
capabilities=AgentCapabilities(streaming=True),
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
supported_interfaces=interfaces,
)

task_store = InMemoryTaskStore()
push_config_store = InMemoryPushNotificationConfigStore()
httpx_client = httpx.AsyncClient()
push_sender = BasePushNotificationSender(
httpx_client=httpx.AsyncClient(),
httpx_client=httpx_client,
config_store=push_config_store,
)
Comment thread
kdziedzic70 marked this conversation as resolved.

Expand Down Expand Up @@ -400,6 +533,18 @@ async def main_async(http_port: int, grpc_port: int) -> None:
)
uvicorn_server = uvicorn.Server(config)

# Signal handling
loop = asyncio.get_running_loop()

async def shutdown() -> None:
logger.info('Shutting down...')
uvicorn_server.should_exit = True
await server.stop(5)
await httpx_client.aclose()

for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, lambda: asyncio.create_task(shutdown()))

await uvicorn_server.serve()


Expand Down
18 changes: 18 additions & 0 deletions itk/run_itk.sh
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,24 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \
"edges": ["0->1", "0->2", "1->0", "2->0"],
"protocols": ["http_json"],
"behavior": "push_notification"
},
{
"name": "Resubscribe Test - JSONRPC",
"sdks": ["current", "python_v10", "python_v03", "go_v10", "go_v03"],
"traversal": "euler",
"edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"],
"protocols": ["jsonrpc"],
"streaming": true,
"behavior": "resubscribe"
},
{
"name": "Resubscribe Test - Python & Go Non-JSONRPC Protocols",
"sdks": ["current", "python_v10", "python_v03", "go_v10"],
"traversal": "euler",
"edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"],
"protocols": ["grpc", "http_json"],
"streaming": true,
"behavior": "resubscribe"
}
]
}')
Expand Down
Loading