From 0a01256b8b5d9d582d586dc1d61ee32ae878d2d5 Mon Sep 17 00:00:00 2001 From: alvinkam2001 Date: Tue, 21 Apr 2026 17:33:40 -0400 Subject: [PATCH 1/5] batch call for traces --- .../processors/sgp_tracing_processor.py | 76 +++++++++++-------- .../processors/tracing_processor_interface.py | 15 ++++ src/agentex/lib/core/tracing/span_queue.py | 39 ++++++---- .../processors/test_sgp_tracing_processor.py | 39 ++++++++++ tests/lib/core/tracing/test_span_queue.py | 71 +++++++++++++++++ 5 files changed, 194 insertions(+), 46 deletions(-) diff --git a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py index 1376df06c..3d172c46c 100644 --- a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py +++ b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py @@ -125,48 +125,64 @@ def _add_source_to_span(self, span: Span) -> None: @override async def on_span_start(self, span: Span) -> None: - self._add_source_to_span(span) - sgp_span = create_span( - name=span.name, - span_type=_get_span_type(span), - span_id=span.id, - parent_id=span.parent_id, - trace_id=span.trace_id, - input=span.input, - output=span.output, - metadata=span.data, - ) - sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr] + await self.on_spans_start([span]) + + @override + async def on_span_end(self, span: Span) -> None: + await self.on_spans_end([span]) + + @override + async def on_spans_start(self, spans: list[Span]) -> None: + if not spans: + return + + sgp_spans: list[SGPSpan] = [] + for span in spans: + self._add_source_to_span(span) + sgp_span = create_span( + name=span.name, + span_type=_get_span_type(span), + span_id=span.id, + parent_id=span.parent_id, + trace_id=span.trace_id, + input=span.input, + output=span.output, + metadata=span.data, + ) + sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr] + self._spans[span.id] = sgp_span + sgp_spans.append(sgp_span) if self.disabled: logger.warning("SGP is disabled, skipping span upsert") return - # TODO(AGX1-198): Batch multiple spans into a single upsert_batch call - # instead of one span per HTTP request. - # https://linear.app/scale-epd/issue/AGX1-198/actually-use-sgp-batching-for-spans await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr] - items=[sgp_span.to_request_params()] + items=[s.to_request_params() for s in sgp_spans] ) - self._spans[span.id] = sgp_span - @override - async def on_span_end(self, span: Span) -> None: - sgp_span = self._spans.pop(span.id, None) - if sgp_span is None: - logger.warning(f"Span {span.id} not found in stored spans, skipping span end") + async def on_spans_end(self, spans: list[Span]) -> None: + if not spans: return - self._add_source_to_span(span) - sgp_span.input = span.input # type: ignore[assignment] - sgp_span.output = span.output # type: ignore[assignment] - sgp_span.metadata = span.data # type: ignore[assignment] - sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr] - - if self.disabled: + to_upsert: list[SGPSpan] = [] + for span in spans: + sgp_span = self._spans.pop(span.id, None) + if sgp_span is None: + logger.warning(f"Span {span.id} not found in stored spans, skipping span end") + continue + + self._add_source_to_span(span) + sgp_span.input = span.input # type: ignore[assignment] + sgp_span.output = span.output # type: ignore[assignment] + sgp_span.metadata = span.data # type: ignore[assignment] + sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr] + to_upsert.append(sgp_span) + + if self.disabled or not to_upsert: return await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr] - items=[sgp_span.to_request_params()] + items=[s.to_request_params() for s in to_upsert] ) @override diff --git a/src/agentex/lib/core/tracing/processors/tracing_processor_interface.py b/src/agentex/lib/core/tracing/processors/tracing_processor_interface.py index 4ab85dcf4..623e7bd51 100644 --- a/src/agentex/lib/core/tracing/processors/tracing_processor_interface.py +++ b/src/agentex/lib/core/tracing/processors/tracing_processor_interface.py @@ -1,3 +1,4 @@ +import asyncio from abc import ABC, abstractmethod from agentex.types.span import Span @@ -35,6 +36,20 @@ async def on_span_start(self, span: Span) -> None: async def on_span_end(self, span: Span) -> None: pass + async def on_spans_start(self, spans: list[Span]) -> None: + """Batched variant of on_span_start. + + Default fallback fans out to the single-span method in parallel so + existing processors keep working unchanged. Processors that support + real batching (e.g. sending all spans in one HTTP call) should + override this to avoid the per-span round trip. + """ + await asyncio.gather(*(self.on_span_start(s) for s in spans), return_exceptions=False) + + async def on_spans_end(self, spans: list[Span]) -> None: + """Batched variant of on_span_end. See on_spans_start for details.""" + await asyncio.gather(*(self.on_span_end(s) for s in spans), return_exceptions=False) + @abstractmethod async def shutdown(self) -> None: pass diff --git a/src/agentex/lib/core/tracing/span_queue.py b/src/agentex/lib/core/tracing/span_queue.py index d5d09dd0f..5dd3f2aec 100644 --- a/src/agentex/lib/core/tracing/span_queue.py +++ b/src/agentex/lib/core/tracing/span_queue.py @@ -95,29 +95,36 @@ async def _drain_loop(self) -> None: @staticmethod async def _process_items(items: list[_SpanQueueItem]) -> None: - """Process a list of span events concurrently.""" + """Dispatch a batch of same-event-type items to each processor in one call. - async def _handle(item: _SpanQueueItem) -> None: + Groups spans by processor so each processor sees its full slice of the + drain batch at once. Processors that override the batched methods can + then send a single HTTP request per drain cycle instead of N. + """ + if not items: + return + + event_type = items[0].event_type + by_processor: dict[AsyncTracingProcessor, list[Span]] = {} + for item in items: + for p in item.processors: + by_processor.setdefault(p, []).append(item.span) + + async def _handle(p: AsyncTracingProcessor, spans: list[Span]) -> None: try: - if item.event_type == SpanEventType.START: - coros = [p.on_span_start(item.span) for p in item.processors] + if event_type == SpanEventType.START: + await p.on_spans_start(spans) else: - coros = [p.on_span_end(item.span) for p in item.processors] - results = await asyncio.gather(*coros, return_exceptions=True) - for result in results: - if isinstance(result, Exception): - logger.error( - "Tracing processor error during %s for span %s", - item.event_type.value, - item.span.id, - exc_info=result, - ) + await p.on_spans_end(spans) except Exception: logger.exception( - "Unexpected error in span queue for span %s", item.span.id + "Tracing processor %s failed handling %d spans during %s", + type(p).__name__, + len(spans), + event_type.value, ) - await asyncio.gather(*[_handle(item) for item in items]) + await asyncio.gather(*[_handle(p, spans) for p, spans in by_processor.items()]) # ------------------------------------------------------------------ # Shutdown diff --git a/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py b/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py index 818fed375..50d615e0d 100644 --- a/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py +++ b/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py @@ -188,3 +188,42 @@ async def test_sgp_span_input_updated_on_end(self): assert len(processor._spans) == 0 # The end upsert should have been called assert processor.sgp_async_client.spans.upsert_batch.call_count == 2 # start + end + + async def test_on_spans_start_sends_single_upsert_for_batch(self): + """Given N spans at once, on_spans_start should make ONE upsert_batch HTTP call.""" + processor, _ = self._make_processor() + + n = 10 + spans = [_make_span() for _ in range(n)] + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + await processor.on_spans_start(spans) + + assert processor.sgp_async_client.spans.upsert_batch.call_count == 1, ( + "Batched on_spans_start must make exactly one upsert_batch HTTP call" + ) + items = processor.sgp_async_client.spans.upsert_batch.call_args.kwargs["items"] + assert len(items) == n + # All spans should be tracked for the subsequent end call + assert len(processor._spans) == n + + async def test_on_spans_end_sends_single_upsert_for_batch(self): + """Given N spans at once, on_spans_end should make ONE upsert_batch HTTP call.""" + processor, _ = self._make_processor() + + n = 10 + spans = [_make_span() for _ in range(n)] + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + await processor.on_spans_start(spans) + + processor.sgp_async_client.spans.upsert_batch.reset_mock() + + for span in spans: + span.end_time = datetime.now(UTC) + await processor.on_spans_end(spans) + + assert processor.sgp_async_client.spans.upsert_batch.call_count == 1, ( + "Batched on_spans_end must make exactly one upsert_batch HTTP call" + ) + items = processor.sgp_async_client.spans.upsert_batch.call_args.kwargs["items"] + assert len(items) == n + assert len(processor._spans) == 0 diff --git a/tests/lib/core/tracing/test_span_queue.py b/tests/lib/core/tracing/test_span_queue.py index 4524ba187..b60b93095 100644 --- a/tests/lib/core/tracing/test_span_queue.py +++ b/tests/lib/core/tracing/test_span_queue.py @@ -21,9 +21,25 @@ def _make_span(span_id: str | None = None) -> Span: def _make_processor(**overrides: AsyncMock) -> AsyncMock: + """Build a mock processor compatible with the queue's batched dispatch. + + The queue now calls on_spans_start(list) / on_spans_end(list) on each + processor. Mirror the behavior of AsyncTracingProcessor's default fallback + by fanning out the list to per-span calls concurrently, so tests that + assert on on_span_start / on_span_end continue to observe per-span calls. + """ proc = AsyncMock() proc.on_span_start = overrides.get("on_span_start", AsyncMock()) proc.on_span_end = overrides.get("on_span_end", AsyncMock()) + + async def _fanout_start(spans: list[Span]) -> None: + await asyncio.gather(*(proc.on_span_start(s) for s in spans), return_exceptions=True) + + async def _fanout_end(spans: list[Span]) -> None: + await asyncio.gather(*(proc.on_span_end(s) for s in spans), return_exceptions=True) + + proc.on_spans_start = AsyncMock(side_effect=_fanout_start) + proc.on_spans_end = AsyncMock(side_effect=_fanout_end) return proc @@ -218,6 +234,61 @@ async def slow_start(span: Span) -> None: ) +class TestAsyncSpanQueueBatchedDispatch: + """The queue should dispatch a whole drain batch to each processor via the + batched methods (on_spans_start / on_spans_end) in one call per processor, + so processors that support real HTTP batching can send one request instead + of N. + """ + + async def test_batched_start_dispatch_single_call_per_drain(self): + received: list[list[str]] = [] + + async def capture_starts(spans: list[Span]) -> None: + received.append([s.id for s in spans]) + + proc = AsyncMock() + proc.on_spans_start = AsyncMock(side_effect=capture_starts) + proc.on_spans_end = AsyncMock() + + queue = AsyncSpanQueue() + + # Enqueue several spans synchronously before the drain has a chance to + # run — they should all land in a single drain batch. + ids = [f"span-{i}" for i in range(5)] + for i in ids: + queue.enqueue(SpanEventType.START, _make_span(i), [proc]) + + await queue.shutdown() + + # on_spans_start must have been called exactly once with all 5 spans. + assert proc.on_spans_start.call_count == 1, ( + f"Expected one batched call, got {proc.on_spans_start.call_count}" + ) + assert received == [ids] + + async def test_batched_end_dispatch_single_call_per_drain(self): + received: list[list[str]] = [] + + async def capture_ends(spans: list[Span]) -> None: + received.append([s.id for s in spans]) + + proc = AsyncMock() + proc.on_spans_start = AsyncMock() + proc.on_spans_end = AsyncMock(side_effect=capture_ends) + + queue = AsyncSpanQueue() + + ids = [f"span-{i}" for i in range(5)] + for i in ids: + queue.enqueue(SpanEventType.END, _make_span(i), [proc]) + + await queue.shutdown() + + assert proc.on_spans_end.call_count == 1 + assert received == [ids] + + class TestAsyncSpanQueueIntegration: async def test_integration_with_async_trace(self): call_log: list[tuple[str, str]] = [] From aea049bfadd6354128abd9165e33d11d18e84424 Mon Sep 17 00:00:00 2001 From: alvinkam2001 Date: Tue, 21 Apr 2026 17:37:18 -0400 Subject: [PATCH 2/5] fix linting issues --- .../lib/core/tracing/processors/sgp_tracing_processor.py | 2 ++ .../lib/core/tracing/processors/tracing_processor_interface.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py index 3d172c46c..187dedcbc 100644 --- a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py +++ b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import override import scale_gp_beta.lib.tracing as tracing diff --git a/src/agentex/lib/core/tracing/processors/tracing_processor_interface.py b/src/agentex/lib/core/tracing/processors/tracing_processor_interface.py index 623e7bd51..526464ecc 100644 --- a/src/agentex/lib/core/tracing/processors/tracing_processor_interface.py +++ b/src/agentex/lib/core/tracing/processors/tracing_processor_interface.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from abc import ABC, abstractmethod From dbce95c2bcb7c604afb473866b6ed9d4ef3f3cc8 Mon Sep 17 00:00:00 2001 From: alvinkam2001 Date: Wed, 22 Apr 2026 11:21:52 -0400 Subject: [PATCH 3/5] assert on event types --- src/agentex/lib/core/tracing/span_queue.py | 4 ++++ tests/lib/core/tracing/test_span_queue.py | 26 ++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/agentex/lib/core/tracing/span_queue.py b/src/agentex/lib/core/tracing/span_queue.py index 5dd3f2aec..d0d92669e 100644 --- a/src/agentex/lib/core/tracing/span_queue.py +++ b/src/agentex/lib/core/tracing/span_queue.py @@ -105,6 +105,10 @@ async def _process_items(items: list[_SpanQueueItem]) -> None: return event_type = items[0].event_type + assert all(i.event_type == event_type for i in items), ( + "_process_items requires all items to share the same event_type; " + "callers must split START and END batches before dispatching." + ) by_processor: dict[AsyncTracingProcessor, list[Span]] = {} for item in items: for p in item.processors: diff --git a/tests/lib/core/tracing/test_span_queue.py b/tests/lib/core/tracing/test_span_queue.py index b60b93095..0d6b2fd0a 100644 --- a/tests/lib/core/tracing/test_span_queue.py +++ b/tests/lib/core/tracing/test_span_queue.py @@ -234,6 +234,32 @@ async def slow_start(span: Span) -> None: ) +class TestProcessItemsPreconditions: + """_process_items assumes every item in the list has the same event_type. + Violating that precondition silently causes END events to be treated as + STARTs (or vice versa), which is a silent data-corruption bug. Guard it + with an assertion.""" + + async def test_mixed_event_types_raise_assertion(self): + from agentex.lib.core.tracing.span_queue import _SpanQueueItem + + proc = AsyncMock() + proc.on_spans_start = AsyncMock() + proc.on_spans_end = AsyncMock() + + mixed = [ + _SpanQueueItem(event_type=SpanEventType.START, span=_make_span("a"), processors=[proc]), + _SpanQueueItem(event_type=SpanEventType.END, span=_make_span("b"), processors=[proc]), + ] + + try: + await AsyncSpanQueue._process_items(mixed) + except AssertionError: + return + else: + raise AssertionError("Expected AssertionError for mixed event types") + + class TestAsyncSpanQueueBatchedDispatch: """The queue should dispatch a whole drain batch to each processor via the batched methods (on_spans_start / on_spans_end) in one call per processor, From 3c3bc63a1d2e5038f64a25387e1847bb1f37669f Mon Sep 17 00:00:00 2001 From: alvinkam2001 Date: Wed, 22 Apr 2026 11:22:28 -0400 Subject: [PATCH 4/5] tracing processor logger improvements --- .../processors/tracing_processor_interface.py | 30 +++++- .../test_tracing_processor_interface.py | 98 +++++++++++++++++++ 2 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 tests/lib/core/tracing/processors/test_tracing_processor_interface.py diff --git a/src/agentex/lib/core/tracing/processors/tracing_processor_interface.py b/src/agentex/lib/core/tracing/processors/tracing_processor_interface.py index 526464ecc..f352f38c4 100644 --- a/src/agentex/lib/core/tracing/processors/tracing_processor_interface.py +++ b/src/agentex/lib/core/tracing/processors/tracing_processor_interface.py @@ -5,6 +5,9 @@ from agentex.types.span import Span from agentex.lib.types.tracing import TracingProcessorConfig +from agentex.lib.utils.logging import make_logger + +logger = make_logger(__name__) class SyncTracingProcessor(ABC): @@ -45,12 +48,35 @@ async def on_spans_start(self, spans: list[Span]) -> None: existing processors keep working unchanged. Processors that support real batching (e.g. sending all spans in one HTTP call) should override this to avoid the per-span round trip. + + Per-span exceptions are captured and logged individually so that one + failing span does not prevent the others from being processed. """ - await asyncio.gather(*(self.on_span_start(s) for s in spans), return_exceptions=False) + results = await asyncio.gather( + *(self.on_span_start(s) for s in spans), return_exceptions=True + ) + for span, result in zip(spans, results): + if isinstance(result, Exception): + logger.error( + "Tracing processor %s failed on_span_start for span %s", + type(self).__name__, + span.id, + exc_info=result, + ) async def on_spans_end(self, spans: list[Span]) -> None: """Batched variant of on_span_end. See on_spans_start for details.""" - await asyncio.gather(*(self.on_span_end(s) for s in spans), return_exceptions=False) + results = await asyncio.gather( + *(self.on_span_end(s) for s in spans), return_exceptions=True + ) + for span, result in zip(spans, results): + if isinstance(result, Exception): + logger.error( + "Tracing processor %s failed on_span_end for span %s", + type(self).__name__, + span.id, + exc_info=result, + ) @abstractmethod async def shutdown(self) -> None: diff --git a/tests/lib/core/tracing/processors/test_tracing_processor_interface.py b/tests/lib/core/tracing/processors/test_tracing_processor_interface.py new file mode 100644 index 000000000..5229272bf --- /dev/null +++ b/tests/lib/core/tracing/processors/test_tracing_processor_interface.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import uuid +import logging +from datetime import UTC, datetime +from typing import override + +from agentex.types.span import Span +from agentex.lib.types.tracing import TracingProcessorConfig +from agentex.lib.core.tracing.processors.tracing_processor_interface import ( + AsyncTracingProcessor, +) + + +def _make_span(span_id: str | None = None) -> Span: + return Span( + id=span_id or str(uuid.uuid4()), + name="test-span", + start_time=datetime.now(UTC), + trace_id="trace-1", + ) + + +class _RecordingProcessor(AsyncTracingProcessor): + """Test processor that records every on_span_* call and fails on demand.""" + + def __init__(self, fail_ids: set[str] | None = None) -> None: + self.started_ids: list[str] = [] + self.ended_ids: list[str] = [] + self._fail_ids = fail_ids or set() + + @override + async def on_span_start(self, span: Span) -> None: + self.started_ids.append(span.id) + if span.id in self._fail_ids: + raise RuntimeError(f"boom-start-{span.id}") + + @override + async def on_span_end(self, span: Span) -> None: + self.ended_ids.append(span.id) + if span.id in self._fail_ids: + raise RuntimeError(f"boom-end-{span.id}") + + @override + async def shutdown(self) -> None: + pass + + +class TestDefaultBatchedFanout: + """The default on_spans_start / on_spans_end in AsyncTracingProcessor must: + - dispatch to the single-span method for every span + - continue after individual failures (not short-circuit) + - log each failure individually + - not propagate exceptions to the caller + """ + + async def test_on_spans_start_runs_every_span_despite_failures(self, caplog): + proc = _RecordingProcessor(fail_ids={"span-1"}) + spans = [_make_span(f"span-{i}") for i in range(3)] + + with caplog.at_level(logging.ERROR): + # Must not raise, even though span-1 fails. + await proc.on_spans_start(spans) + + # Every span's on_span_start was invoked + assert proc.started_ids == ["span-0", "span-1", "span-2"] + + async def test_on_spans_start_logs_each_failure(self, caplog): + proc = _RecordingProcessor(fail_ids={"span-0", "span-2"}) + spans = [_make_span(f"span-{i}") for i in range(3)] + + with caplog.at_level(logging.ERROR): + await proc.on_spans_start(spans) + + # Two distinct error log records, one per failing span + error_records = [r for r in caplog.records if r.levelno == logging.ERROR] + messages = " ".join(r.getMessage() for r in error_records) + assert "span-0" in messages + assert "span-2" in messages + + async def test_on_spans_end_runs_every_span_despite_failures(self, caplog): + proc = _RecordingProcessor(fail_ids={"span-1"}) + spans = [_make_span(f"span-{i}") for i in range(3)] + + with caplog.at_level(logging.ERROR): + await proc.on_spans_end(spans) + + assert proc.ended_ids == ["span-0", "span-1", "span-2"] + + async def test_dummy_config_construction(self): + """AsyncTracingProcessor's __init__ is abstract — verify concrete + subclass above satisfies the interface.""" + _ = TracingProcessorConfig + proc = _RecordingProcessor() + await proc.on_spans_start([]) + await proc.on_spans_end([]) + assert proc.started_ids == [] + assert proc.ended_ids == [] From 21b790f9d0df4a77d760da15a307b3a8d2568f38 Mon Sep 17 00:00:00 2001 From: alvinkam2001 Date: Wed, 22 Apr 2026 17:41:42 -0400 Subject: [PATCH 5/5] lint fixes --- .../core/tracing/processors/test_tracing_processor_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lib/core/tracing/processors/test_tracing_processor_interface.py b/tests/lib/core/tracing/processors/test_tracing_processor_interface.py index 5229272bf..12847b70d 100644 --- a/tests/lib/core/tracing/processors/test_tracing_processor_interface.py +++ b/tests/lib/core/tracing/processors/test_tracing_processor_interface.py @@ -2,8 +2,8 @@ import uuid import logging -from datetime import UTC, datetime from typing import override +from datetime import UTC, datetime from agentex.types.span import Span from agentex.lib.types.tracing import TracingProcessorConfig