Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 33 additions & 55 deletions src/conductor/client/automator/async_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from conductor.client.worker.worker_config import resolve_worker_config, get_worker_config_oneline
from conductor.client.worker.exception import NonRetryableException
from conductor.client.automator.json_schema_generator import generate_json_schema_from_function
from conductor.client.automator.lease_tracker import LeaseInfo, LEASE_EXTEND_RETRY_COUNT, LEASE_EXTEND_DURATION_FACTOR
from conductor.client.automator.lease_tracker import LeaseManager

logger = logging.getLogger(
Configuration.get_logging_formatted_name(
Expand Down Expand Up @@ -113,7 +113,9 @@ def __init__(
self._semaphore = None
self._shutdown = False # Flag to indicate graceful shutdown
self._use_update_v2 = True # Will be set to False if server doesn't support v2 endpoint
self._lease_info = {} # task_id -> LeaseInfo for lease extension heartbeats
self._lease_manager = LeaseManager.get_instance()
self._tracked_task_ids = set() # Local set for cleanup on shutdown
self._sync_task_client = None # Created after fork for LeaseManager heartbeats

async def run(self) -> None:
"""Main async loop - runs continuously in single event loop."""
Expand All @@ -133,6 +135,17 @@ async def run(self) -> None:
api_client=self.async_api_client
)

# Create a sync TaskResourceApi for LeaseManager heartbeats
# (LeaseManager sends heartbeats from its own ThreadPoolExecutor)
from conductor.client.http.api.task_resource_api import TaskResourceApi
from conductor.client.http.api_client import ApiClient
self._sync_task_client = TaskResourceApi(
ApiClient(
configuration=self.configuration,
metrics_collector=self.metrics_collector
)
)

# Create semaphore in the event loop (must be created within the loop)
self._semaphore = asyncio.Semaphore(self._max_workers)

Expand Down Expand Up @@ -168,8 +181,10 @@ async def _cleanup(self) -> None:
"""Clean up async resources."""
logger.debug("Cleaning up AsyncTaskRunner resources...")

# Stop all lease extension tracking
self._lease_info.clear()
# Untrack all tasks this runner was tracking from the shared LeaseManager
for task_id in list(self._tracked_task_ids):
self._lease_manager.untrack(task_id)
self._tracked_task_ids.clear()

# Cancel any running tasks (EAFP style)
try:
Expand All @@ -187,6 +202,13 @@ async def _cleanup(self) -> None:
except (IOError, OSError) as e:
logger.warning(f"Error closing async client: {e}")

# Close sync HTTP client used for lease heartbeats
if self._sync_task_client:
try:
self._sync_task_client.api_client.rest_client.connection.close()
except Exception:
pass

# Clear event listeners
self.event_dispatcher = None

Expand Down Expand Up @@ -441,9 +463,6 @@ async def __async_register_task_definition(self) -> None:
async def run_once(self) -> None:
"""Execute one iteration of the polling loop (async version)."""
try:
# Send lease extension heartbeats for any tasks that are due
await self._send_due_heartbeats()

# No need for manual cleanup - tasks remove themselves via add_done_callback
# Just check capacity directly
current_capacity = len(self._running_tasks)
Expand Down Expand Up @@ -932,68 +951,27 @@ async def __async_update_task(self, task_result: TaskResult):

return None

# -- Lease extension (heartbeat) methods ----------------------------------
# -- Lease extension (heartbeat) delegation to LeaseManager ----------------

def _track_lease(self, task) -> None:
"""Start tracking a task for lease extension heartbeat."""
"""Start tracking a task for lease extension via the shared LeaseManager."""
if not getattr(self.worker, 'lease_extend_enabled', False):
return
timeout = getattr(task, 'response_timeout_seconds', None) or 0
if timeout <= 0:
return
interval = timeout * LEASE_EXTEND_DURATION_FACTOR
if interval < 1:
return
self._lease_info[task.task_id] = LeaseInfo(
self._lease_manager.track(
task_id=task.task_id,
workflow_instance_id=task.workflow_instance_id,
response_timeout_seconds=timeout,
last_heartbeat_time=time.monotonic(),
interval_seconds=interval,
)
logger.debug(
"Tracking lease for task %s (timeout=%ss, heartbeat every %ss)",
task.task_id, timeout, interval,
task_client=self._sync_task_client,
)
self._tracked_task_ids.add(task.task_id)

def _untrack_lease(self, task_id: str) -> None:
"""Stop tracking a task for lease extension."""
removed = self._lease_info.pop(task_id, None)
if removed is not None:
logger.debug("Untracked lease for task %s", task_id)

async def _send_due_heartbeats(self) -> None:
"""Check all tracked tasks and send heartbeats for any that are due."""
if not self._lease_info:
return
now = time.monotonic()
for info in list(self._lease_info.values()):
elapsed = now - info.last_heartbeat_time
if elapsed < info.interval_seconds:
continue
await self._send_heartbeat(info)
info.last_heartbeat_time = time.monotonic()

async def _send_heartbeat(self, info: LeaseInfo) -> None:
"""Send a single lease extension heartbeat with retry (async)."""
result = TaskResult(
task_id=info.task_id,
workflow_instance_id=info.workflow_instance_id,
extend_lease=True,
)
for attempt in range(LEASE_EXTEND_RETRY_COUNT):
try:
await self.async_task_client.update_task(body=result)
logger.debug("Extended lease for task %s", info.task_id)
return
except Exception as e:
if attempt < LEASE_EXTEND_RETRY_COUNT - 1:
await asyncio.sleep(0.5 * (attempt + 2))
else:
logger.error(
"Failed to extend lease for task %s after %d attempts: %s",
info.task_id, LEASE_EXTEND_RETRY_COUNT, e,
)
self._lease_manager.untrack(task_id)
self._tracked_task_ids.discard(task_id)

# --------------------------------------------------------------------------

Expand Down
209 changes: 208 additions & 1 deletion src/conductor/client/automator/lease_tracker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
"""Shared lease extension (heartbeat) tracking for TaskRunner and AsyncTaskRunner."""
"""Centralized lease extension (heartbeat) management for Conductor task runners.

Architecture:
LeaseManager runs a single background daemon thread that periodically checks
for tasks needing lease extension heartbeats. Due heartbeats are dispatched
to a small fixed ThreadPoolExecutor for parallel, non-blocking API calls.

This decouples heartbeat work entirely from worker poll loops, preventing
heartbeat API calls (and their retries) from blocking task polling.

Thread-safe: track() and untrack() can be called from any thread or event loop.
"""

import logging
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, Dict, Optional

from conductor.client.http.models.task_result import TaskResult

logger = logging.getLogger(__name__)

# Lease extension constants (matches Java SDK)
LEASE_EXTEND_RETRY_COUNT = 3
Expand All @@ -15,3 +36,189 @@ class LeaseInfo:
response_timeout_seconds: float
last_heartbeat_time: float # time.monotonic() of last heartbeat (or task start)
interval_seconds: float # 80% of responseTimeoutSeconds
task_client: Any = None # Sync TaskResourceApi for sending heartbeats


class LeaseManager:
"""Centralized lease extension manager for all workers in a process.

One background daemon thread checks for due heartbeats at a fixed interval.
A small ThreadPoolExecutor sends heartbeat API calls in parallel.
Poll loops are never blocked by heartbeat work.

Usage:
manager = LeaseManager.get_instance()
manager.track(task_id, workflow_id, timeout, task_client)
# ... task completes ...
manager.untrack(task_id)
"""

_instance: Optional['LeaseManager'] = None
_instance_lock = threading.Lock()
_instance_pid: Optional[int] = None

@classmethod
def get_instance(cls, check_interval: float = 1.0,
max_heartbeat_workers: int = 4) -> 'LeaseManager':
"""Get or create the process-wide LeaseManager singleton.

Fork-safe: a new instance is created after fork (threads don't survive fork).
"""
current_pid = os.getpid()
if cls._instance is None or cls._instance_pid != current_pid:
with cls._instance_lock:
if cls._instance is None or cls._instance_pid != current_pid:
cls._instance = cls(
check_interval=check_interval,
max_heartbeat_workers=max_heartbeat_workers,
)
cls._instance_pid = current_pid
return cls._instance

@classmethod
def _reset_instance(cls):
"""Reset the singleton. For testing only."""
with cls._instance_lock:
if cls._instance is not None:
cls._instance.shutdown()
cls._instance = None
cls._instance_pid = None

def __init__(self, check_interval: float = 1.0, max_heartbeat_workers: int = 4):
self._tracked: Dict[str, LeaseInfo] = {}
self._lock = threading.Lock()
self._executor = ThreadPoolExecutor(
max_workers=max_heartbeat_workers,
thread_name_prefix="lease-heartbeat",
)
self._stop_event = threading.Event()
self._check_interval = check_interval
self._thread: Optional[threading.Thread] = None
self._started = False
self._start_lock = threading.Lock()

def _ensure_started(self) -> None:
"""Lazily start the background thread on first track() call."""
if self._started:
return
with self._start_lock:
if not self._started:
self._thread = threading.Thread(
target=self._run, daemon=True, name="lease-manager",
)
self._thread.start()
self._started = True
logger.debug(
"LeaseManager started (check_interval=%.1fs)", self._check_interval,
)

def track(self, task_id: str, workflow_instance_id: str,
response_timeout_seconds: float, task_client: Any) -> None:
"""Start tracking a task for lease extension heartbeats.

Thread-safe. Can be called from any worker thread or event loop.

Args:
task_id: Conductor task ID.
workflow_instance_id: Workflow instance this task belongs to.
response_timeout_seconds: The task's server-side response timeout.
task_client: A **sync** TaskResourceApi for sending heartbeat API calls.
"""
interval = response_timeout_seconds * LEASE_EXTEND_DURATION_FACTOR
if interval < 1:
logger.debug(
"Skipping lease tracking for task %s (interval %.1fs too short)",
task_id, interval,
)
return

info = LeaseInfo(
task_id=task_id,
workflow_instance_id=workflow_instance_id,
response_timeout_seconds=response_timeout_seconds,
last_heartbeat_time=time.monotonic(),
interval_seconds=interval,
task_client=task_client,
)
with self._lock:
self._tracked[task_id] = info
self._ensure_started()
logger.debug(
"Tracking lease for task %s (timeout=%ss, heartbeat every %ss)",
task_id, response_timeout_seconds, interval,
)

def untrack(self, task_id: str) -> None:
"""Stop tracking a task. Thread-safe."""
with self._lock:
removed = self._tracked.pop(task_id, None)
if removed is not None:
logger.debug("Untracked lease for task %s", task_id)

@property
def tracked_count(self) -> int:
"""Number of currently tracked tasks."""
with self._lock:
return len(self._tracked)

# -- Background thread -----------------------------------------------------

def _run(self) -> None:
"""Background loop — checks for due heartbeats at fixed intervals."""
while not self._stop_event.is_set():
try:
self._check_and_send()
except Exception as e:
logger.error("LeaseManager error: %s", e)
self._stop_event.wait(self._check_interval)

def _check_and_send(self) -> None:
"""Find tasks with due heartbeats and dispatch to the thread pool."""
now = time.monotonic()
with self._lock:
due = [
info for info in self._tracked.values()
if now - info.last_heartbeat_time >= info.interval_seconds
]
for info in due:
# Update timestamp immediately to prevent double-dispatch on next tick
info.last_heartbeat_time = time.monotonic()
self._executor.submit(self._send_heartbeat, info)

@staticmethod
def _send_heartbeat(info: LeaseInfo) -> None:
"""Send a single lease extension heartbeat with retry.

Runs in a pool thread — blocking retries only block the pool thread,
never a poll loop.
"""
result = TaskResult(
task_id=info.task_id,
workflow_instance_id=info.workflow_instance_id,
extend_lease=True,
)
for attempt in range(LEASE_EXTEND_RETRY_COUNT):
try:
info.task_client.update_task(body=result)
logger.debug("Extended lease for task %s", info.task_id)
return
except Exception as e:
if attempt < LEASE_EXTEND_RETRY_COUNT - 1:
time.sleep(0.5 * (attempt + 2))
else:
logger.error(
"Failed to extend lease for task %s after %d attempts: %s",
info.task_id, LEASE_EXTEND_RETRY_COUNT, e,
)

# -- Lifecycle -------------------------------------------------------------

def shutdown(self) -> None:
"""Stop the background thread and thread pool."""
self._stop_event.set()
if self._started and self._thread is not None:
self._thread.join(timeout=5)
self._executor.shutdown(wait=False)
with self._lock:
self._tracked.clear()
logger.debug("LeaseManager shut down")
Loading
Loading