diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index c59623215..9f6d54ff9 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -3,6 +3,7 @@ It includes: - ConversationManager: Abstract base class defining the conversation management interface +- ProactiveCompressionConfig: Configuration type for proactive compression settings - NullConversationManager: A no-op implementation that does not modify conversation history - SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context size while preserving conversation coherence @@ -13,7 +14,7 @@ is critical for effective agent interactions. """ -from .conversation_manager import ConversationManager +from .conversation_manager import ConversationManager, ProactiveCompressionConfig from .null_conversation_manager import NullConversationManager from .sliding_window_conversation_manager import SlidingWindowConversationManager from .summarizing_conversation_manager import SummarizingConversationManager @@ -21,6 +22,7 @@ __all__ = [ "ConversationManager", "NullConversationManager", + "ProactiveCompressionConfig", "SlidingWindowConversationManager", "SummarizingConversationManager", ] diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 690ecbde5..7e2283883 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -1,14 +1,33 @@ """Abstract interface for conversation history management.""" +import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict, Union +from ...hooks.events import BeforeModelCallEvent from ...hooks.registry import HookProvider, HookRegistry from ...types.content import Message if TYPE_CHECKING: from ...agent.agent import Agent +logger = logging.getLogger(__name__) + +DEFAULT_COMPRESSION_THRESHOLD = 0.7 +DEFAULT_CONTEXT_WINDOW_LIMIT = 200_000 + + +class ProactiveCompressionConfig(TypedDict, total=False): + """Configuration for proactive compression when passed as an object. + + Attributes: + compression_threshold: Ratio of context window usage that triggers proactive compression. + Value between 0 (exclusive) and 1 (inclusive). + Defaults to 0.7 (compress when 70% of the context window is used). + """ + + compression_threshold: float + class ConversationManager(ABC, HookProvider): """Abstract base class for managing conversation history. @@ -22,45 +41,122 @@ class ConversationManager(ABC, HookProvider): ConversationManager implements the HookProvider protocol, allowing derived classes to register hooks for agent lifecycle events. Derived classes that override register_hooks must call the base implementation to ensure proper - hook registration. + hook registration chain. + + The primary responsibility of a ConversationManager is overflow recovery: when the model encounters a context + window overflow, :meth:`reduce_context` is called with ``e`` set and MUST reduce the history enough for the next + model call to succeed. + + Subclasses can enable proactive compression by passing ``proactive_compression`` in the constructor. + When enabled, the base class registers a ``BeforeModelCallEvent`` hook that checks projected input tokens + against the model's context window limit and calls :meth:`reduce_context` (without ``e``) when the + threshold is exceeded. This is a best-effort operation — errors are swallowed so the model call can + still proceed. Example: ```python - class MyConversationManager(ConversationManager): - def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: - super().register_hooks(registry, **kwargs) - # Register additional hooks here + # Enable proactive compression with default threshold (0.7) + SlidingWindowConversationManager(window_size=50, proactive_compression=True) + + # Enable proactive compression with custom threshold + SummarizingConversationManager(proactive_compression={"compression_threshold": 0.8}) ``` """ - def __init__(self) -> None: + def __init__(self, *, proactive_compression: Union[bool, "ProactiveCompressionConfig", None] = None) -> None: """Initialize the ConversationManager. + Args: + proactive_compression: Enable proactive context compression before the model call. + - ``True``: compress when 70% of the context window is used (default threshold). + - ``{"compression_threshold": float}``: compress at the specified ratio (0, 1]. + - ``False`` or ``None``: disabled, only reactive overflow recovery is used. + + Raises: + ValueError: If compression_threshold is not in the valid range (0, 1]. + Attributes: removed_message_count: The messages that have been removed from the agents messages array. These represent messages provided by the user or LLM that have been removed, not messages included by the conversation manager through something like summarization. """ + # Resolve the threshold from proactive_compression parameter + if proactive_compression is True: + threshold: float | None = DEFAULT_COMPRESSION_THRESHOLD + elif isinstance(proactive_compression, dict): + threshold = proactive_compression.get("compression_threshold", DEFAULT_COMPRESSION_THRESHOLD) + else: + threshold = None + + if threshold is not None and (threshold <= 0 or threshold > 1): + raise ValueError( + f"compression_threshold must be between 0 (exclusive) and 1 (inclusive), got {threshold}" + ) + self.removed_message_count = 0 + self._compression_threshold = threshold + self._context_window_limit_warned = False def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: """Register hooks for agent lifecycle events. + Always registers a ``BeforeModelCallEvent`` hook for proactive compression. + When ``proactive_compression`` is not configured, the handler is a no-op (early return). + Derived classes that override this method must call the base implementation to ensure proper hook registration chain. Args: registry: The hook registry to register callbacks with. **kwargs: Additional keyword arguments for future extensibility. + """ + # Always subscribe — the threshold check happens inside the handler + registry.add_callback(BeforeModelCallEvent, self._on_before_model_call_threshold) - Example: - ```python - def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: - super().register_hooks(registry, **kwargs) - registry.add_callback(SomeEvent, self.on_some_event) - ``` + def _on_before_model_call_threshold(self, event: BeforeModelCallEvent) -> None: + """Handle BeforeModelCallEvent for proactive compression. + + When proactive compression is not configured, this is a no-op. + When configured, checks projected input tokens against the context window limit + and calls reduce_context() without error (best-effort) when threshold is exceeded. + + Args: + event: The before model call event. """ - pass + # Early return if proactive compression is not enabled + if self._compression_threshold is None: + return + + context_window_limit = event.agent.model.context_window_limit + if context_window_limit is None: + context_window_limit = DEFAULT_CONTEXT_WINDOW_LIMIT + if not self._context_window_limit_warned: + self._context_window_limit_warned = True + logger.warning( + "context_window_limit=<%s> | context_window_limit not set on model, using default." + " Set context_window_limit in your model config for accurate proactive compression", + DEFAULT_CONTEXT_WINDOW_LIMIT, + ) + + if event.projected_input_tokens is None: + logger.debug("projected_input_tokens= | skipping proactive compression") + return + + ratio = event.projected_input_tokens / context_window_limit + if ratio >= self._compression_threshold: + logger.debug( + "projected_tokens=<%s>, limit=<%s>, ratio=<%.2f>, compression_threshold=<%s>" + " | compression threshold exceeded, reducing context", + event.projected_input_tokens, + context_window_limit, + ratio, + self._compression_threshold, + ) + # Proactive compression is best-effort: swallow errors so the model call can still proceed. + try: + self.reduce_context(agent=event.agent) + except Exception: + logger.debug("proactive compression failed, will proceed with model call", exc_info=True) def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None: """Restore the Conversation Manager's state from a session. @@ -99,22 +195,24 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: @abstractmethod def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: - """Called when the model's context window is exceeded. - - This method should implement the specific strategy for reducing the window size when a context overflow occurs. - It is typically called after a ContextWindowOverflowException is caught. + """Reduce the conversation history. - Implementations might use strategies such as: + Called in two scenarios: + 1. **Reactive** (e is set): A context window overflow occurred. The implementation + MUST remove enough history for the next model call to succeed, or re-raise the error. + 2. **Proactive** (e is None): The compression threshold was exceeded. This is best-effort — + returning without reduction or raising is acceptable; the model call proceeds regardless. - - Removing the N oldest messages - - Summarizing older context - - Applying importance-based filtering - - Maintaining critical conversation markers + Implementations should modify ``agent.messages`` in-place. Args: agent: The agent whose conversation history will be reduced. This list is modified in-place. e: The exception that triggered the context reduction, if any. + When set, this is a reactive overflow recovery call — the implementation MUST + reduce enough history for the next model call to succeed. + When None, this is a proactive compression call — best-effort reduction to avoid + hitting the context window limit. **kwargs: Additional keyword arguments for future extensibility. """ pass diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 11632525d..4077cb08b 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -5,7 +5,6 @@ if TYPE_CHECKING: from ...agent.agent import Agent -from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -29,7 +28,10 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: pass def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: - """Does not reduce context and raises an exception. + """Does not reduce context. + + When called reactively (e is not None), re-raises the overflow exception since this + manager cannot reduce context. When called proactively (e is None), returns silently. Args: agent: The agent whose conversation history will remain unmodified. @@ -37,10 +39,7 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A **kwargs: Additional keyword arguments for future extensibility. Raises: - e: If provided. - ContextWindowOverflowException: If e is None. + e: If provided (reactive overflow). """ if e: raise e - else: - raise ContextWindowOverflowException("Context window overflowed!") diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 1b45dd42c..1ad8edc24 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -10,7 +10,7 @@ from ...types.content import ContentBlock, Messages from ...types.exceptions import ContextWindowOverflowException from ...types.tools import ToolResultContent -from .conversation_manager import ConversationManager +from .conversation_manager import ConversationManager, ProactiveCompressionConfig logger = logging.getLogger(__name__) @@ -37,6 +37,7 @@ def __init__( should_truncate_results: bool = True, *, per_turn: bool | int = False, + proactive_compression: bool | ProactiveCompressionConfig | None = None, ): """Initialize the sliding window conversation manager. @@ -54,6 +55,10 @@ def __init__( manage message history and prevent the agent loop from slowing down. Start with per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed for performance tuning. + proactive_compression: Enable proactive context compression before the model call. + - ``True``: compress when 70% of the context window is used (default threshold). + - ``{"compression_threshold": float}``: compress at the specified ratio (0, 1]. + - ``False`` or ``None``: disabled, only reactive overflow recovery is used. Raises: ValueError: If window_size is negative, or if per_turn is 0 or a negative integer. @@ -63,7 +68,7 @@ def __init__( if isinstance(per_turn, int) and not isinstance(per_turn, bool) and per_turn <= 0: raise ValueError(f"per_turn must be a positive integer, True, or False, got {per_turn}") - super().__init__() + super().__init__(proactive_compression=proactive_compression) self.window_size = window_size self.should_truncate_results = should_truncate_results @@ -158,6 +163,12 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Trim the oldest messages to reduce the conversation context size. + When ``e`` is set (reactive overflow recovery), attempts to truncate large tool results + first before falling back to message trimming. + + When ``e`` is None (proactive compression or routine management), only trims messages + without attempting tool result truncation. + The method handles special cases where trimming the messages leads to: - toolResult with no corresponding toolUse - toolUse with no corresponding toolResult @@ -166,12 +177,14 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A agent: The agent whose messages will be reduce. This list is modified in-place. e: The exception that triggered the context reduction, if any. + When set, this is a reactive overflow recovery call. + When None, this is a proactive or routine management call. **kwargs: Additional keyword arguments for future extensibility. Raises: ContextWindowOverflowException: If the context cannot be reduced further and a context overflow - error was provided (e is not None). When called during routine window management (e is None), - logs a warning and returns without modification. + error was provided (e is not None). When called during routine window management or + proactive compression (e is None), logs a warning and returns without modification. """ messages = agent.messages @@ -181,16 +194,18 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A messages[:] = [] return - # Try to truncate the tool result first - oldest_message_idx_with_tool_results = self._find_oldest_message_with_tool_results(messages) - if oldest_message_idx_with_tool_results is not None and self.should_truncate_results: - logger.debug( - "message_index=<%s> | found message with tool results at index", oldest_message_idx_with_tool_results - ) - results_truncated = self._truncate_tool_results(messages, oldest_message_idx_with_tool_results) - if results_truncated: - logger.debug("message_index=<%s> | tool results truncated", oldest_message_idx_with_tool_results) - return + # Try to truncate the tool result first (only for reactive overflow, not proactive compression) + if e is not None: + oldest_message_idx_with_tool_results = self._find_oldest_message_with_tool_results(messages) + if oldest_message_idx_with_tool_results is not None and self.should_truncate_results: + logger.debug( + "message_index=<%s> | found message with tool results at index", + oldest_message_idx_with_tool_results, + ) + results_truncated = self._truncate_tool_results(messages, oldest_message_idx_with_tool_results) + if results_truncated: + logger.debug("message_index=<%s> | tool results truncated", oldest_message_idx_with_tool_results) + return # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index abd4d08b5..2030e1d3b 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -12,7 +12,7 @@ from ...types.content import Message from ...types.exceptions import ContextWindowOverflowException from ...types.tools import AgentTool -from .conversation_manager import ConversationManager +from .conversation_manager import ConversationManager, ProactiveCompressionConfig if TYPE_CHECKING: from ..agent import Agent @@ -65,6 +65,8 @@ def __init__( preserve_recent_messages: int = 10, summarization_agent: Optional["Agent"] = None, summarization_system_prompt: str | None = None, + *, + proactive_compression: bool | ProactiveCompressionConfig | None = None, ): """Initialize the summarizing conversation manager. @@ -77,8 +79,12 @@ def __init__( If provided, this agent can use tools as part of the summarization process. summarization_system_prompt: Optional system prompt override for summarization. If None, uses the default summarization prompt. + proactive_compression: Enable proactive context compression before the model call. + - ``True``: compress when 70% of the context window is used (default threshold). + - ``{"compression_threshold": float}``: compress at the specified ratio (0, 1]. + - ``False`` or ``None``: disabled, only reactive overflow recovery is used. """ - super().__init__() + super().__init__(proactive_compression=proactive_compression) if summarization_agent is not None and summarization_system_prompt is not None: raise ValueError( "Cannot provide both summarization_agent and summarization_system_prompt. " @@ -126,54 +132,76 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Reduce context using summarization. + When ``e`` is set (reactive overflow recovery), summarization failure is re-raised — + the agent loop must not proceed with an overflow. + + When ``e`` is None (proactive compression), summarization failure is logged and + returns silently — the model call proceeds regardless. + Args: agent: The agent whose conversation history will be reduced. The agent's messages list is modified in-place. e: The exception that triggered the context reduction, if any. + When set, this is a reactive overflow recovery call. + When None, this is a proactive compression call (best-effort). **kwargs: Additional keyword arguments for future extensibility. Raises: - ContextWindowOverflowException: If the context cannot be summarized. + Exception: If summarization fails during reactive overflow recovery (e is set). """ try: - # Calculate how many messages to summarize - messages_to_summarize_count = max(1, int(len(agent.messages) * self.summary_ratio)) + self._summarize_oldest(agent) + except Exception as summarization_error: + if e is not None: + # Reactive: rethrow so the ContextWindowOverflowException propagates + logger.error("Summarization failed: %s", summarization_error) + raise summarization_error from e + # Proactive: best-effort, swallow errors so the model call can still proceed. + logger.warning("Proactive summarization failed, continuing: %s", summarization_error) - # Ensure we don't summarize recent messages - messages_to_summarize_count = min( - messages_to_summarize_count, len(agent.messages) - self.preserve_recent_messages - ) + def _summarize_oldest(self, agent: "Agent") -> None: + """Summarize the oldest messages and replace them with a summary. - if messages_to_summarize_count <= 0: - raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + Args: + agent: The agent instance. - # Adjust split point to avoid breaking ToolUse/ToolResult pairs - messages_to_summarize_count = self._adjust_split_point_for_tool_pairs( - agent.messages, messages_to_summarize_count - ) + Raises: + ContextWindowOverflowException: If there are insufficient messages for summarization. + """ + # Calculate how many messages to summarize + messages_to_summarize_count = max(1, int(len(agent.messages) * self.summary_ratio)) + + # Ensure we don't summarize recent messages + messages_to_summarize_count = min( + messages_to_summarize_count, len(agent.messages) - self.preserve_recent_messages + ) - if messages_to_summarize_count <= 0: - raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") - # Extract messages to summarize - messages_to_summarize = agent.messages[:messages_to_summarize_count] - remaining_messages = agent.messages[messages_to_summarize_count:] + # Adjust split point to avoid breaking ToolUse/ToolResult pairs + messages_to_summarize_count = self._adjust_split_point_for_tool_pairs( + agent.messages, messages_to_summarize_count + ) - # Keep track of the number of messages that have been summarized thus far. - self.removed_message_count += len(messages_to_summarize) - # If there is a summary message, don't count it in the removed_message_count. - if self._summary_message: - self.removed_message_count -= 1 + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") - # Generate summary - self._summary_message = self._generate_summary(messages_to_summarize, agent) + # Extract messages to summarize + messages_to_summarize = agent.messages[:messages_to_summarize_count] + remaining_messages = agent.messages[messages_to_summarize_count:] - # Replace the summarized messages with the summary - agent.messages[:] = [self._summary_message] + remaining_messages + # Keep track of the number of messages that have been summarized thus far. + self.removed_message_count += len(messages_to_summarize) + # If there is a summary message, don't count it in the removed_message_count. + if self._summary_message: + self.removed_message_count -= 1 - except Exception as summarization_error: - logger.error("Summarization failed: %s", summarization_error) - raise summarization_error from e + # Generate summary + self._summary_message = self._generate_summary(messages_to_summarize, agent) + + # Replace the summarized messages with the summary + agent.messages[:] = [self._summary_message] + remaining_messages def _generate_summary(self, messages: list[Message], agent: "Agent") -> Message: """Generate a summary of the provided messages. diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 8679e6fd7..df748241e 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -4,6 +4,7 @@ from strands import tool from strands.agent.agent import Agent +from strands.agent.conversation_manager.conversation_manager import ConversationManager from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.hooks.events import BeforeModelCallEvent @@ -300,7 +301,7 @@ def test_sliding_window_conversation_manager_with_tool_results_truncated(): ] test_agent = Agent(messages=messages) - manager.reduce_context(test_agent) + manager.reduce_context(test_agent, e=RuntimeError("context overflow")) result_text = messages[1]["content"][0]["toolResult"]["content"][0]["text"] assert result_text.startswith("A" * 200) @@ -310,8 +311,35 @@ def test_sliding_window_conversation_manager_with_tool_results_truncated(): assert messages[1]["content"][0]["toolResult"]["status"] == "success" -def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): - """Test that NullConversationManager doesn't modify messages.""" +def test_sliding_window_proactive_compression_skips_tool_result_truncation(): + """Proactive compression (e=None) should only trim messages, not truncate tool results.""" + large_text = "A" * 300 + "B" * 300 + "C" * 300 + manager = SlidingWindowConversationManager(window_size=2) + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "789", "content": [{"text": large_text}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "Done"}]}, + {"role": "user", "content": [{"text": "Next question"}]}, + ] + test_agent = Agent(messages=messages) + + manager.reduce_context(test_agent) # e=None (proactive) + + # Tool results should NOT be truncated during proactive compression + for msg in messages: + for content in msg.get("content", []): + if "toolResult" in content: + for item in content["toolResult"].get("content", []): + if "text" in item: + assert "... [truncated:" not in item["text"] + + +def test_null_conversation_manager_reduce_context_proactive_returns_silently(): + """Proactive compression (e=None) returns silently without raising.""" manager = NullConversationManager() messages = [ {"role": "user", "content": [{"text": "Hello"}]}, @@ -322,12 +350,25 @@ def test_null_conversation_manager_reduce_context_raises_context_window_overflow manager.apply_management(test_agent) - with pytest.raises(ContextWindowOverflowException): - manager.reduce_context(messages) + # Proactive call (e=None) should not raise + manager.reduce_context(test_agent) assert messages == original_messages +def test_null_conversation_manager_reduce_context_reactive_raises_overflow(): + """Reactive overflow (e is not None) re-raises the exception.""" + manager = NullConversationManager() + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + test_agent = Agent(messages=messages) + + with pytest.raises(ContextWindowOverflowException): + manager.reduce_context(test_agent, e=ContextWindowOverflowException("overflow")) + + def test_null_conversation_manager_reduce_context_with_exception_raises_same_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = NullConversationManager() @@ -400,9 +441,10 @@ def reduce_context(self, agent, e=None, **kwargs): manager = MinimalConversationManager() registry = HookRegistry() - # Should work without error + # Should work without error — the base class always registers the hook manager.register_hooks(registry) - assert not registry.has_callbacks() + # Base class always registers the proactive compression hook + assert registry.has_callbacks() def test_per_turn_hooks_registration(): @@ -555,7 +597,7 @@ def test_truncation_targets_oldest_message_first(): ] test_agent = Agent(messages=messages) - manager.reduce_context(test_agent) + manager.reduce_context(test_agent, e=RuntimeError("context overflow")) # The oldest tool result (index 1) must be truncated oldest_text = messages[1]["content"][0]["toolResult"]["content"][0]["text"] @@ -755,3 +797,242 @@ def test_window_size_zero_clears_on_overflow(): manager.reduce_context(test_agent, e=Exception("overflow")) assert messages == [] + + +# ============================================================================== +# Proactive Compression Tests (proactive_compression parameter) +# ============================================================================== + + +class _MinimalManager(ConversationManager): + """Manager that only implements abstract methods.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.reduce_context_call_count = 0 + + def apply_management(self, agent, **kwargs): + pass + + def reduce_context(self, agent, e=None, **kwargs): + self.reduce_context_call_count += 1 + if agent.messages: + agent.messages.pop(0) + + +def _make_mock_agent(messages=None, context_window_limit=1000): + agent = MagicMock() + agent.messages = messages if messages is not None else [] + agent.model = MagicMock() + agent.model.context_window_limit = context_window_limit + return agent + + +def _make_threshold_event(agent, projected_input_tokens=None): + return BeforeModelCallEvent( + agent=agent, + invocation_state={}, + projected_input_tokens=projected_input_tokens, + ) + + +def test_proactive_compression_rejects_zero(): + with pytest.raises(ValueError, match="compression_threshold must be between 0"): + _MinimalManager(proactive_compression={"compression_threshold": 0}) + + +def test_proactive_compression_rejects_negative(): + with pytest.raises(ValueError, match="compression_threshold must be between 0"): + _MinimalManager(proactive_compression={"compression_threshold": -0.5}) + + +def test_proactive_compression_rejects_greater_than_one(): + with pytest.raises(ValueError, match="compression_threshold must be between 0"): + _MinimalManager(proactive_compression={"compression_threshold": 1.5}) + + +def test_proactive_compression_accepts_exactly_one(): + manager = _MinimalManager(proactive_compression={"compression_threshold": 1.0}) + assert manager._compression_threshold == 1.0 + + +def test_proactive_compression_none_by_default(): + manager = _MinimalManager() + assert manager._compression_threshold is None + + +def test_proactive_compression_true_uses_default_threshold(): + """proactive_compression=True uses default threshold of 0.7.""" + manager = _MinimalManager(proactive_compression=True) + assert manager._compression_threshold == 0.7 + + +def test_proactive_compression_false_disables(): + """proactive_compression=False means no compression.""" + manager = _MinimalManager(proactive_compression=False) + assert manager._compression_threshold is None + + +def test_proactive_compression_always_registers_hook(): + """Hook is always registered regardless of proactive_compression setting.""" + manager = _MinimalManager() + registry = HookRegistry() + manager.register_hooks(registry) + # Always registers the hook + assert registry.has_callbacks() + + +def test_proactive_compression_hook_is_noop_when_not_configured(): + """BeforeModelCallEvent handler is a no-op when proactive_compression is not set.""" + manager = _MinimalManager() + agent = _make_mock_agent(context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=900) + registry.invoke_callbacks(event) + + assert manager.reduce_context_call_count == 0 + + +def test_proactive_compression_calls_reduce_context_when_exceeded(): + manager = _MinimalManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(messages=[{"role": "user", "content": [{"text": "msg"}]}], context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=800) + registry.invoke_callbacks(event) + + assert manager.reduce_context_call_count == 1 + + +def test_proactive_compression_no_call_when_below(): + manager = _MinimalManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=500) + registry.invoke_callbacks(event) + + assert manager.reduce_context_call_count == 0 + + +def test_proactive_compression_no_call_when_projected_tokens_none(): + manager = _MinimalManager(proactive_compression=True) + agent = _make_mock_agent(context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=None) + registry.invoke_callbacks(event) + + assert manager.reduce_context_call_count == 0 + + +def test_proactive_compression_uses_default_when_context_window_limit_not_set(): + manager = _MinimalManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(context_window_limit=None) + registry = HookRegistry() + manager.register_hooks(registry) + + # projected_input_tokens=150_000 is 75% of the 200k default, exceeding 0.7 threshold + event = _make_threshold_event(agent, projected_input_tokens=150_000) + with patch("strands.agent.conversation_manager.conversation_manager.logger") as mock_logger: + registry.invoke_callbacks(event) + mock_logger.warning.assert_called_once() + assert "using default" in mock_logger.warning.call_args[0][0] + + assert manager.reduce_context_call_count == 1 + + +def test_proactive_compression_warns_only_once_per_instance(): + """Second invocation on the same manager instance suppresses the context_window_limit warning.""" + manager = _MinimalManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(context_window_limit=None) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=150_000) + with patch("strands.agent.conversation_manager.conversation_manager.logger") as mock_logger: + registry.invoke_callbacks(event) + registry.invoke_callbacks(event) + assert mock_logger.warning.call_count == 1 + + +def test_proactive_compression_exception_swallowed(): + """Exceptions in reduce_context during proactive compression should not propagate.""" + + class _FailingManager(ConversationManager): + def apply_management(self, agent, **kwargs): + pass + + def reduce_context(self, agent, e=None, **kwargs): + raise RuntimeError("boom") + + manager = _FailingManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=800) + registry.invoke_callbacks(event) + + +def test_proactive_compression_true_default_threshold_behavior(): + """proactive_compression=True uses 0.7 — triggered at 0.7+ but not below.""" + manager = _MinimalManager(proactive_compression=True) + agent = _make_mock_agent( + messages=[{"role": "user", "content": [{"text": "msg"}]}], context_window_limit=1000 + ) + registry = HookRegistry() + manager.register_hooks(registry) + + # 650/1000 = 0.65 < 0.7 — should NOT trigger + event = _make_threshold_event(agent, projected_input_tokens=650) + registry.invoke_callbacks(event) + assert manager.reduce_context_call_count == 0 + + # 800/1000 = 0.8 >= 0.7 — should trigger + event2 = _make_threshold_event(agent, projected_input_tokens=800) + registry.invoke_callbacks(event2) + assert manager.reduce_context_call_count == 1 + + +def test_sliding_window_proactive_compression_trims(): + manager = SlidingWindowConversationManager( + window_size=4, should_truncate_results=False, proactive_compression={"compression_threshold": 0.7} + ) + messages = [ + {"role": "user", "content": [{"text": f"Message {i}"}]} + if i % 2 == 0 + else {"role": "assistant", "content": [{"text": f"Response {i}"}]} + for i in range(6) + ] + agent = _make_mock_agent(messages=messages, context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=800) + registry.invoke_callbacks(event) + + assert len(agent.messages) == 4 + + +def test_sliding_window_proactive_compression_no_trim_below(): + manager = SlidingWindowConversationManager( + window_size=4, should_truncate_results=False, proactive_compression={"compression_threshold": 0.7} + ) + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + ] + agent = _make_mock_agent(messages=messages, context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=500) + registry.invoke_callbacks(event) + + assert len(agent.messages) == 2 diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py index c49c69de6..dbd225e9b 100644 --- a/tests/strands/agent/test_summarizing_conversation_manager.py +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -1,5 +1,5 @@ from typing import cast -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest @@ -8,6 +8,8 @@ DEFAULT_SUMMARIZATION_PROMPT, SummarizingConversationManager, ) +from strands.hooks.events import BeforeModelCallEvent +from strands.hooks.registry import HookRegistry from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -101,7 +103,7 @@ def test_init_clamps_summary_ratio(): def test_reduce_context_raises_when_no_agent(): - """Test that reduce_context raises exception when agent has no messages.""" + """Test that reduce_context raises exception when agent has no messages (reactive mode).""" manager = SummarizingConversationManager() # Create a mock agent with no messages @@ -109,8 +111,9 @@ def test_reduce_context_raises_when_no_agent(): empty_messages: Messages = [] mock_agent.messages = empty_messages + # Reactive mode (e is set) should raise with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): - manager.reduce_context(mock_agent) + manager.reduce_context(mock_agent, e=RuntimeError("overflow")) def test_reduce_context_with_summarization(summarizing_manager, mock_agent): @@ -155,8 +158,9 @@ def test_reduce_context_too_few_messages_raises_exception(summarizing_manager, m ] mock_agent.messages = insufficient_test_messages # 5 messages, preserve_recent_messages=5, so nothing to summarize + # Reactive mode (e is set) should raise with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): - manager.reduce_context(mock_agent) + manager.reduce_context(mock_agent, e=RuntimeError("overflow")) def test_reduce_context_insufficient_messages_for_summarization(mock_agent): @@ -173,9 +177,9 @@ def test_reduce_context_insufficient_messages_for_summarization(mock_agent): ] mock_agent.messages = insufficient_messages - # This should raise an exception since there aren't enough messages to summarize + # Reactive mode (e is set) should raise with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): - manager.reduce_context(mock_agent) + manager.reduce_context(mock_agent, e=RuntimeError("overflow")) def test_reduce_context_raises_on_summarization_failure(): @@ -197,8 +201,9 @@ def test_reduce_context_raises_on_summarization_failure(): ) with patch("strands.agent.conversation_manager.summarizing_conversation_manager.logger") as mock_logger: + # Reactive mode (e is set) should raise with pytest.raises(Exception, match="Agent failed"): - manager.reduce_context(failing_agent) + manager.reduce_context(failing_agent, e=RuntimeError("overflow")) # Should log the error mock_logger.error.assert_called_once() @@ -675,9 +680,10 @@ def mock_adjust(messages, split_point): ] mock_agent.messages = simple_messages - # The adjustment method will return 0, which should trigger line 122-123 + # The adjustment method will return 0, which should trigger the <= 0 check + # Reactive mode (e is set) should raise with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): - manager.reduce_context(mock_agent) + manager.reduce_context(mock_agent, e=RuntimeError("overflow")) def test_summarizing_conversation_manager_properly_records_removed_message_count(): @@ -802,3 +808,86 @@ def tracking_call(self, prompt): assert observed_values == [None], "structured output should be disabled during summarization" assert summary_agent._default_structured_output_model is structured_output_model, "should be restored after" + + +# ============================================================================== +# Compression Threshold Tests +# ============================================================================== + + +def _make_summarizing_threshold_agent(messages, summary_response="Summary of conversation", context_window_limit=1000): + agent = MagicMock() + agent.messages = messages + agent.model = MagicMock() + agent.model.context_window_limit = context_window_limit + agent.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream(summary_response)) + return agent + + +def test_proactive_compression_summarizes_when_exceeded(): + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + proactive_compression={"compression_threshold": 0.7}, + ) + messages = [ + {"role": "user", "content": [{"text": f"Message {i}"}]} + if i % 2 == 0 + else {"role": "assistant", "content": [{"text": f"Response {i}"}]} + for i in range(20) + ] + agent = _make_summarizing_threshold_agent(messages, context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = BeforeModelCallEvent(agent=agent, invocation_state={}, projected_input_tokens=800) + registry.invoke_callbacks(event) + + # 20 * 0.5 = 10 summarized → 1 summary + 10 remaining = 11 + assert len(agent.messages) == 11 + assert agent.messages[0]["role"] == "user" + + +def test_proactive_compression_no_summarize_when_below(): + manager = SummarizingConversationManager(proactive_compression={"compression_threshold": 0.7}) + messages = [ + {"role": "user", "content": [{"text": f"Message {i}"}]} + if i % 2 == 0 + else {"role": "assistant", "content": [{"text": f"Response {i}"}]} + for i in range(20) + ] + agent = _make_summarizing_threshold_agent(messages, context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = BeforeModelCallEvent(agent=agent, invocation_state={}, projected_input_tokens=500) + registry.invoke_callbacks(event) + + assert len(agent.messages) == 20 + + +def test_proactive_compression_swallows_errors(): + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + proactive_compression={"compression_threshold": 0.7}, + ) + messages = [ + {"role": "user", "content": [{"text": f"Message {i}"}]} + if i % 2 == 0 + else {"role": "assistant", "content": [{"text": f"Response {i}"}]} + for i in range(20) + ] + agent = MagicMock() + agent.messages = messages + agent.model = MagicMock() + agent.model.context_window_limit = 1000 + agent.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream_error(RuntimeError("model failed"))) + + registry = HookRegistry() + manager.register_hooks(registry) + + event = BeforeModelCallEvent(agent=agent, invocation_state={}, projected_input_tokens=800) + # Should not throw — proactive compression is best-effort + registry.invoke_callbacks(event) + assert len(agent.messages) == 20