Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions splunklib/ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,8 @@ triggers the retry logic described above. A custom `model_middleware` can interc
to observe, log, or override the retry behavior. A custom `model_middleware` can also raise
the `StructuredOutputGenerationException` manually to reject structured output and force a re-generation.

The number maximal of re-tries is limited per agent loop invocation see [Default limit middlewares][#default-limit-middlewares].

### Subagents with structured output/input

In addition to output schemas, subagents can define input schemas. These schemas both constrain
Expand Down Expand Up @@ -926,7 +928,7 @@ async with Agent(
) as agent: ...
```

### Default limit middlewares
## Default limit middlewares

Every `Agent` automatically applies sane default limits to prevent runaway execution
or excessive token usage. Default limit middlewares are appended after any user-supplied
Expand All @@ -939,15 +941,17 @@ chain - place it last if you want the same behavior.
| `TokenLimitMiddleware` | 200 000 tokens | token count of messages passed to the model |
| `StepLimitMiddleware` | 100 steps | steps taken |
| `TimeoutLimitMiddleware` | 600 seconds (10 minutes) | per `invoke` call |
| `StructuredOutputRetryLimitMiddleware` | 3 retries | per `invoke` call |

`TokenLimitMiddleware` and `StepLimitMiddleware` check the values from the messages passed to the
model on each call. `TimeoutLimitMiddleware` resets its deadline on each `invoke`, so every call
gets a fresh time budget.
model on each call. `TimeoutLimitMiddleware` and `StructuredOutputRetryLimitMiddlewa` resets its
deadline/limit on each `invoke`, so effectively these limit only the agent loop.

When a limit is exceeded, the agent raises the corresponding exception:
`TokenLimitExceededException`, `StepsLimitExceededException`, or `TimeoutExceededException`.
`TokenLimitExceededException`, `StepsLimitExceededException`, or `TimeoutExceededException`,
`StructuredOutputRetryLimitExceededException`.

#### Overriding defaults
### Overriding defaults

To override a specific limit, pass your own instance of the corresponding middleware
class. The default for that limit is suppressed automatically - the other defaults
Expand All @@ -970,13 +974,18 @@ To override all defaults, pass all three:
async with Agent(
...,
middleware=[
StructuredOutputRetryLimitMiddleware(0), # no-retries.
TokenLimitMiddleware(50_000),
StepLimitMiddleware(10),
TimeoutLimitMiddleware(30.0),
],
) as agent: ...
```

**Note**: When overriding limit middlewares, order matters. Place `StructuredOutputRetryLimitMiddleware`
first and `TokenLimitMiddleware`, `StepLimitMiddleware`, and `TimeoutLimitMiddleware` last,
otherwise the limits may not behave as expected.

There is no explicit opt-out - the intent is that agents should always have some guardrails.

## Logger
Expand Down
18 changes: 14 additions & 4 deletions splunklib/ai/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
from splunklib.ai.conversation_store import ConversationStore
from splunklib.ai.hooks import (
DEFAULT_STEP_LIMIT,
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT,
DEFAULT_TIMEOUT_SECONDS,
DEFAULT_TOKEN_LIMIT,
StepLimitMiddleware,
StructuredOutputRetryLimitMiddleware,
TimeoutLimitMiddleware,
TokenLimitMiddleware,
)
Expand Down Expand Up @@ -79,16 +81,24 @@ def __init__(
self._output_schema = output_schema
user_middleware = tuple(middleware) if middleware else ()
user_middleware_types = {type(m) for m in user_middleware}

# NOTE: we're creating separate instances per agent - TimeoutLimitMiddleware is stateful
# and sharing one would cause agents to overwrite each other's deadline.
predefined: list[AgentMiddleware] = [
predefined_before: list[AgentMiddleware] = [
StructuredOutputRetryLimitMiddleware(DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT),
]
predefined_after: list[AgentMiddleware] = [
TokenLimitMiddleware(DEFAULT_TOKEN_LIMIT),
StepLimitMiddleware(DEFAULT_STEP_LIMIT),
TimeoutLimitMiddleware(DEFAULT_TIMEOUT_SECONDS),
]
# Append predefined middlewares by default if not provided already.
default_middleware = [m for m in predefined if type(m) not in user_middleware_types]
self._middleware = (*user_middleware, *default_middleware)

self._middleware = (
*{m for m in predefined_before if type(m) not in user_middleware_types},
*user_middleware,
*{m for m in predefined_after if type(m) not in user_middleware_types},
)

self._trace_id = secrets.token_hex(16) # 32 Hex characters
self._conversation_store = conversation_store
self._thread_id = thread_id
Expand Down
5 changes: 0 additions & 5 deletions splunklib/ai/engines/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,11 +883,6 @@ async def llm_handler(req: ModelRequest) -> ModelResponse:
except StructuredOutputGenerationException as e:
# Structured output generation failed, retry.

# TODO: we should provide a mechanism to limit the amount of retries
# thath happen sequentially (say 3), otherwise raise a different exception.
# For now this can be done with the use of model middleware that counts
# the amount of StructuredOutputGenerationException that were raised.

ai_msg = _map_message_to_langchain(e.message)
assert isinstance(ai_msg, LC_AIMessage)

Expand Down
49 changes: 49 additions & 0 deletions splunklib/ai/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
ModelRequest,
ModelResponse,
)
from splunklib.ai.structured_output import StructuredOutputGenerationException

DEFAULT_TIMEOUT_SECONDS: float = 600.0
DEFAULT_STEP_LIMIT: int = 100
DEFAULT_TOKEN_LIMIT: int = 200_000
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3


class AgentStopException(Exception):
Expand Down Expand Up @@ -43,6 +45,13 @@ def __init__(self, timeout_seconds: float) -> None:
super().__init__(f"Timed out after {timeout_seconds} seconds.")


class StructuredOutputRetryLimitExceededException(AgentStopException):
"""Raised by `Agent.invoke`, when structured output retry limit exceeds"""

def __init__(self, retry_count: int) -> None:
super().__init__(f"Structured output retry limit of {retry_count} exceeded")


def before_model(
func: Callable[[ModelRequest], None | Awaitable[None]],
) -> AgentMiddleware:
Expand Down Expand Up @@ -199,3 +208,43 @@ async def model_middleware(
if self._deadline is not None and monotonic() >= self._deadline:
raise TimeoutExceededException(timeout_seconds=self._seconds)
return await handler(request)


class StructuredOutputRetryLimitMiddleware(AgentMiddleware):
"""Stops agent execution when the agent exceeds structured output
retry limit during a single agent loop invocation.
"""

_limit: int
_retries_per_thread_id: dict[str, int]

def __init__(self, limit: int) -> None:
self._limit = limit
self._retries_per_thread_id = {}

@override
async def agent_middleware(
self,
request: AgentRequest,
handler: AgentMiddlewareHandler,
) -> AgentResponse[Any | None]:
try:
# Agent loop starting.
self._retries_per_thread_id[request.thread_id] = 0
return await handler(request)
finally:
del self._retries_per_thread_id[request.thread_id] # don't leak memory

@override
async def model_middleware(
self,
request: ModelRequest,
handler: ModelMiddlewareHandler,
) -> ModelResponse:
try:
return await handler(request)
except StructuredOutputGenerationException:
self._retries_per_thread_id[request.state.thread_id] += 1
if self._retries_per_thread_id[request.state.thread_id] > self._limit:
raise StructuredOutputRetryLimitExceededException(self._limit)
raise # re-raise, to retry structured output generation
4 changes: 2 additions & 2 deletions tests/ai_testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _json_body_matcher(r1: Any, r2: Any) -> None:
my_vcr = vcr.VCR(
cassette_library_dir=snapshot_dir,
serializer="json-friendly",
record_mode=RecordMode.ONCE,
record_mode=RecordMode.NEW_EPISODES,
match_on=[
"method",
"scheme",
Expand All @@ -184,7 +184,7 @@ def _json_body_matcher(r1: Any, r2: Any) -> None:
],
before_record_request=_before_record_request,
before_record_response=_before_record_response,
record_on_exception=False,
# record_on_exception=False,
drop_unused_requests=True,
)
my_vcr.register_serializer("json-friendly", _JSONFriendlySerializer())
Expand Down
Loading