diff --git a/AUDIT.md b/AUDIT.md index 739c72b..71c6783 100644 --- a/AUDIT.md +++ b/AUDIT.md @@ -182,3 +182,26 @@ Automated contract tests validate sample request/response payloads against the O - **Negative tests:** missing required fields, extra fields (additionalProperties), invalid enum values - **Enum value tests:** UnitEnum, ErrorCode, DecisionEnum, ReservationStatus, CommitOveragePolicy - **Spec fixture:** `tests/fixtures/cycles-protocol-v0.yaml` (copy of canonical spec) + +--- + +## Streaming Convenience Module (added 2026-04-08) + +**Module:** `runcycles/streaming.py` +**Test file:** `tests/test_streaming.py` (64 tests, all passing) +**Version:** 0.3.0 + +Added `StreamReservation` and `AsyncStreamReservation` context managers that automate the reserve → commit/release lifecycle for streaming use cases. This is a DX convenience layer — no protocol changes. + +- **`StreamReservation`** — sync context manager: reserves on `__enter__`, auto-commits on successful `__exit__`, auto-releases on exception +- **`AsyncStreamReservation`** — async equivalent using `__aenter__`/`__aexit__` +- **`StreamUsage`** — mutable accumulator for token counts and cost during streaming +- **Client convenience methods:** `CyclesClient.stream_reservation()` and `AsyncCyclesClient.stream_reservation()` — thin factories that build Subject from config defaults +- **Cost resolution:** explicit `usage.actual_cost` > `cost_fn(usage)` > estimate fallback +- **Heartbeat:** automatic TTL extension, same interval formula as decorator lifecycle (`max(ttl_ms / 2, 1000)` ms) +- **Commit retry:** uses existing `CommitRetryEngine`/`AsyncCommitRetryEngine` +- **Context propagation:** sets/clears `CyclesContext` via `ContextVar`, accessible via `get_cycles_context()`; respects user-set `ctx.metrics` during streaming +- **Spec validation:** `validate_ttl_ms()` (1000–86400000), `validate_grace_period_ms()` (0–60000), `validate_subject()` (at least one standard field) — matches lifecycle.py +- **Error handling:** `RESERVATION_FINALIZED`, `RESERVATION_EXPIRED`, and `IDEMPOTENCY_MISMATCH` do not trigger release; other 4xx client errors do trigger release — matches lifecycle.py behavior exactly + +Protocol conformance: No new endpoints or protocol changes. All reservation, commit, release, and extend calls use the same client methods and body formats as the decorator path. Verified by 64 unit tests covering success, deny, error, retry, heartbeat, cost resolution, context propagation, spec validation, and all commit error-code branches. diff --git a/README.md b/README.md index f27b031..f4cd8b6 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,47 @@ async def call_llm(prompt: str) -> str: result = await call_llm("Hello") ``` +### Streaming + +For streaming LLM responses, use the `stream_reservation()` context manager. It reserves budget on enter, auto-commits on successful exit, and auto-releases on exception: + +```python +from openai import OpenAI +from runcycles import CyclesClient, CyclesConfig, Action, Amount, Unit + +config = CyclesConfig(base_url="http://localhost:7878", api_key="your-api-key", tenant="acme") +cycles_client = CyclesClient(config) +openai_client = OpenAI() +max_tokens = 1024 + +with cycles_client.stream_reservation( + action=Action(kind="llm.completion", name="gpt-4o"), + estimate=Amount(unit=Unit.USD_MICROCENTS, amount=max_tokens * 1000), + cost_fn=lambda u: u.tokens_input * 250 + u.tokens_output * 1000, +) as reservation: + # Caps available immediately after entering the context + if reservation.caps and reservation.caps.max_tokens: + max_tokens = min(max_tokens, reservation.caps.max_tokens) + + stream = openai_client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=max_tokens, + stream=True, + stream_options={"include_usage": True}, + ) + + for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="", flush=True) + if chunk.usage: + reservation.usage.tokens_input = chunk.usage.prompt_tokens + reservation.usage.tokens_output = chunk.usage.completion_tokens +# Committed automatically with actual cost computed by cost_fn +``` + +Also available as `async with client.stream_reservation(...)` for async clients. See [streaming_usage.py](examples/streaming_usage.py) for a complete example. + ## Configuration ### From environment variables @@ -374,7 +415,7 @@ The [`examples/`](examples/) directory contains runnable integration examples: | [async_usage.py](examples/async_usage.py) | Async client and async decorator | | [openai_integration.py](examples/openai_integration.py) | Guard OpenAI chat completions with budget checks | | [anthropic_integration.py](examples/anthropic_integration.py) | Guard Anthropic messages with per-tool budget tracking | -| [streaming_usage.py](examples/streaming_usage.py) | Budget-managed streaming with token accumulation | +| [streaming_usage.py](examples/streaming_usage.py) | `stream_reservation()` context manager with auto-commit | | [fastapi_integration.py](examples/fastapi_integration.py) | FastAPI middleware, dependency injection, per-tenant budgets | | [langchain_integration.py](examples/langchain_integration.py) | LangChain callback handler for budget-aware agents | diff --git a/examples/README.md b/examples/README.md index 35790d7..d066bb8 100644 --- a/examples/README.md +++ b/examples/README.md @@ -28,7 +28,7 @@ pip install runcycles | [async_usage.py](async_usage.py) | Async client and async decorator | — | | [openai_integration.py](openai_integration.py) | Guard OpenAI chat completions with budget checks | `openai` | | [anthropic_integration.py](anthropic_integration.py) | Guard Anthropic messages with per-tool budget tracking | `anthropic` | -| [streaming_usage.py](streaming_usage.py) | Budget-managed streaming with token accumulation | `openai` | +| [streaming_usage.py](streaming_usage.py) | `stream_reservation()` context manager with auto-commit | `openai` | | [fastapi_integration.py](fastapi_integration.py) | FastAPI middleware, dependency injection, per-tenant budgets | `fastapi`, `uvicorn` | | [langchain_integration.py](langchain_integration.py) | LangChain callback handler for budget-aware agents | `langchain`, `langchain-openai` | diff --git a/examples/streaming_usage.py b/examples/streaming_usage.py index ab2ff7e..d5032bf 100644 --- a/examples/streaming_usage.py +++ b/examples/streaming_usage.py @@ -1,7 +1,7 @@ """Budget-managed streaming with Cycles. -Demonstrates the programmatic reserve → stream → commit pattern where the -actual cost is only known after the stream completes. +Demonstrates the StreamReservation context manager: reserve on enter, +auto-commit on success, auto-release on exception. Requirements: pip install runcycles openai @@ -14,8 +14,6 @@ """ import os -import time -import uuid from openai import OpenAI @@ -23,14 +21,8 @@ Action, Amount, BudgetExceededError, - CommitRequest, CyclesClient, CyclesConfig, - CyclesMetrics, - CyclesProtocolError, - ReleaseRequest, - ReservationCreateRequest, - Subject, Unit, ) @@ -50,7 +42,7 @@ # --------------------------------------------------------------------------- -# 2. Streaming with budget management +# 2. Streaming with budget management (context manager API) # --------------------------------------------------------------------------- def stream_with_budget( prompt: str, @@ -59,65 +51,25 @@ def stream_with_budget( ) -> str: """Stream an OpenAI response with Cycles budget protection. - The pattern: - 1. Reserve budget based on max_tokens (worst case) - 2. Stream the response, accumulating output - 3. Commit the actual cost after the stream completes - 4. Release the reservation if streaming fails + The StreamReservation context manager handles: + - Creating a reservation on enter + - Auto-committing actual cost on successful exit + - Auto-releasing the reservation on exception + - Heartbeat-based TTL extension for long streams """ estimated_input_tokens = len(prompt.split()) * 2 - estimated_cost = ( - estimated_input_tokens * PRICE_PER_INPUT_TOKEN - + max_tokens * PRICE_PER_OUTPUT_TOKEN - ) - - idempotency_key = str(uuid.uuid4()) - - # Step 1: Reserve budget - reserve_response = cycles_client.create_reservation( - ReservationCreateRequest( - idempotency_key=idempotency_key, - subject=Subject(tenant=config.tenant, agent="streaming-agent"), - action=Action(kind="llm.completion", name=model), - estimate=Amount(unit=Unit.USD_MICROCENTS, amount=estimated_cost), - ttl_ms=120_000, # longer TTL for streaming - ) - ) - - if not reserve_response.is_success: - error = reserve_response.get_error_response() - if error and error.error == "BUDGET_EXCEEDED": - raise BudgetExceededError( - error.message, - status=reserve_response.status, - error_code=error.error, - request_id=error.request_id, - details=error.details, - ) - msg = error.message if error else (reserve_response.error_message or "Reservation failed") - raise CyclesProtocolError( - msg, - status=reserve_response.status, - error_code=error.error if error else None, - request_id=error.request_id if error else None, - details=error.details if error else None, - ) - - reservation_id = reserve_response.get_body_attribute("reservation_id") - decision = reserve_response.get_body_attribute("decision") - - # Check for caps - caps = reserve_response.get_body_attribute("caps") - if caps and caps.get("max_tokens"): - max_tokens = min(max_tokens, caps["max_tokens"]) - print(f" Budget authority capped max_tokens to {max_tokens}") + estimated_cost = estimated_input_tokens * PRICE_PER_INPUT_TOKEN + max_tokens * PRICE_PER_OUTPUT_TOKEN + + with cycles_client.stream_reservation( + action=Action(kind="llm.completion", name=model), + estimate=Amount(unit=Unit.USD_MICROCENTS, amount=estimated_cost), + cost_fn=lambda u: u.tokens_input * PRICE_PER_INPUT_TOKEN + u.tokens_output * PRICE_PER_OUTPUT_TOKEN, + ) as reservation: + # Caps are available immediately after entering the context + if reservation.caps and reservation.caps.max_tokens: + max_tokens = min(max_tokens, reservation.caps.max_tokens) + print(f" Budget authority capped max_tokens to {max_tokens}") - # Step 2: Stream the response - start_time = time.time() - chunks: list[str] = [] - completion_tokens = 0 - - try: stream = openai_client.chat.completions.create( model=model, messages=[{"role": "user", "content": prompt}], @@ -126,6 +78,7 @@ def stream_with_budget( stream_options={"include_usage": True}, ) + chunks: list[str] = [] for chunk in stream: if chunk.choices and chunk.choices[0].delta.content: text = chunk.choices[0].delta.content @@ -134,48 +87,12 @@ def stream_with_budget( # The final chunk includes usage stats if chunk.usage: - input_tokens = chunk.usage.prompt_tokens - completion_tokens = chunk.usage.completion_tokens + reservation.usage.tokens_input = chunk.usage.prompt_tokens + reservation.usage.tokens_output = chunk.usage.completion_tokens print() # newline after streaming - except Exception: - # If streaming fails, release the reservation to free budget - cycles_client.release_reservation( - reservation_id, - ReleaseRequest(idempotency_key=f"release-{idempotency_key}"), - ) - raise - - # Step 3: Commit actual cost - elapsed_ms = int((time.time() - start_time) * 1000) - actual_cost = ( - input_tokens * PRICE_PER_INPUT_TOKEN - + completion_tokens * PRICE_PER_OUTPUT_TOKEN - ) - - commit_response = cycles_client.commit_reservation( - reservation_id, - CommitRequest( - idempotency_key=f"commit-{idempotency_key}", - actual=Amount(unit=Unit.USD_MICROCENTS, amount=actual_cost), - metrics=CyclesMetrics( - tokens_input=input_tokens, - tokens_output=completion_tokens, - latency_ms=elapsed_ms, - model_version=model, - custom={"streamed": True, "decision": decision}, - ), - ), - ) - - if not commit_response.is_success: - print(f" Warning: commit failed: {commit_response.error_message}") - - savings = estimated_cost - actual_cost - print(f" Estimated: {estimated_cost} microcents, Actual: {actual_cost} microcents") - print(f" Budget saved by accurate commit: {savings} microcents") - + # Auto-committed on exit with actual cost computed by cost_fn return "".join(chunks) diff --git a/pyproject.toml b/pyproject.toml index 6ca9616..c51a36f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "runcycles" -version = "0.2.0" +version = "0.3.0" description = "Python client for the Cycles budget-management protocol" readme = "README.md" license = "Apache-2.0" diff --git a/runcycles/__init__.py b/runcycles/__init__.py index 10c48d5..4d237ef 100644 --- a/runcycles/__init__.py +++ b/runcycles/__init__.py @@ -51,6 +51,7 @@ Unit, ) from runcycles.response import CyclesResponse +from runcycles.streaming import AsyncStreamReservation, StreamReservation, StreamUsage __all__ = [ # Client @@ -67,6 +68,10 @@ "get_cycles_context", # Response "CyclesResponse", + # Streaming + "StreamReservation", + "AsyncStreamReservation", + "StreamUsage", # Exceptions "CyclesError", "CyclesProtocolError", diff --git a/runcycles/client.py b/runcycles/client.py index 3cadea0..6c95771 100644 --- a/runcycles/client.py +++ b/runcycles/client.py @@ -3,11 +3,16 @@ from __future__ import annotations import logging -from typing import Any +from collections.abc import Callable +from typing import TYPE_CHECKING, Any import httpx from pydantic import BaseModel +if TYPE_CHECKING: + from runcycles.models import Action, Amount, Subject + from runcycles.streaming import AsyncStreamReservation, StreamReservation, StreamUsage + from runcycles._constants import ( API_KEY_HEADER, BALANCES_PATH, @@ -55,8 +60,7 @@ def _validate_balance_filters(params: dict[str, str]) -> None: """Validate that at least one subject filter is provided for balance queries.""" if not any(k in _BALANCE_FILTER_PARAMS for k in params): raise ValueError( - "get_balances requires at least one subject filter" - " (tenant, workspace, app, workflow, agent, or toolset)" + "get_balances requires at least one subject filter (tenant, workspace, app, workflow, agent, or toolset)" ) @@ -106,6 +110,43 @@ def get_balances(self, **query_params: str) -> CyclesResponse: def create_event(self, request: BaseModel | dict[str, Any]) -> CyclesResponse: return self._post(EVENTS_PATH, request) + def stream_reservation( + self, + *, + subject: Subject | None = None, + action: Action, + estimate: Amount, + ttl_ms: int = 120_000, + grace_period_ms: int | None = None, + overage_policy: str = "ALLOW_IF_AVAILABLE", + cost_fn: Callable[[StreamUsage], int] | None = None, + metadata: dict[str, Any] | None = None, + ) -> StreamReservation: + """Return a context manager that reserves budget on enter and commits/releases on exit.""" + from runcycles.models import Subject as _Subject + from runcycles.streaming import StreamReservation as _SR + + if subject is None: + subject = _Subject( + tenant=self._config.tenant, + workspace=self._config.workspace, + app=self._config.app, + workflow=self._config.workflow, + agent=self._config.agent, + toolset=self._config.toolset, + ) + return _SR( + self, + subject=subject, + action=action, + estimate=estimate, + ttl_ms=ttl_ms, + grace_period_ms=grace_period_ms, + overage_policy=overage_policy, + cost_fn=cost_fn, + metadata=metadata, + ) + def close(self) -> None: self._http.close() @@ -152,7 +193,10 @@ def _handle_response(resp: httpx.Response) -> CyclesResponse: if body and isinstance(body, dict): error_msg = body.get("message") or body.get("error") return CyclesResponse.http_error( - resp.status_code, error_msg or resp.reason_phrase or "Unknown error", body, headers=headers, + resp.status_code, + error_msg or resp.reason_phrase or "Unknown error", + body, + headers=headers, ) @@ -202,6 +246,43 @@ async def get_balances(self, **query_params: str) -> CyclesResponse: async def create_event(self, request: BaseModel | dict[str, Any]) -> CyclesResponse: return await self._post(EVENTS_PATH, request) + def stream_reservation( + self, + *, + subject: Subject | None = None, + action: Action, + estimate: Amount, + ttl_ms: int = 120_000, + grace_period_ms: int | None = None, + overage_policy: str = "ALLOW_IF_AVAILABLE", + cost_fn: Callable[[StreamUsage], int] | None = None, + metadata: dict[str, Any] | None = None, + ) -> AsyncStreamReservation: + """Return an async context manager that reserves budget on enter and commits/releases on exit.""" + from runcycles.models import Subject as _Subject + from runcycles.streaming import AsyncStreamReservation as _ASR + + if subject is None: + subject = _Subject( + tenant=self._config.tenant, + workspace=self._config.workspace, + app=self._config.app, + workflow=self._config.workflow, + agent=self._config.agent, + toolset=self._config.toolset, + ) + return _ASR( + self, + subject=subject, + action=action, + estimate=estimate, + ttl_ms=ttl_ms, + grace_period_ms=grace_period_ms, + overage_policy=overage_policy, + cost_fn=cost_fn, + metadata=metadata, + ) + async def aclose(self) -> None: await self._http.aclose() @@ -248,5 +329,8 @@ def _handle_response(resp: httpx.Response) -> CyclesResponse: if body and isinstance(body, dict): error_msg = body.get("message") or body.get("error") return CyclesResponse.http_error( - resp.status_code, error_msg or resp.reason_phrase or "Unknown error", body, headers=headers, + resp.status_code, + error_msg or resp.reason_phrase or "Unknown error", + body, + headers=headers, ) diff --git a/runcycles/streaming.py b/runcycles/streaming.py new file mode 100644 index 0000000..38adb7b --- /dev/null +++ b/runcycles/streaming.py @@ -0,0 +1,564 @@ +"""Streaming convenience: reserve on enter, commit/release on exit.""" + +from __future__ import annotations + +import asyncio +import logging +import threading +import time +import uuid +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +from runcycles._validation import validate_grace_period_ms, validate_subject, validate_ttl_ms +from runcycles.client import AsyncCyclesClient, CyclesClient +from runcycles.context import CyclesContext, _clear_context, _set_context +from runcycles.exceptions import CyclesProtocolError +from runcycles.lifecycle import ( + _build_commit_body, + _build_extend_body, + _build_protocol_exception, + _build_release_body, +) +from runcycles.models import ( + Action, + Amount, + Caps, + CyclesMetrics, + Decision, + ReservationCreateResponse, + Subject, +) +from runcycles.retry import AsyncCommitRetryEngine, CommitRetryEngine + +logger = logging.getLogger(__name__) + + +@dataclass +class StreamUsage: + """Mutable accumulator for streaming usage. + + Update fields during streaming; the context manager reads them at commit time. + """ + + tokens_input: int = 0 + tokens_output: int = 0 + actual_cost: int | None = None + model_version: str | None = None + custom: dict[str, Any] = field(default_factory=dict) + + def add_input_tokens(self, count: int) -> None: + self.tokens_input += count + + def add_output_tokens(self, count: int) -> None: + self.tokens_output += count + + def set_actual_cost(self, amount: int) -> None: + self.actual_cost = amount + + +def _build_streaming_reservation_body( + subject: Subject, + action: Action, + estimate: Amount, + ttl_ms: int, + overage_policy: str, + grace_period_ms: int | None, +) -> dict[str, Any]: + validate_subject(subject) + validate_ttl_ms(ttl_ms) + validate_grace_period_ms(grace_period_ms) + + body: dict[str, Any] = { + "idempotency_key": str(uuid.uuid4()), + "subject": subject.model_dump(exclude_none=True), + "action": action.model_dump(exclude_none=True), + "estimate": estimate.model_dump(exclude_none=True), + "ttl_ms": ttl_ms, + "overage_policy": overage_policy, + } + if grace_period_ms is not None: + body["grace_period_ms"] = grace_period_ms + return body + + +def _resolve_actual_cost( + usage: StreamUsage, + cost_fn: Callable[[StreamUsage], int] | None, + estimate_amount: int, +) -> int: + """Resolve the actual cost: explicit > cost_fn > estimate fallback.""" + if usage.actual_cost is not None: + return usage.actual_cost + if cost_fn is not None: + try: + return cost_fn(usage) + except Exception: + logger.warning("cost_fn raised, falling back to estimate", exc_info=True) + return estimate_amount + return estimate_amount + + +def _build_stream_metrics( + usage: StreamUsage, + elapsed_ms: int, + ctx_metrics: CyclesMetrics | None, +) -> CyclesMetrics: + """Build commit metrics, merging user-set ctx.metrics with stream usage.""" + if ctx_metrics is not None: + # User set metrics on context during streaming — respect them, + # but fill in latency if not already set. + if ctx_metrics.latency_ms is None: + ctx_metrics.latency_ms = elapsed_ms + return ctx_metrics + + return CyclesMetrics( + tokens_input=usage.tokens_input if usage.tokens_input else None, + tokens_output=usage.tokens_output if usage.tokens_output else None, + latency_ms=elapsed_ms, + model_version=usage.model_version, + custom=usage.custom or None, + ) + + +class StreamReservation: + """Sync context manager: reserve on ``__enter__``, commit/release on ``__exit__``. + + Usage:: + + with client.stream_reservation( + action=Action(kind="llm.completion", name="gpt-4o"), + estimate=Amount(unit=Unit.USD_MICROCENTS, amount=1_000_000), + cost_fn=lambda u: u.tokens_input * 250 + u.tokens_output * 1000, + ) as reservation: + for chunk in stream: + reservation.usage.tokens_input = chunk.usage.prompt_tokens + reservation.usage.tokens_output = chunk.usage.completion_tokens + # Auto-committed on success, auto-released on exception. + """ + + def __init__( + self, + client: CyclesClient, + *, + subject: Subject, + action: Action, + estimate: Amount, + ttl_ms: int = 120_000, + grace_period_ms: int | None = None, + overage_policy: str = "ALLOW_IF_AVAILABLE", + cost_fn: Callable[[StreamUsage], int] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + self._client = client + self._subject = subject + self._action = action + self._estimate = estimate + self._ttl_ms = ttl_ms + self._grace_period_ms = grace_period_ms + self._overage_policy = overage_policy + self._cost_fn = cost_fn + self._metadata = metadata + + self._usage = StreamUsage() + self._reservation_id: str | None = None + self._caps: Caps | None = None + self._decision: Decision = Decision.ALLOW + self._ctx: CyclesContext | None = None + self._start_time: float = 0.0 + + self._heartbeat_stop = threading.Event() + self._heartbeat_thread: threading.Thread | None = None + + self._retry_engine = CommitRetryEngine(client._config) + self._retry_engine.set_client(client) + + @property + def usage(self) -> StreamUsage: + return self._usage + + @property + def reservation_id(self) -> str: + if self._reservation_id is None: + raise RuntimeError("reservation_id not available outside context manager") + return self._reservation_id + + @property + def caps(self) -> Caps | None: + return self._caps + + @property + def decision(self) -> Decision: + return self._decision + + def __enter__(self) -> StreamReservation: + body = _build_streaming_reservation_body( + self._subject, + self._action, + self._estimate, + self._ttl_ms, + self._overage_policy, + self._grace_period_ms, + ) + + response = self._client.create_reservation(body) + + if not response.is_success: + raise _build_protocol_exception("Failed to create reservation", response) + + result = ReservationCreateResponse.model_validate(response.body) + + if result.decision == Decision.DENY: + raise _build_protocol_exception("Reservation denied", response) + + if result.reservation_id is None: + raise CyclesProtocolError( + "Reservation successful but reservation_id missing", + status=response.status, + ) + + self._reservation_id = result.reservation_id + self._decision = result.decision + self._caps = result.caps + + self._ctx = CyclesContext( + reservation_id=result.reservation_id, + estimate=self._estimate.amount, + decision=result.decision, + caps=result.caps, + expires_at_ms=result.expires_at_ms, + affected_scopes=result.affected_scopes, + scope_path=result.scope_path, + reserved=result.reserved, + balances=result.balances, + ) + _set_context(self._ctx) + + self._start_time = time.monotonic() + self._heartbeat_thread = self._start_heartbeat() + + logger.info( + "Stream reservation created: id=%s, decision=%s", + self._reservation_id, + self._decision, + ) + + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + self._heartbeat_stop.set() + if self._heartbeat_thread and self._heartbeat_thread.is_alive(): + self._heartbeat_thread.join(timeout=1.0) + + assert self._reservation_id is not None + + try: + if exc_type is not None: + self._handle_release("stream_failed") + else: + self._handle_commit() + finally: + _clear_context() + + def _handle_commit(self) -> None: + elapsed_ms = int((time.monotonic() - self._start_time) * 1000) + actual = _resolve_actual_cost(self._usage, self._cost_fn, self._estimate.amount) + ctx_metrics = self._ctx.metrics if self._ctx else None + metrics = _build_stream_metrics(self._usage, elapsed_ms, ctx_metrics) + unit = self._estimate.unit if isinstance(self._estimate.unit, str) else self._estimate.unit.value + commit_body = _build_commit_body(actual, unit, metrics, self._metadata) + + assert self._reservation_id is not None + try: + response = self._client.commit_reservation(self._reservation_id, commit_body) + if response.is_success: + logger.info("Stream commit successful: id=%s", self._reservation_id) + elif response.is_transport_error or response.is_server_error: + logger.warning("Stream commit failed (retryable): id=%s", self._reservation_id) + self._retry_engine.schedule(self._reservation_id, commit_body) + else: + error_code = None + error_resp = response.get_error_response() + if error_resp and error_resp.error_code: + error_code = error_resp.error_code.value + if error_code in ("RESERVATION_FINALIZED", "RESERVATION_EXPIRED"): + logger.warning("Reservation already finalized/expired: id=%s", self._reservation_id) + elif error_code == "IDEMPOTENCY_MISMATCH": + logger.warning("Commit idempotency mismatch (not releasing): id=%s", self._reservation_id) + elif response.is_client_error: + self._handle_release(f"commit_rejected_{error_code}") + else: + logger.warning("Unrecognized commit response: id=%s", self._reservation_id) + except Exception: + logger.exception("Failed to commit stream: id=%s", self._reservation_id) + self._retry_engine.schedule(self._reservation_id, commit_body) + + def _handle_release(self, reason: str) -> None: + assert self._reservation_id is not None + try: + body = _build_release_body(reason) + response = self._client.release_reservation(self._reservation_id, body) + if response.is_success: + logger.info("Stream released: id=%s", self._reservation_id) + else: + logger.warning("Stream release failed: id=%s, status=%d", self._reservation_id, response.status) + except Exception: + logger.exception("Failed to release stream: id=%s", self._reservation_id) + + def _start_heartbeat(self) -> threading.Thread | None: + if self._ttl_ms <= 0: + return None + interval_s = max(self._ttl_ms / 2, 1000) / 1000.0 + assert self._reservation_id is not None + reservation_id: str = self._reservation_id + ctx = self._ctx + + def heartbeat_loop() -> None: + while not self._heartbeat_stop.wait(timeout=interval_s): + try: + body = _build_extend_body(self._ttl_ms) + response = self._client.extend_reservation(reservation_id, body) + if response.is_success: + new_expires = response.get_body_attribute("expires_at_ms") + if new_expires is not None and ctx is not None: + ctx.update_expires_at_ms(int(new_expires)) + else: + logger.warning("Stream heartbeat failed: id=%s", reservation_id) + except Exception: + logger.warning("Stream heartbeat error: id=%s", reservation_id, exc_info=True) + + t = threading.Thread( + target=heartbeat_loop, + daemon=True, + name=f"cycles-stream-hb-{reservation_id[:12] if reservation_id else 'unknown'}", + ) + t.start() + return t + + +class AsyncStreamReservation: + """Async context manager: reserve on ``__aenter__``, commit/release on ``__aexit__``. + + Usage:: + + async with client.stream_reservation( + action=Action(kind="llm.completion", name="gpt-4o"), + estimate=Amount(unit=Unit.USD_MICROCENTS, amount=1_000_000), + cost_fn=lambda u: u.tokens_input * 250 + u.tokens_output * 1000, + ) as reservation: + async for chunk in stream: + reservation.usage.tokens_input = chunk.usage.prompt_tokens + reservation.usage.tokens_output = chunk.usage.completion_tokens + # Auto-committed on success, auto-released on exception. + """ + + def __init__( + self, + client: AsyncCyclesClient, + *, + subject: Subject, + action: Action, + estimate: Amount, + ttl_ms: int = 120_000, + grace_period_ms: int | None = None, + overage_policy: str = "ALLOW_IF_AVAILABLE", + cost_fn: Callable[[StreamUsage], int] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + self._client = client + self._subject = subject + self._action = action + self._estimate = estimate + self._ttl_ms = ttl_ms + self._grace_period_ms = grace_period_ms + self._overage_policy = overage_policy + self._cost_fn = cost_fn + self._metadata = metadata + + self._usage = StreamUsage() + self._reservation_id: str | None = None + self._caps: Caps | None = None + self._decision: Decision = Decision.ALLOW + self._ctx: CyclesContext | None = None + self._start_time: float = 0.0 + + self._heartbeat_task: asyncio.Task[None] | None = None + + self._retry_engine = AsyncCommitRetryEngine(client._config) + self._retry_engine.set_client(client) + + @property + def usage(self) -> StreamUsage: + return self._usage + + @property + def reservation_id(self) -> str: + if self._reservation_id is None: + raise RuntimeError("reservation_id not available outside context manager") + return self._reservation_id + + @property + def caps(self) -> Caps | None: + return self._caps + + @property + def decision(self) -> Decision: + return self._decision + + async def __aenter__(self) -> AsyncStreamReservation: + body = _build_streaming_reservation_body( + self._subject, + self._action, + self._estimate, + self._ttl_ms, + self._overage_policy, + self._grace_period_ms, + ) + + response = await self._client.create_reservation(body) + + if not response.is_success: + raise _build_protocol_exception("Failed to create reservation", response) + + result = ReservationCreateResponse.model_validate(response.body) + + if result.decision == Decision.DENY: + raise _build_protocol_exception("Reservation denied", response) + + if result.reservation_id is None: + raise CyclesProtocolError( + "Reservation successful but reservation_id missing", + status=response.status, + ) + + self._reservation_id = result.reservation_id + self._decision = result.decision + self._caps = result.caps + + self._ctx = CyclesContext( + reservation_id=result.reservation_id, + estimate=self._estimate.amount, + decision=result.decision, + caps=result.caps, + expires_at_ms=result.expires_at_ms, + affected_scopes=result.affected_scopes, + scope_path=result.scope_path, + reserved=result.reserved, + balances=result.balances, + ) + _set_context(self._ctx) + + self._start_time = time.monotonic() + self._heartbeat_task = self._start_heartbeat() + + logger.info( + "Async stream reservation created: id=%s, decision=%s", + self._reservation_id, + self._decision, + ) + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + if self._heartbeat_task: + self._heartbeat_task.cancel() + try: + await self._heartbeat_task + except asyncio.CancelledError: + pass + + assert self._reservation_id is not None + + try: + if exc_type is not None: + await self._handle_release("stream_failed") + else: + await self._handle_commit() + finally: + _clear_context() + + async def _handle_commit(self) -> None: + elapsed_ms = int((time.monotonic() - self._start_time) * 1000) + actual = _resolve_actual_cost(self._usage, self._cost_fn, self._estimate.amount) + ctx_metrics = self._ctx.metrics if self._ctx else None + metrics = _build_stream_metrics(self._usage, elapsed_ms, ctx_metrics) + unit = self._estimate.unit if isinstance(self._estimate.unit, str) else self._estimate.unit.value + commit_body = _build_commit_body(actual, unit, metrics, self._metadata) + + assert self._reservation_id is not None + try: + response = await self._client.commit_reservation(self._reservation_id, commit_body) + if response.is_success: + logger.info("Async stream commit successful: id=%s", self._reservation_id) + elif response.is_transport_error or response.is_server_error: + logger.warning("Async stream commit failed (retryable): id=%s", self._reservation_id) + self._retry_engine.schedule(self._reservation_id, commit_body) + else: + error_code = None + error_resp = response.get_error_response() + if error_resp and error_resp.error_code: + error_code = error_resp.error_code.value + if error_code in ("RESERVATION_FINALIZED", "RESERVATION_EXPIRED"): + logger.warning("Reservation already finalized/expired: id=%s", self._reservation_id) + elif error_code == "IDEMPOTENCY_MISMATCH": + logger.warning("Commit idempotency mismatch (not releasing): id=%s", self._reservation_id) + elif response.is_client_error: + await self._handle_release(f"commit_rejected_{error_code}") + else: + logger.warning("Unrecognized commit response: id=%s", self._reservation_id) + except Exception: + logger.exception("Failed to commit async stream: id=%s", self._reservation_id) + self._retry_engine.schedule(self._reservation_id, commit_body) + + async def _handle_release(self, reason: str) -> None: + assert self._reservation_id is not None + try: + body = _build_release_body(reason) + response = await self._client.release_reservation(self._reservation_id, body) + if response.is_success: + logger.info("Async stream released: id=%s", self._reservation_id) + else: + logger.warning("Async stream release failed: id=%s, status=%d", self._reservation_id, response.status) + except Exception: + logger.exception("Failed to release async stream: id=%s", self._reservation_id) + + def _start_heartbeat(self) -> asyncio.Task[None] | None: + if self._ttl_ms <= 0: + return None + interval_s = max(self._ttl_ms / 2, 1000) / 1000.0 + assert self._reservation_id is not None + reservation_id: str = self._reservation_id + ctx = self._ctx + client = self._client + ttl_ms = self._ttl_ms + + async def heartbeat_loop() -> None: + try: + while True: + await asyncio.sleep(interval_s) + try: + body = _build_extend_body(ttl_ms) + response = await client.extend_reservation(reservation_id, body) + if response.is_success: + new_expires = response.get_body_attribute("expires_at_ms") + if new_expires is not None and ctx is not None: + ctx.update_expires_at_ms(int(new_expires)) + else: + logger.warning("Async stream heartbeat failed: id=%s", reservation_id) + except Exception: + logger.warning("Async stream heartbeat error: id=%s", reservation_id, exc_info=True) + except asyncio.CancelledError: + return + + return asyncio.create_task(heartbeat_loop()) diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 0000000..04625de --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,1308 @@ +"""Tests for the streaming convenience module.""" + +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from runcycles.client import AsyncCyclesClient, CyclesClient +from runcycles.config import CyclesConfig +from runcycles.context import get_cycles_context +from runcycles.exceptions import BudgetExceededError, CyclesProtocolError +from runcycles.models import Action, Amount, Decision, Subject, Unit +from runcycles.response import CyclesResponse +from runcycles.streaming import ( + AsyncStreamReservation, + StreamReservation, + StreamUsage, + _build_streaming_reservation_body, + _resolve_actual_cost, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_config() -> CyclesConfig: + return CyclesConfig( + base_url="http://localhost:7878", + api_key="test-key", + tenant="acme", + retry_enabled=False, + retry_initial_delay=0.001, + retry_max_delay=0.01, + ) + + +def _allow_response(caps: dict | None = None) -> CyclesResponse: + body: dict = { + "decision": "ALLOW", + "reservation_id": "rsv_stream_test", + "expires_at_ms": int(time.time() * 1000) + 600_000, + "affected_scopes": ["tenant:acme"], + "scope_path": "tenant:acme", + "reserved": {"unit": "USD_MICROCENTS", "amount": 1000}, + } + if caps: + body["caps"] = caps + return CyclesResponse.success(200, body) + + +def _deny_response() -> CyclesResponse: + return CyclesResponse.success( + 200, + { + "decision": "DENY", + "affected_scopes": ["tenant:acme"], + "reason_code": "BUDGET_EXCEEDED", + }, + ) + + +def _commit_success() -> CyclesResponse: + return CyclesResponse.success( + 200, + { + "status": "COMMITTED", + "charged": {"unit": "USD_MICROCENTS", "amount": 500}, + }, + ) + + +def _release_success() -> CyclesResponse: + return CyclesResponse.success( + 200, + { + "status": "RELEASED", + "released": {"unit": "USD_MICROCENTS", "amount": 1000}, + }, + ) + + +def _make_mock_client() -> MagicMock: + config = _make_config() + mock = MagicMock(spec=CyclesClient) + mock._config = config + return mock + + +def _make_async_mock_client() -> MagicMock: + config = _make_config() + mock = MagicMock(spec=AsyncCyclesClient) + mock._config = config + mock.create_reservation = AsyncMock() + mock.commit_reservation = AsyncMock() + mock.release_reservation = AsyncMock() + mock.extend_reservation = AsyncMock() + return mock + + +def _default_subject() -> Subject: + return Subject(tenant="acme") + + +def _default_action() -> Action: + return Action(kind="llm.completion", name="gpt-4o") + + +def _default_estimate() -> Amount: + return Amount(unit=Unit.USD_MICROCENTS, amount=1000) + + +# --------------------------------------------------------------------------- +# StreamUsage tests +# --------------------------------------------------------------------------- + + +class TestStreamUsage: + def test_defaults(self) -> None: + u = StreamUsage() + assert u.tokens_input == 0 + assert u.tokens_output == 0 + assert u.actual_cost is None + assert u.model_version is None + assert u.custom == {} + + def test_add_input_tokens(self) -> None: + u = StreamUsage() + u.add_input_tokens(10) + u.add_input_tokens(5) + assert u.tokens_input == 15 + + def test_add_output_tokens(self) -> None: + u = StreamUsage() + u.add_output_tokens(20) + assert u.tokens_output == 20 + + def test_set_actual_cost(self) -> None: + u = StreamUsage() + u.set_actual_cost(999) + assert u.actual_cost == 999 + + +# --------------------------------------------------------------------------- +# _build_streaming_reservation_body tests +# --------------------------------------------------------------------------- + + +class TestBuildStreamingReservationBody: + def test_basic(self) -> None: + body = _build_streaming_reservation_body( + _default_subject(), + _default_action(), + _default_estimate(), + ttl_ms=120_000, + overage_policy="ALLOW_IF_AVAILABLE", + grace_period_ms=None, + ) + assert body["subject"]["tenant"] == "acme" + assert body["action"]["kind"] == "llm.completion" + assert body["estimate"]["amount"] == 1000 + assert body["ttl_ms"] == 120_000 + assert body["overage_policy"] == "ALLOW_IF_AVAILABLE" + assert "idempotency_key" in body + assert "grace_period_ms" not in body + + def test_with_grace_period(self) -> None: + body = _build_streaming_reservation_body( + _default_subject(), + _default_action(), + _default_estimate(), + ttl_ms=120_000, + overage_policy="REJECT", + grace_period_ms=5000, + ) + assert body["grace_period_ms"] == 5000 + + def test_ttl_below_minimum_raises(self) -> None: + with pytest.raises(ValueError, match="ttl_ms"): + _build_streaming_reservation_body( + _default_subject(), + _default_action(), + _default_estimate(), + ttl_ms=500, + overage_policy="ALLOW_IF_AVAILABLE", + grace_period_ms=None, + ) + + def test_ttl_above_maximum_raises(self) -> None: + with pytest.raises(ValueError, match="ttl_ms"): + _build_streaming_reservation_body( + _default_subject(), + _default_action(), + _default_estimate(), + ttl_ms=86_400_001, + overage_policy="ALLOW_IF_AVAILABLE", + grace_period_ms=None, + ) + + def test_grace_period_above_maximum_raises(self) -> None: + with pytest.raises(ValueError, match="grace_period_ms"): + _build_streaming_reservation_body( + _default_subject(), + _default_action(), + _default_estimate(), + ttl_ms=120_000, + overage_policy="ALLOW_IF_AVAILABLE", + grace_period_ms=60_001, + ) + + def test_subject_with_no_standard_fields_raises(self) -> None: + with pytest.raises(ValueError, match="at least one standard field"): + _build_streaming_reservation_body( + Subject(dimensions={"custom": "val"}), + _default_action(), + _default_estimate(), + ttl_ms=120_000, + overage_policy="ALLOW_IF_AVAILABLE", + grace_period_ms=None, + ) + + +# --------------------------------------------------------------------------- +# _resolve_actual_cost tests +# --------------------------------------------------------------------------- + + +class TestResolveActualCost: + def test_explicit_actual_cost(self) -> None: + u = StreamUsage(actual_cost=777) + assert _resolve_actual_cost(u, lambda _: 999, 1000) == 777 + + def test_cost_fn(self) -> None: + u = StreamUsage(tokens_input=100, tokens_output=50) + + def cost_fn(usage: StreamUsage) -> int: + return usage.tokens_input * 2 + usage.tokens_output * 3 + + assert _resolve_actual_cost(u, cost_fn, 1000) == 350 + + def test_cost_fn_error_falls_back_to_estimate(self) -> None: + u = StreamUsage() + + def bad_fn(_: StreamUsage) -> int: + raise ValueError("oops") + + assert _resolve_actual_cost(u, bad_fn, 1000) == 1000 + + def test_fallback_to_estimate(self) -> None: + u = StreamUsage() + assert _resolve_actual_cost(u, None, 500) == 500 + + +# --------------------------------------------------------------------------- +# StreamReservation (sync) tests +# --------------------------------------------------------------------------- + + +class TestStreamReservation: + def test_successful_reserve_and_commit(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, # disable heartbeat for test + ) + + with sr as reservation: + assert reservation.reservation_id == "rsv_stream_test" + assert reservation.decision == Decision.ALLOW + reservation.usage.tokens_input = 50 + reservation.usage.tokens_output = 25 + + mock.create_reservation.assert_called_once() + mock.commit_reservation.assert_called_once() + mock.release_reservation.assert_not_called() + + def test_exception_triggers_release(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.release_reservation.return_value = _release_success() + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(RuntimeError, match="stream error"): + with sr: + raise RuntimeError("stream error") + + mock.commit_reservation.assert_not_called() + mock.release_reservation.assert_called_once() + + def test_deny_raises_protocol_error(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _deny_response() + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(CyclesProtocolError, match="Reservation denied"): + with sr: + pass + + def test_reservation_failure_raises(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = CyclesResponse.http_error( + 500, + "Server error", + body=None, + ) + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(CyclesProtocolError, match="Failed to create reservation"): + with sr: + pass + + def test_missing_reservation_id_raises(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = CyclesResponse.success( + 200, + { + "decision": "ALLOW", + "affected_scopes": ["tenant:acme"], + }, + ) + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(CyclesProtocolError, match="reservation_id missing"): + with sr: + pass + + def test_cost_fn_used_for_commit(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + + def cost_fn(usage: StreamUsage) -> int: + return usage.tokens_input * 10 + usage.tokens_output * 20 + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + cost_fn=cost_fn, + ) + + with sr as reservation: + reservation.usage.tokens_input = 100 + reservation.usage.tokens_output = 50 + + # Check the commit body had actual = 100*10 + 50*20 = 2000 + commit_call = mock.commit_reservation.call_args + commit_body = commit_call[0][1] + assert commit_body["actual"]["amount"] == 2000 + + def test_actual_cost_overrides_cost_fn(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + cost_fn=lambda _: 9999, + ) + + with sr as reservation: + reservation.usage.set_actual_cost(42) + + commit_body = mock.commit_reservation.call_args[0][1] + assert commit_body["actual"]["amount"] == 42 + + def test_fallback_to_estimate_when_no_cost_fn(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with sr: + pass + + commit_body = mock.commit_reservation.call_args[0][1] + assert commit_body["actual"]["amount"] == 1000 # estimate amount + + def test_caps_propagated(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response( + caps={"max_tokens": 512}, + ) + mock.commit_reservation.return_value = _commit_success() + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with sr as reservation: + assert reservation.caps is not None + assert reservation.caps.max_tokens == 512 + + def test_context_set_and_cleared(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + ctx_inside = None + with sr: + ctx_inside = get_cycles_context() + + assert ctx_inside is not None + assert ctx_inside.reservation_id == "rsv_stream_test" + assert get_cycles_context() is None + + def test_context_cleared_on_exception(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.release_reservation.return_value = _release_success() + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(ValueError): + with sr: + raise ValueError("boom") + + assert get_cycles_context() is None + + def test_reservation_id_not_available_outside_context(self) -> None: + mock = _make_mock_client() + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + with pytest.raises(RuntimeError, match="not available outside"): + _ = sr.reservation_id + + def test_commit_server_error_schedules_retry(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = CyclesResponse.http_error( + 500, + "Server error", + body=None, + ) + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + # retry_enabled=False in config, so retry is just logged, no crash + with sr: + pass + + mock.commit_reservation.assert_called_once() + + def test_commit_finalized_does_not_release(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = CyclesResponse.http_error( + 409, + "Finalized", + body={"error": "RESERVATION_FINALIZED", "message": "Already committed", "request_id": "r1"}, + ) + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with sr: + pass + + mock.release_reservation.assert_not_called() + + def test_commit_expired_does_not_release(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = CyclesResponse.http_error( + 410, + "Expired", + body={"error": "RESERVATION_EXPIRED", "message": "Expired", "request_id": "r1"}, + ) + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with sr: + pass + + mock.release_reservation.assert_not_called() + + def test_commit_client_error_triggers_release(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = CyclesResponse.http_error( + 400, + "Bad request", + body={"error": "VALIDATION_ERROR", "message": "Bad", "request_id": "r1"}, + ) + mock.release_reservation.return_value = _release_success() + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with sr: + pass + + mock.release_reservation.assert_called_once() + + def test_commit_exception_schedules_retry(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.side_effect = Exception("network down") + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + # Should not raise; logs and attempts retry (disabled in test config) + with sr: + pass + + def test_metadata_passed_to_commit(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + metadata={"source": "test"}, + ) + + with sr: + pass + + commit_body = mock.commit_reservation.call_args[0][1] + assert commit_body["metadata"] == {"source": "test"} + + def test_metrics_include_tokens(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with sr as reservation: + reservation.usage.tokens_input = 100 + reservation.usage.tokens_output = 50 + reservation.usage.model_version = "gpt-4o-2024" + + commit_body = mock.commit_reservation.call_args[0][1] + assert commit_body["metrics"]["tokens_input"] == 100 + assert commit_body["metrics"]["tokens_output"] == 50 + assert commit_body["metrics"]["model_version"] == "gpt-4o-2024" + + def test_heartbeat_starts_and_stops(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + mock.extend_reservation.return_value = CyclesResponse.success( + 200, + { + "status": "EXTENDED", + "expires_at_ms": int(time.time() * 1000) + 600_000, + }, + ) + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=2000, # heartbeat every 1s + ) + + with sr: + # Wait long enough for at least one heartbeat + time.sleep(1.2) + + mock.extend_reservation.assert_called() + + +# --------------------------------------------------------------------------- +# AsyncStreamReservation tests +# --------------------------------------------------------------------------- + + +class TestAsyncStreamReservation: + @pytest.mark.asyncio + async def test_successful_reserve_and_commit(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + async with asr as reservation: + assert reservation.reservation_id == "rsv_stream_test" + reservation.usage.tokens_input = 50 + + mock.create_reservation.assert_called_once() + mock.commit_reservation.assert_called_once() + mock.release_reservation.assert_not_called() + + @pytest.mark.asyncio + async def test_exception_triggers_release(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.release_reservation.return_value = _release_success() + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(RuntimeError, match="async stream error"): + async with asr: + raise RuntimeError("async stream error") + + mock.commit_reservation.assert_not_called() + mock.release_reservation.assert_called_once() + + @pytest.mark.asyncio + async def test_deny_raises(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _deny_response() + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(CyclesProtocolError, match="Reservation denied"): + async with asr: + pass + + @pytest.mark.asyncio + async def test_cost_fn_used(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + cost_fn=lambda u: u.tokens_input * 5, + ) + + async with asr as reservation: + reservation.usage.tokens_input = 300 + + commit_body = mock.commit_reservation.call_args[0][1] + # 300 * 5 = 1500, distinct from estimate (1000) + assert commit_body["actual"]["amount"] == 1500 + + @pytest.mark.asyncio + async def test_context_set_and_cleared(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + ctx_inside = None + async with asr: + ctx_inside = get_cycles_context() + + assert ctx_inside is not None + assert ctx_inside.reservation_id == "rsv_stream_test" + assert get_cycles_context() is None + + @pytest.mark.asyncio + async def test_reservation_id_not_available_outside(self) -> None: + mock = _make_async_mock_client() + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + with pytest.raises(RuntimeError, match="not available outside"): + _ = asr.reservation_id + + @pytest.mark.asyncio + async def test_commit_server_error_schedules_retry(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = CyclesResponse.http_error( + 500, + "Server error", + body=None, + ) + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + async with asr: + pass + + mock.commit_reservation.assert_called_once() + + @pytest.mark.asyncio + async def test_commit_client_error_triggers_release(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = CyclesResponse.http_error( + 400, + "Bad request", + body={"error": "VALIDATION_ERROR", "message": "Bad", "request_id": "r1"}, + ) + mock.release_reservation.return_value = _release_success() + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + async with asr: + pass + + mock.release_reservation.assert_called_once() + + @pytest.mark.asyncio + async def test_missing_reservation_id_raises(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = CyclesResponse.success( + 200, + { + "decision": "ALLOW", + "affected_scopes": ["tenant:acme"], + }, + ) + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(CyclesProtocolError, match="reservation_id missing"): + async with asr: + pass + + @pytest.mark.asyncio + async def test_reservation_failure_raises(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = CyclesResponse.http_error( + 500, + "Server error", + body=None, + ) + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(CyclesProtocolError, match="Failed to create reservation"): + async with asr: + pass + + +# --------------------------------------------------------------------------- +# Client convenience method tests +# --------------------------------------------------------------------------- + + +class TestClientStreamReservation: + def test_sync_client_returns_stream_reservation(self) -> None: + config = _make_config() + mock_http = MagicMock() + with patch("runcycles.client.httpx.Client", return_value=mock_http): + client = CyclesClient(config) + sr = client.stream_reservation( + action=_default_action(), + estimate=_default_estimate(), + ) + assert isinstance(sr, StreamReservation) + + def test_sync_client_uses_config_subject(self) -> None: + config = CyclesConfig( + base_url="http://localhost:7878", + api_key="test", + tenant="acme", + workspace="prod", + ) + mock_http = MagicMock() + with patch("runcycles.client.httpx.Client", return_value=mock_http): + client = CyclesClient(config) + sr = client.stream_reservation( + action=_default_action(), + estimate=_default_estimate(), + ) + assert sr._subject.tenant == "acme" + assert sr._subject.workspace == "prod" + + def test_sync_client_explicit_subject_overrides(self) -> None: + config = _make_config() + mock_http = MagicMock() + with patch("runcycles.client.httpx.Client", return_value=mock_http): + client = CyclesClient(config) + custom_subject = Subject(tenant="other") + sr = client.stream_reservation( + subject=custom_subject, + action=_default_action(), + estimate=_default_estimate(), + ) + assert sr._subject.tenant == "other" + + def test_async_client_returns_async_stream_reservation(self) -> None: + config = _make_config() + mock_http = MagicMock() + with patch("runcycles.client.httpx.AsyncClient", return_value=mock_http): + client = AsyncCyclesClient(config) + asr = client.stream_reservation( + action=_default_action(), + estimate=_default_estimate(), + ) + assert isinstance(asr, AsyncStreamReservation) + + +# --------------------------------------------------------------------------- +# Budget-exceeded (typed exception) test +# --------------------------------------------------------------------------- + + +class TestStreamReservationEdgeCases: + def test_unrecognized_commit_response(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + # Status 302 is neither client error, server error, nor transport error + mock.commit_reservation.return_value = CyclesResponse.http_error( + 302, + "Redirect", + body=None, + ) + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with sr: + pass + + mock.release_reservation.assert_not_called() + + def test_release_failure_logged(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.release_reservation.return_value = CyclesResponse.http_error( + 500, + "Server error", + body=None, + ) + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(ValueError): + with sr: + raise ValueError("boom") + + mock.release_reservation.assert_called_once() + + def test_release_exception_logged(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.release_reservation.side_effect = Exception("network down") + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(ValueError): + with sr: + raise ValueError("boom") + + def test_commit_idempotency_mismatch_does_not_release(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = CyclesResponse.http_error( + 409, + "Idempotency mismatch", + body={"error": "IDEMPOTENCY_MISMATCH", "message": "Mismatch", "request_id": "r1"}, + ) + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with sr: + pass + + mock.release_reservation.assert_not_called() + + def test_ctx_metrics_respected(self) -> None: + """If user sets ctx.metrics during streaming, those should be used instead of StreamUsage.""" + mock = _make_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + + from runcycles.models import CyclesMetrics + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with sr: + ctx = get_cycles_context() + assert ctx is not None + ctx.metrics = CyclesMetrics(tokens_input=999, tokens_output=888, model_version="custom") + + commit_body = mock.commit_reservation.call_args[0][1] + assert commit_body["metrics"]["tokens_input"] == 999 + assert commit_body["metrics"]["tokens_output"] == 888 + assert commit_body["metrics"]["model_version"] == "custom" + + +class TestAsyncStreamReservationEdgeCases: + @pytest.mark.asyncio + async def test_commit_finalized_does_not_release(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = CyclesResponse.http_error( + 409, + "Finalized", + body={"error": "RESERVATION_FINALIZED", "message": "Done", "request_id": "r1"}, + ) + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + async with asr: + pass + + mock.release_reservation.assert_not_called() + + @pytest.mark.asyncio + async def test_commit_expired_does_not_release(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = CyclesResponse.http_error( + 410, + "Expired", + body={"error": "RESERVATION_EXPIRED", "message": "Expired", "request_id": "r1"}, + ) + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + async with asr: + pass + + mock.release_reservation.assert_not_called() + + @pytest.mark.asyncio + async def test_unrecognized_commit_response(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = CyclesResponse.http_error( + 302, + "Redirect", + body=None, + ) + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + async with asr: + pass + + mock.release_reservation.assert_not_called() + + @pytest.mark.asyncio + async def test_commit_exception_schedules_retry(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.side_effect = Exception("network down") + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + async with asr: + pass + + @pytest.mark.asyncio + async def test_commit_idempotency_mismatch_does_not_release(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = CyclesResponse.http_error( + 409, + "Idempotency mismatch", + body={"error": "IDEMPOTENCY_MISMATCH", "message": "Mismatch", "request_id": "r1"}, + ) + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + async with asr: + pass + + mock.release_reservation.assert_not_called() + + @pytest.mark.asyncio + async def test_release_failure_logged(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.release_reservation.return_value = CyclesResponse.http_error( + 500, + "Server error", + body=None, + ) + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(ValueError): + async with asr: + raise ValueError("boom") + + @pytest.mark.asyncio + async def test_release_exception_logged(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.release_reservation.side_effect = Exception("network down") + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(ValueError): + async with asr: + raise ValueError("boom") + + @pytest.mark.asyncio + async def test_heartbeat_starts_and_stops(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + mock.extend_reservation.return_value = CyclesResponse.success( + 200, + { + "status": "EXTENDED", + "expires_at_ms": int(time.time() * 1000) + 600_000, + }, + ) + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=2000, + ) + + async with asr: + await asyncio.sleep(1.2) + + mock.extend_reservation.assert_called() + + @pytest.mark.asyncio + async def test_metadata_passed_to_commit(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response() + mock.commit_reservation.return_value = _commit_success() + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + metadata={"key": "val"}, + ) + + async with asr: + pass + + commit_body = mock.commit_reservation.call_args[0][1] + assert commit_body["metadata"] == {"key": "val"} + + @pytest.mark.asyncio + async def test_caps_propagated(self) -> None: + mock = _make_async_mock_client() + mock.create_reservation.return_value = _allow_response( + caps={"max_tokens": 256}, + ) + mock.commit_reservation.return_value = _commit_success() + + asr = AsyncStreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + async with asr as reservation: + assert reservation.caps is not None + assert reservation.caps.max_tokens == 256 + + +class TestBudgetExceeded: + def test_budget_exceeded_raises_typed_error(self) -> None: + mock = _make_mock_client() + mock.create_reservation.return_value = CyclesResponse.http_error( + 409, + "Budget exceeded", + body={ + "error": "BUDGET_EXCEEDED", + "message": "No budget", + "request_id": "req-1", + }, + ) + + sr = StreamReservation( + mock, + subject=_default_subject(), + action=_default_action(), + estimate=_default_estimate(), + ttl_ms=1000, + ) + + with pytest.raises(BudgetExceededError): + with sr: + pass