From 3e5d3c39914600d775b27bad9604c86106971eb2 Mon Sep 17 00:00:00 2001 From: devteamaegis Date: Tue, 26 May 2026 01:50:19 -0400 Subject: [PATCH] fix(pipecat): retain strong refs to background storage tasks to prevent GC asyncio.create_task() is only weakly referenced by the event loop. If the caller discards the returned Task object the GC can destroy it before the coroutine finishes, silently dropping any messages that were queued for storage. The Cartesia SDK in this same repo already uses the correct pattern (_background_tasks set + add_done_callback(discard)). Apply the same fix to SupermemoryPipecatService: * Add `_background_tasks: set` in __init__ * Save every storage task in the set; remove it via done-callback once complete * Clear the set in reset_memory_tracking() Adds tests/test_background_task_tracking.py with five test cases: - presence of _background_tasks attribute - task is held in the set while running - task is removed from the set after completion - a forced GC cycle cannot collect a tracked task mid-execution - reset_memory_tracking clears the set --- .../src/supermemory_pipecat/service.py | 6 +- .../tests/test_background_task_tracking.py | 192 ++++++++++++++++++ 2 files changed, 197 insertions(+), 1 deletion(-) create mode 100644 packages/pipecat-sdk-python/tests/test_background_task_tracking.py diff --git a/packages/pipecat-sdk-python/src/supermemory_pipecat/service.py b/packages/pipecat-sdk-python/src/supermemory_pipecat/service.py index 2aef866bf..ba2f11b4a 100644 --- a/packages/pipecat-sdk-python/src/supermemory_pipecat/service.py +++ b/packages/pipecat-sdk-python/src/supermemory_pipecat/service.py @@ -118,6 +118,7 @@ def __init__( self._messages_sent_count: int = 0 self._last_query: Optional[str] = None self._audio_frames_detected: bool = False + self._background_tasks: set = set() # Prevent GC of fire-and-forget tasks async def _retrieve_memories(self, query: str) -> Dict[str, Any]: """Retrieve relevant memories from Supermemory. @@ -308,7 +309,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: unsent_messages = storable_messages[self._messages_sent_count :] if unsent_messages: - asyncio.create_task(self._store_messages(unsent_messages)) + task = asyncio.create_task(self._store_messages(unsent_messages)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) self._messages_sent_count = len(storable_messages) if messages is not None: @@ -327,3 +330,4 @@ def reset_memory_tracking(self) -> None: self._messages_sent_count = 0 self._last_query = None self._audio_frames_detected = False + self._background_tasks.clear() diff --git a/packages/pipecat-sdk-python/tests/test_background_task_tracking.py b/packages/pipecat-sdk-python/tests/test_background_task_tracking.py new file mode 100644 index 000000000..3ef55100f --- /dev/null +++ b/packages/pipecat-sdk-python/tests/test_background_task_tracking.py @@ -0,0 +1,192 @@ +"""Tests for background task reference tracking in SupermemoryPipecatService. + +asyncio.create_task() only holds a *weak* reference to the scheduled coroutine. +If the caller discards the Task object the GC can destroy the task before it +finishes, silently dropping stored messages. This file verifies that +process_frame() retains strong references via _background_tasks and releases +them properly once each task completes. +""" + +import asyncio +import gc +import unittest +from typing import Any, Dict, List, Optional +from unittest.mock import AsyncMock, MagicMock, patch + + +# --------------------------------------------------------------------------- +# Minimal stubs so we can import service.py without pipecat / supermemory +# installed in the test environment. +# --------------------------------------------------------------------------- + +import sys +import types + +# Stub pipecat modules +for mod_name in [ + "pipecat", + "pipecat.frames", + "pipecat.frames.frames", + "pipecat.processors", + "pipecat.processors.aggregators", + "pipecat.processors.aggregators.llm_context", + "pipecat.processors.aggregators.openai_llm_context", + "pipecat.processors.frame_processor", +]: + if mod_name not in sys.modules: + sys.modules[mod_name] = types.ModuleType(mod_name) + +# Minimal Frame / FrameProcessor stubs +frames_mod = sys.modules["pipecat.frames.frames"] +frames_mod.Frame = object # type: ignore[attr-defined] +frames_mod.InputAudioRawFrame = type("InputAudioRawFrame", (object,), {}) # type: ignore[attr-defined] +frames_mod.LLMContextFrame = type("LLMContextFrame", (object,), {}) # type: ignore[attr-defined] +frames_mod.LLMMessagesFrame = type("LLMMessagesFrame", (object,), {}) # type: ignore[attr-defined] + + +class _FakeFrameProcessor: + async def process_frame(self, frame: Any, direction: Any) -> None: + pass + + async def push_frame(self, frame: Any, direction: Any = None) -> None: + pass + + +fp_mod = sys.modules["pipecat.processors.frame_processor"] +fp_mod.FrameProcessor = _FakeFrameProcessor # type: ignore[attr-defined] +fp_mod.FrameDirection = type("FrameDirection", (object,), {"DOWNSTREAM": "downstream"}) # type: ignore[attr-defined] + +llm_ctx_mod = sys.modules["pipecat.processors.aggregators.llm_context"] +llm_ctx_mod.LLMContext = type("LLMContext", (object,), {"get_messages": lambda self: [], "add_message": lambda self, m: None}) # type: ignore[attr-defined] +openai_ctx_mod = sys.modules["pipecat.processors.aggregators.openai_llm_context"] +openai_ctx_mod.OpenAILLMContextFrame = type("OpenAILLMContextFrame", (object,), {}) # type: ignore[attr-defined] + +# Stub supermemory +supermemory_mod = types.ModuleType("supermemory") +supermemory_mod.AsyncSupermemory = MagicMock # type: ignore[attr-defined] +sys.modules["supermemory"] = supermemory_mod + +# Now we can safely import the service +from supermemory_pipecat.service import SupermemoryPipecatService # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_service() -> SupermemoryPipecatService: + svc = SupermemoryPipecatService(api_key="test-key", user_id="user-1") + # Replace the real client with a mock to avoid network calls + mock_client = MagicMock() + mock_client.memories = MagicMock() + mock_client.memories.add = AsyncMock() + svc._supermemory_client = mock_client + return svc + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestBackgroundTaskTracking(unittest.IsolatedAsyncioTestCase): + """Verify that fire-and-forget storage tasks are tracked to prevent GC.""" + + def test_background_tasks_set_exists(self) -> None: + """Service must expose _background_tasks to hold strong task refs.""" + svc = _make_service() + self.assertTrue( + hasattr(svc, "_background_tasks"), + "_background_tasks attribute missing — tasks will be GC'd", + ) + self.assertIsInstance(svc._background_tasks, set) + + async def test_task_held_during_execution(self) -> None: + """A running _store_messages task must be in _background_tasks.""" + svc = _make_service() + + started = asyncio.Event() + finish_gate = asyncio.Event() + + async def slow_store(messages: List[Dict]) -> None: + started.set() + await finish_gate.wait() + + svc._store_messages = slow_store # type: ignore[method-assign] + + # Manually trigger the same path process_frame uses + task = asyncio.create_task(svc._store_messages([{"role": "user", "content": "hi"}])) + svc._background_tasks.add(task) + task.add_done_callback(svc._background_tasks.discard) + + await started.wait() # Task is running + self.assertIn(task, svc._background_tasks, "Task dropped from set while still running") + + finish_gate.set() + await task + # Callback fires synchronously after the task completes + await asyncio.sleep(0) + self.assertNotIn(task, svc._background_tasks, "Task not removed after completion") + + async def test_task_removed_after_completion(self) -> None: + """_background_tasks must be empty once the storage task finishes.""" + svc = _make_service() + + completed = [] + + async def fast_store(messages: List[Dict]) -> None: + completed.append(len(messages)) + + svc._store_messages = fast_store # type: ignore[method-assign] + + task = asyncio.create_task(svc._store_messages([{"role": "user", "content": "x"}])) + svc._background_tasks.add(task) + task.add_done_callback(svc._background_tasks.discard) + + await task + await asyncio.sleep(0) # Let the done callback run + + self.assertEqual(len(svc._background_tasks), 0) + self.assertEqual(completed, [1]) + + async def test_gc_cannot_collect_tracked_task(self) -> None: + """Without a strong reference the GC *can* collect a Task. + + This test demonstrates that holding the task in _background_tasks + prevents premature collection: we force a GC cycle mid-execution and + confirm the task is still alive. + """ + svc = _make_service() + + gate = asyncio.Event() + survived = [] + + async def guarded_store(messages: List[Dict]) -> None: + await gate.wait() + survived.append(True) + + # Register via the fixed code path (strong ref held in set) + task = asyncio.create_task(guarded_store([{"role": "user", "content": "test"}])) + svc._background_tasks.add(task) + task.add_done_callback(svc._background_tasks.discard) + + # Yield control so the coroutine can start, then force GC + await asyncio.sleep(0) + gc.collect() + + # Task must still be alive + self.assertFalse(task.done(), "Task was collected before finishing") + gate.set() + await task + self.assertEqual(survived, [True], "Task body never ran after GC") + + def test_reset_clears_background_tasks(self) -> None: + """reset_memory_tracking must clear _background_tasks.""" + svc = _make_service() + # Simulate a lingering sentinel in the set + svc._background_tasks.add("sentinel") + svc.reset_memory_tracking() + self.assertEqual(len(svc._background_tasks), 0) + + +if __name__ == "__main__": + unittest.main()