Skip to content
Draft

Hooks #371

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions src/aws_durable_execution_sdk_python/hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import datetime
from abc import ABC
from dataclasses import dataclass

from aws_durable_execution_sdk_python.execution import InvocationStatus
from aws_durable_execution_sdk_python.lambda_service import (
OperationType,
OperationStatus,
OperationAction,
OperationSubType,
ErrorObject,
)


@dataclass
class OperationStartInfo:
operation_id: str
operation_type: OperationType
sub_type: OperationSubType | None = None
name: str | None = None
parent_id: str | None = None
start_timestamp: datetime.datetime | None = None


@dataclass
class OperationEndInfo(OperationStartInfo):
status: OperationStatus = OperationStatus.SUCCEEDED
end_timestamp: datetime.datetime | None = None
attempt: int = 1
error: ErrorObject | None = None


@dataclass
class AttemptStartInfo(OperationStartInfo):
attempt: int = 1


@dataclass
class AttemptEndInfo(AttemptStartInfo):
outcome: OperationAction = OperationAction.SUCCEED
error: ErrorObject | None = None
next_attempt_delay_seconds: int | None = None


@dataclass
class InvocationStartInfo:
request_id: str
execution_arn: str
start_time: datetime.datetime


@dataclass
class InvocationEndInfo(InvocationStartInfo):
status: InvocationStatus = InvocationStatus.SUCCEEDED
error: ErrorObject | None = None


@dataclass
class ExecutionStartInfo(InvocationStartInfo):
pass


@dataclass
class ExecutionEndInfo(InvocationEndInfo):
pass


class DurableExecutionPlugin(ABC):
"""Base class for plugins. Override only the methods you need."""

def on_execution_start(self, info: ExecutionStartInfo) -> None:
pass

def on_execution_end(self, info: ExecutionEndInfo) -> None:
pass

def on_invocation_start(self, info: InvocationStartInfo) -> None:
pass

def on_invocation_end(self, info: InvocationEndInfo) -> None:
pass

def on_operation_start(self, info: OperationStartInfo) -> None:
pass

def on_operation_end(self, info: OperationEndInfo) -> None:
pass

def on_operation_attempt_start(self, info: AttemptStartInfo) -> None:
pass

def on_operation_attempt_end(self, info: AttemptEndInfo) -> None:
pass

# Todo: further discussions required to finalize the following interface
# def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass
243 changes: 243 additions & 0 deletions tests/hook_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import datetime
import unittest

from aws_durable_execution_sdk_python.hook import (
AttemptEndInfo,
AttemptStartInfo,
DurableExecutionPlugin,
ExecutionEndInfo,
ExecutionStartInfo,
InvocationEndInfo,
InvocationStartInfo,
OperationEndInfo,
OperationStartInfo,
)
from aws_durable_execution_sdk_python.execution import InvocationStatus
from aws_durable_execution_sdk_python.lambda_service import (
ErrorObject,
OperationAction,
OperationStatus,
OperationSubType,
OperationType,
)


class TestOperationStartInfo(unittest.TestCase):
def test_required_fields(self):
info = OperationStartInfo(
operation_id="op-1", operation_type=OperationType.STEP
)
self.assertEqual(info.operation_id, "op-1")
self.assertEqual(info.operation_type, OperationType.STEP)
self.assertIsNone(info.sub_type)
self.assertIsNone(info.name)
self.assertIsNone(info.parent_id)
self.assertIsNone(info.start_timestamp)

def test_all_fields(self):
ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC)
info = OperationStartInfo(
operation_id="op-2",
operation_type=OperationType.CALLBACK,
sub_type=OperationSubType.CALLBACK,
name="my-op",
parent_id="parent-1",
start_timestamp=ts,
)
self.assertEqual(info.sub_type, OperationSubType.CALLBACK)
self.assertEqual(info.name, "my-op")
self.assertEqual(info.parent_id, "parent-1")
self.assertEqual(info.start_timestamp, ts)


class TestOperationEndInfo(unittest.TestCase):
def test_inherits_operation_start_info(self):
self.assertTrue(issubclass(OperationEndInfo, OperationStartInfo))

def test_defaults(self):
info = OperationEndInfo(
operation_id="op-1", operation_type=OperationType.STEP
)
self.assertEqual(info.status, OperationStatus.SUCCEEDED)
self.assertIsNone(info.end_timestamp)
self.assertEqual(info.attempt, 1)
self.assertIsNone(info.error)

def test_with_error(self):
err = ErrorObject(message="fail", type="RuntimeError", data=None, stack_trace=None)
info = OperationEndInfo(
operation_id="op-1",
operation_type=OperationType.STEP,
status=OperationStatus.FAILED,
error=err,
attempt=3,
)
self.assertEqual(info.status, OperationStatus.FAILED)
self.assertEqual(info.attempt, 3)
self.assertEqual(info.error.message, "fail")


class TestAttemptStartInfo(unittest.TestCase):
def test_inherits_operation_start_info(self):
self.assertTrue(issubclass(AttemptStartInfo, OperationStartInfo))

def test_default_attempt(self):
info = AttemptStartInfo(
operation_id="op-1", operation_type=OperationType.STEP
)
self.assertEqual(info.attempt, 1)

def test_custom_attempt(self):
info = AttemptStartInfo(
operation_id="op-1", operation_type=OperationType.STEP, attempt=5
)
self.assertEqual(info.attempt, 5)


class TestAttemptEndInfo(unittest.TestCase):
def test_inherits_attempt_start_info(self):
self.assertTrue(issubclass(AttemptEndInfo, AttemptStartInfo))

def test_defaults(self):
info = AttemptEndInfo(
operation_id="op-1", operation_type=OperationType.STEP
)
self.assertEqual(info.outcome, OperationAction.SUCCEED)
self.assertIsNone(info.error)
self.assertIsNone(info.next_attempt_delay_seconds)

def test_retry_with_delay(self):
err = ErrorObject(message="timeout", type="TimeoutError", data=None, stack_trace=None)
info = AttemptEndInfo(
operation_id="op-1",
operation_type=OperationType.STEP,
outcome=OperationAction.RETRY,
error=err,
next_attempt_delay_seconds=30,
)
self.assertEqual(info.outcome, OperationAction.RETRY)
self.assertEqual(info.next_attempt_delay_seconds, 30)
self.assertEqual(info.error.type, "TimeoutError")


class TestInvocationStartInfo(unittest.TestCase):
def test_fields(self):
ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC)
info = InvocationStartInfo(
request_id="req-1", execution_arn="arn:aws:lambda:us-east-1:123:durable:abc", start_time=ts
)
self.assertEqual(info.request_id, "req-1")
self.assertEqual(info.execution_arn, "arn:aws:lambda:us-east-1:123:durable:abc")
self.assertEqual(info.start_time, ts)


class TestInvocationEndInfo(unittest.TestCase):
def test_inherits_invocation_start_info(self):
self.assertTrue(issubclass(InvocationEndInfo, InvocationStartInfo))

def test_defaults(self):
ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC)
info = InvocationEndInfo(
request_id="req-1", execution_arn="arn:test", start_time=ts
)
self.assertEqual(info.status, InvocationStatus.SUCCEEDED)
self.assertIsNone(info.error)

def test_failed(self):
ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC)
err = ErrorObject(message="boom", type="Error", data=None, stack_trace=None)
info = InvocationEndInfo(
request_id="req-1",
execution_arn="arn:test",
start_time=ts,
status=InvocationStatus.FAILED,
error=err,
)
self.assertEqual(info.status, InvocationStatus.FAILED)
self.assertEqual(info.error.message, "boom")


class TestExecutionStartInfo(unittest.TestCase):
def test_inherits_invocation_start_info(self):
self.assertTrue(issubclass(ExecutionStartInfo, InvocationStartInfo))

def test_construction(self):
ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC)
info = ExecutionStartInfo(
request_id="req-1", execution_arn="arn:test", start_time=ts
)
self.assertEqual(info.request_id, "req-1")


class TestExecutionEndInfo(unittest.TestCase):
def test_inherits_invocation_end_info(self):
self.assertTrue(issubclass(ExecutionEndInfo, InvocationEndInfo))

def test_defaults(self):
ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC)
info = ExecutionEndInfo(
request_id="req-1", execution_arn="arn:test", start_time=ts
)
self.assertEqual(info.status, InvocationStatus.SUCCEEDED)
self.assertIsNone(info.error)


class TestDurableExecutionPlugin(unittest.TestCase):
def test_default_methods_are_noop(self):
"""All default hook methods should be callable and return None."""
plugin = _NoOpPlugin()
ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC)

exec_start = ExecutionStartInfo(request_id="r", execution_arn="a", start_time=ts)
exec_end = ExecutionEndInfo(request_id="r", execution_arn="a", start_time=ts)
inv_start = InvocationStartInfo(request_id="r", execution_arn="a", start_time=ts)
inv_end = InvocationEndInfo(request_id="r", execution_arn="a", start_time=ts)
op_start = OperationStartInfo(operation_id="o", operation_type=OperationType.STEP)
op_end = OperationEndInfo(operation_id="o", operation_type=OperationType.STEP)
att_start = AttemptStartInfo(operation_id="o", operation_type=OperationType.STEP)
att_end = AttemptEndInfo(operation_id="o", operation_type=OperationType.STEP)

self.assertIsNone(plugin.on_execution_start(exec_start))
self.assertIsNone(plugin.on_execution_end(exec_end))
self.assertIsNone(plugin.on_invocation_start(inv_start))
self.assertIsNone(plugin.on_invocation_end(inv_end))
self.assertIsNone(plugin.on_operation_start(op_start))
self.assertIsNone(plugin.on_operation_end(op_end))
self.assertIsNone(plugin.on_operation_attempt_start(att_start))
self.assertIsNone(plugin.on_operation_attempt_end(att_end))

def test_subclass_override(self):
"""A subclass can override specific hooks."""
plugin = _TrackingPlugin()
ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC)

plugin.on_execution_start(ExecutionStartInfo(request_id="r", execution_arn="a", start_time=ts))
plugin.on_operation_start(OperationStartInfo(operation_id="o", operation_type=OperationType.WAIT))

self.assertEqual(plugin.calls, ["execution_start:r", "operation_start:o"])

def test_cannot_instantiate_abc_directly(self):
"""DurableExecutionPlugin is abstract but has no abstract methods, so it can be instantiated via a subclass."""
self.assertTrue(issubclass(DurableExecutionPlugin, object))


class _NoOpPlugin(DurableExecutionPlugin):
"""Concrete subclass that inherits all default no-op methods."""
pass


class _TrackingPlugin(DurableExecutionPlugin):
"""Concrete subclass that tracks calls to specific hooks."""

def __init__(self):
self.calls: list[str] = []

def on_execution_start(self, info: ExecutionStartInfo) -> None:
self.calls.append(f"execution_start:{info.request_id}")

def on_operation_start(self, info: OperationStartInfo) -> None:
self.calls.append(f"operation_start:{info.operation_id}")


if __name__ == "__main__":
unittest.main()
Loading