Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 48 additions & 30 deletions src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import override

import scale_gp_beta.lib.tracing as tracing
Expand Down Expand Up @@ -125,48 +127,64 @@ def _add_source_to_span(self, span: Span) -> None:

@override
async def on_span_start(self, span: Span) -> None:
self._add_source_to_span(span)
sgp_span = create_span(
name=span.name,
span_type=_get_span_type(span),
span_id=span.id,
parent_id=span.parent_id,
trace_id=span.trace_id,
input=span.input,
output=span.output,
metadata=span.data,
)
sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr]
await self.on_spans_start([span])

@override
async def on_span_end(self, span: Span) -> None:
await self.on_spans_end([span])

@override
async def on_spans_start(self, spans: list[Span]) -> None:
if not spans:
return

sgp_spans: list[SGPSpan] = []
for span in spans:
self._add_source_to_span(span)
sgp_span = create_span(
name=span.name,
span_type=_get_span_type(span),
span_id=span.id,
parent_id=span.parent_id,
trace_id=span.trace_id,
input=span.input,
output=span.output,
metadata=span.data,
)
sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr]
self._spans[span.id] = sgp_span
sgp_spans.append(sgp_span)

if self.disabled:
logger.warning("SGP is disabled, skipping span upsert")
return
# TODO(AGX1-198): Batch multiple spans into a single upsert_batch call
# instead of one span per HTTP request.
# https://linear.app/scale-epd/issue/AGX1-198/actually-use-sgp-batching-for-spans
await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr]
items=[sgp_span.to_request_params()]
items=[s.to_request_params() for s in sgp_spans]
)

self._spans[span.id] = sgp_span

@override
async def on_span_end(self, span: Span) -> None:
sgp_span = self._spans.pop(span.id, None)
if sgp_span is None:
logger.warning(f"Span {span.id} not found in stored spans, skipping span end")
async def on_spans_end(self, spans: list[Span]) -> None:
if not spans:
return

self._add_source_to_span(span)
sgp_span.input = span.input # type: ignore[assignment]
sgp_span.output = span.output # type: ignore[assignment]
sgp_span.metadata = span.data # type: ignore[assignment]
sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr]

if self.disabled:
to_upsert: list[SGPSpan] = []
for span in spans:
sgp_span = self._spans.pop(span.id, None)
if sgp_span is None:
logger.warning(f"Span {span.id} not found in stored spans, skipping span end")
continue

self._add_source_to_span(span)
sgp_span.input = span.input # type: ignore[assignment]
sgp_span.output = span.output # type: ignore[assignment]
sgp_span.metadata = span.data # type: ignore[assignment]
sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr]
to_upsert.append(sgp_span)

if self.disabled or not to_upsert:
return
await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr]
items=[sgp_span.to_request_params()]
items=[s.to_request_params() for s in to_upsert]
)

@override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from __future__ import annotations

import asyncio
from abc import ABC, abstractmethod

from agentex.types.span import Span
from agentex.lib.types.tracing import TracingProcessorConfig
from agentex.lib.utils.logging import make_logger

logger = make_logger(__name__)


class SyncTracingProcessor(ABC):
Expand Down Expand Up @@ -35,6 +41,43 @@ async def on_span_start(self, span: Span) -> None:
async def on_span_end(self, span: Span) -> None:
pass

async def on_spans_start(self, spans: list[Span]) -> None:
"""Batched variant of on_span_start.

Default fallback fans out to the single-span method in parallel so
existing processors keep working unchanged. Processors that support
real batching (e.g. sending all spans in one HTTP call) should
override this to avoid the per-span round trip.

Per-span exceptions are captured and logged individually so that one
failing span does not prevent the others from being processed.
"""
results = await asyncio.gather(
*(self.on_span_start(s) for s in spans), return_exceptions=True
)
for span, result in zip(spans, results):
if isinstance(result, Exception):
logger.error(
"Tracing processor %s failed on_span_start for span %s",
type(self).__name__,
span.id,
exc_info=result,
)

async def on_spans_end(self, spans: list[Span]) -> None:
"""Batched variant of on_span_end. See on_spans_start for details."""
results = await asyncio.gather(
*(self.on_span_end(s) for s in spans), return_exceptions=True
)
for span, result in zip(spans, results):
if isinstance(result, Exception):
logger.error(
"Tracing processor %s failed on_span_end for span %s",
type(self).__name__,
span.id,
exc_info=result,
)

@abstractmethod
async def shutdown(self) -> None:
pass
43 changes: 27 additions & 16 deletions src/agentex/lib/core/tracing/span_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,29 +95,40 @@ async def _drain_loop(self) -> None:

@staticmethod
async def _process_items(items: list[_SpanQueueItem]) -> None:
"""Process a list of span events concurrently."""
"""Dispatch a batch of same-event-type items to each processor in one call.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it guaranteed that all these items are the same event-type?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, see we dispatch them ourselves


async def _handle(item: _SpanQueueItem) -> None:
Groups spans by processor so each processor sees its full slice of the
drain batch at once. Processors that override the batched methods can
then send a single HTTP request per drain cycle instead of N.
"""
if not items:
return

event_type = items[0].event_type
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all the event_types the same?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe they should be but added an assert to check

assert all(i.event_type == event_type for i in items), (
"_process_items requires all items to share the same event_type; "
"callers must split START and END batches before dispatching."
)
by_processor: dict[AsyncTracingProcessor, list[Span]] = {}
for item in items:
for p in item.processors:
by_processor.setdefault(p, []).append(item.span)

async def _handle(p: AsyncTracingProcessor, spans: list[Span]) -> None:
try:
if item.event_type == SpanEventType.START:
coros = [p.on_span_start(item.span) for p in item.processors]
if event_type == SpanEventType.START:
await p.on_spans_start(spans)
else:
coros = [p.on_span_end(item.span) for p in item.processors]
results = await asyncio.gather(*coros, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
logger.error(
"Tracing processor error during %s for span %s",
item.event_type.value,
item.span.id,
exc_info=result,
)
await p.on_spans_end(spans)
except Exception:
logger.exception(
"Unexpected error in span queue for span %s", item.span.id
"Tracing processor %s failed handling %d spans during %s",
type(p).__name__,
len(spans),
event_type.value,
)

await asyncio.gather(*[_handle(item) for item in items])
await asyncio.gather(*[_handle(p, spans) for p, spans in by_processor.items()])

# ------------------------------------------------------------------
# Shutdown
Expand Down
39 changes: 39 additions & 0 deletions tests/lib/core/tracing/processors/test_sgp_tracing_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,42 @@ async def test_sgp_span_input_updated_on_end(self):
assert len(processor._spans) == 0
# The end upsert should have been called
assert processor.sgp_async_client.spans.upsert_batch.call_count == 2 # start + end

async def test_on_spans_start_sends_single_upsert_for_batch(self):
"""Given N spans at once, on_spans_start should make ONE upsert_batch HTTP call."""
processor, _ = self._make_processor()

n = 10
spans = [_make_span() for _ in range(n)]
with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()):
await processor.on_spans_start(spans)

assert processor.sgp_async_client.spans.upsert_batch.call_count == 1, (
"Batched on_spans_start must make exactly one upsert_batch HTTP call"
)
items = processor.sgp_async_client.spans.upsert_batch.call_args.kwargs["items"]
assert len(items) == n
# All spans should be tracked for the subsequent end call
assert len(processor._spans) == n

async def test_on_spans_end_sends_single_upsert_for_batch(self):
"""Given N spans at once, on_spans_end should make ONE upsert_batch HTTP call."""
processor, _ = self._make_processor()

n = 10
spans = [_make_span() for _ in range(n)]
with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()):
await processor.on_spans_start(spans)

processor.sgp_async_client.spans.upsert_batch.reset_mock()

for span in spans:
span.end_time = datetime.now(UTC)
await processor.on_spans_end(spans)

assert processor.sgp_async_client.spans.upsert_batch.call_count == 1, (
"Batched on_spans_end must make exactly one upsert_batch HTTP call"
)
items = processor.sgp_async_client.spans.upsert_batch.call_args.kwargs["items"]
assert len(items) == n
assert len(processor._spans) == 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

import uuid
import logging
from typing import override
from datetime import UTC, datetime

from agentex.types.span import Span
from agentex.lib.types.tracing import TracingProcessorConfig
from agentex.lib.core.tracing.processors.tracing_processor_interface import (
AsyncTracingProcessor,
)


def _make_span(span_id: str | None = None) -> Span:
return Span(
id=span_id or str(uuid.uuid4()),
name="test-span",
start_time=datetime.now(UTC),
trace_id="trace-1",
)


class _RecordingProcessor(AsyncTracingProcessor):
"""Test processor that records every on_span_* call and fails on demand."""

def __init__(self, fail_ids: set[str] | None = None) -> None:
self.started_ids: list[str] = []
self.ended_ids: list[str] = []
self._fail_ids = fail_ids or set()

@override
async def on_span_start(self, span: Span) -> None:
self.started_ids.append(span.id)
if span.id in self._fail_ids:
raise RuntimeError(f"boom-start-{span.id}")

@override
async def on_span_end(self, span: Span) -> None:
self.ended_ids.append(span.id)
if span.id in self._fail_ids:
raise RuntimeError(f"boom-end-{span.id}")

@override
async def shutdown(self) -> None:
pass


class TestDefaultBatchedFanout:
"""The default on_spans_start / on_spans_end in AsyncTracingProcessor must:
- dispatch to the single-span method for every span
- continue after individual failures (not short-circuit)
- log each failure individually
- not propagate exceptions to the caller
"""

async def test_on_spans_start_runs_every_span_despite_failures(self, caplog):
proc = _RecordingProcessor(fail_ids={"span-1"})
spans = [_make_span(f"span-{i}") for i in range(3)]

with caplog.at_level(logging.ERROR):
# Must not raise, even though span-1 fails.
await proc.on_spans_start(spans)

# Every span's on_span_start was invoked
assert proc.started_ids == ["span-0", "span-1", "span-2"]

async def test_on_spans_start_logs_each_failure(self, caplog):
proc = _RecordingProcessor(fail_ids={"span-0", "span-2"})
spans = [_make_span(f"span-{i}") for i in range(3)]

with caplog.at_level(logging.ERROR):
await proc.on_spans_start(spans)

# Two distinct error log records, one per failing span
error_records = [r for r in caplog.records if r.levelno == logging.ERROR]
messages = " ".join(r.getMessage() for r in error_records)
assert "span-0" in messages
assert "span-2" in messages

async def test_on_spans_end_runs_every_span_despite_failures(self, caplog):
proc = _RecordingProcessor(fail_ids={"span-1"})
spans = [_make_span(f"span-{i}") for i in range(3)]

with caplog.at_level(logging.ERROR):
await proc.on_spans_end(spans)

assert proc.ended_ids == ["span-0", "span-1", "span-2"]

async def test_dummy_config_construction(self):
"""AsyncTracingProcessor's __init__ is abstract — verify concrete
subclass above satisfies the interface."""
_ = TracingProcessorConfig
proc = _RecordingProcessor()
await proc.on_spans_start([])
await proc.on_spans_end([])
assert proc.started_ids == []
assert proc.ended_ids == []
Loading
Loading