diff --git a/src/aws_durable_execution_sdk_python/hook.py b/src/aws_durable_execution_sdk_python/hook.py new file mode 100644 index 0000000..9fde017 --- /dev/null +++ b/src/aws_durable_execution_sdk_python/hook.py @@ -0,0 +1,96 @@ +import datetime +from abc import ABC +from dataclasses import dataclass + +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import ( + OperationType, + OperationStatus, + OperationAction, + OperationSubType, + ErrorObject, +) + + +@dataclass +class OperationStartInfo: + operation_id: str + operation_type: OperationType + sub_type: OperationSubType | None = None + name: str | None = None + parent_id: str | None = None + start_timestamp: datetime.datetime | None = None + + +@dataclass +class OperationEndInfo(OperationStartInfo): + status: OperationStatus = OperationStatus.SUCCEEDED + end_timestamp: datetime.datetime | None = None + attempt: int = 1 + error: ErrorObject | None = None + + +@dataclass +class AttemptStartInfo(OperationStartInfo): + attempt: int = 1 + + +@dataclass +class AttemptEndInfo(AttemptStartInfo): + outcome: OperationAction = OperationAction.SUCCEED + error: ErrorObject | None = None + next_attempt_delay_seconds: int | None = None + + +@dataclass +class InvocationStartInfo: + request_id: str + execution_arn: str + start_time: datetime.datetime + + +@dataclass +class InvocationEndInfo(InvocationStartInfo): + status: InvocationStatus = InvocationStatus.SUCCEEDED + error: ErrorObject | None = None + + +@dataclass +class ExecutionStartInfo(InvocationStartInfo): + pass + + +@dataclass +class ExecutionEndInfo(InvocationEndInfo): + pass + + +class DurableExecutionPlugin(ABC): + """Base class for plugins. Override only the methods you need.""" + + def on_execution_start(self, info: ExecutionStartInfo) -> None: + pass + + def on_execution_end(self, info: ExecutionEndInfo) -> None: + pass + + def on_invocation_start(self, info: InvocationStartInfo) -> None: + pass + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + pass + + def on_operation_start(self, info: OperationStartInfo) -> None: + pass + + def on_operation_end(self, info: OperationEndInfo) -> None: + pass + + def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: + pass + + def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: + pass + + # Todo: further discussions required to finalize the following interface + # def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass diff --git a/tests/hook_test.py b/tests/hook_test.py new file mode 100644 index 0000000..edacff4 --- /dev/null +++ b/tests/hook_test.py @@ -0,0 +1,243 @@ +import datetime +import unittest + +from aws_durable_execution_sdk_python.hook import ( + AttemptEndInfo, + AttemptStartInfo, + DurableExecutionPlugin, + ExecutionEndInfo, + ExecutionStartInfo, + InvocationEndInfo, + InvocationStartInfo, + OperationEndInfo, + OperationStartInfo, +) +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + OperationAction, + OperationStatus, + OperationSubType, + OperationType, +) + + +class TestOperationStartInfo(unittest.TestCase): + def test_required_fields(self): + info = OperationStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + self.assertEqual(info.operation_id, "op-1") + self.assertEqual(info.operation_type, OperationType.STEP) + self.assertIsNone(info.sub_type) + self.assertIsNone(info.name) + self.assertIsNone(info.parent_id) + self.assertIsNone(info.start_timestamp) + + def test_all_fields(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = OperationStartInfo( + operation_id="op-2", + operation_type=OperationType.CALLBACK, + sub_type=OperationSubType.CALLBACK, + name="my-op", + parent_id="parent-1", + start_timestamp=ts, + ) + self.assertEqual(info.sub_type, OperationSubType.CALLBACK) + self.assertEqual(info.name, "my-op") + self.assertEqual(info.parent_id, "parent-1") + self.assertEqual(info.start_timestamp, ts) + + +class TestOperationEndInfo(unittest.TestCase): + def test_inherits_operation_start_info(self): + self.assertTrue(issubclass(OperationEndInfo, OperationStartInfo)) + + def test_defaults(self): + info = OperationEndInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + self.assertEqual(info.status, OperationStatus.SUCCEEDED) + self.assertIsNone(info.end_timestamp) + self.assertEqual(info.attempt, 1) + self.assertIsNone(info.error) + + def test_with_error(self): + err = ErrorObject(message="fail", type="RuntimeError", data=None, stack_trace=None) + info = OperationEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.FAILED, + error=err, + attempt=3, + ) + self.assertEqual(info.status, OperationStatus.FAILED) + self.assertEqual(info.attempt, 3) + self.assertEqual(info.error.message, "fail") + + +class TestAttemptStartInfo(unittest.TestCase): + def test_inherits_operation_start_info(self): + self.assertTrue(issubclass(AttemptStartInfo, OperationStartInfo)) + + def test_default_attempt(self): + info = AttemptStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + self.assertEqual(info.attempt, 1) + + def test_custom_attempt(self): + info = AttemptStartInfo( + operation_id="op-1", operation_type=OperationType.STEP, attempt=5 + ) + self.assertEqual(info.attempt, 5) + + +class TestAttemptEndInfo(unittest.TestCase): + def test_inherits_attempt_start_info(self): + self.assertTrue(issubclass(AttemptEndInfo, AttemptStartInfo)) + + def test_defaults(self): + info = AttemptEndInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + self.assertEqual(info.outcome, OperationAction.SUCCEED) + self.assertIsNone(info.error) + self.assertIsNone(info.next_attempt_delay_seconds) + + def test_retry_with_delay(self): + err = ErrorObject(message="timeout", type="TimeoutError", data=None, stack_trace=None) + info = AttemptEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + outcome=OperationAction.RETRY, + error=err, + next_attempt_delay_seconds=30, + ) + self.assertEqual(info.outcome, OperationAction.RETRY) + self.assertEqual(info.next_attempt_delay_seconds, 30) + self.assertEqual(info.error.type, "TimeoutError") + + +class TestInvocationStartInfo(unittest.TestCase): + def test_fields(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = InvocationStartInfo( + request_id="req-1", execution_arn="arn:aws:lambda:us-east-1:123:durable:abc", start_time=ts + ) + self.assertEqual(info.request_id, "req-1") + self.assertEqual(info.execution_arn, "arn:aws:lambda:us-east-1:123:durable:abc") + self.assertEqual(info.start_time, ts) + + +class TestInvocationEndInfo(unittest.TestCase): + def test_inherits_invocation_start_info(self): + self.assertTrue(issubclass(InvocationEndInfo, InvocationStartInfo)) + + def test_defaults(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = InvocationEndInfo( + request_id="req-1", execution_arn="arn:test", start_time=ts + ) + self.assertEqual(info.status, InvocationStatus.SUCCEEDED) + self.assertIsNone(info.error) + + def test_failed(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + err = ErrorObject(message="boom", type="Error", data=None, stack_trace=None) + info = InvocationEndInfo( + request_id="req-1", + execution_arn="arn:test", + start_time=ts, + status=InvocationStatus.FAILED, + error=err, + ) + self.assertEqual(info.status, InvocationStatus.FAILED) + self.assertEqual(info.error.message, "boom") + + +class TestExecutionStartInfo(unittest.TestCase): + def test_inherits_invocation_start_info(self): + self.assertTrue(issubclass(ExecutionStartInfo, InvocationStartInfo)) + + def test_construction(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = ExecutionStartInfo( + request_id="req-1", execution_arn="arn:test", start_time=ts + ) + self.assertEqual(info.request_id, "req-1") + + +class TestExecutionEndInfo(unittest.TestCase): + def test_inherits_invocation_end_info(self): + self.assertTrue(issubclass(ExecutionEndInfo, InvocationEndInfo)) + + def test_defaults(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = ExecutionEndInfo( + request_id="req-1", execution_arn="arn:test", start_time=ts + ) + self.assertEqual(info.status, InvocationStatus.SUCCEEDED) + self.assertIsNone(info.error) + + +class TestDurableExecutionPlugin(unittest.TestCase): + def test_default_methods_are_noop(self): + """All default hook methods should be callable and return None.""" + plugin = _NoOpPlugin() + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + exec_start = ExecutionStartInfo(request_id="r", execution_arn="a", start_time=ts) + exec_end = ExecutionEndInfo(request_id="r", execution_arn="a", start_time=ts) + inv_start = InvocationStartInfo(request_id="r", execution_arn="a", start_time=ts) + inv_end = InvocationEndInfo(request_id="r", execution_arn="a", start_time=ts) + op_start = OperationStartInfo(operation_id="o", operation_type=OperationType.STEP) + op_end = OperationEndInfo(operation_id="o", operation_type=OperationType.STEP) + att_start = AttemptStartInfo(operation_id="o", operation_type=OperationType.STEP) + att_end = AttemptEndInfo(operation_id="o", operation_type=OperationType.STEP) + + self.assertIsNone(plugin.on_execution_start(exec_start)) + self.assertIsNone(plugin.on_execution_end(exec_end)) + self.assertIsNone(plugin.on_invocation_start(inv_start)) + self.assertIsNone(plugin.on_invocation_end(inv_end)) + self.assertIsNone(plugin.on_operation_start(op_start)) + self.assertIsNone(plugin.on_operation_end(op_end)) + self.assertIsNone(plugin.on_operation_attempt_start(att_start)) + self.assertIsNone(plugin.on_operation_attempt_end(att_end)) + + def test_subclass_override(self): + """A subclass can override specific hooks.""" + plugin = _TrackingPlugin() + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + plugin.on_execution_start(ExecutionStartInfo(request_id="r", execution_arn="a", start_time=ts)) + plugin.on_operation_start(OperationStartInfo(operation_id="o", operation_type=OperationType.WAIT)) + + self.assertEqual(plugin.calls, ["execution_start:r", "operation_start:o"]) + + def test_cannot_instantiate_abc_directly(self): + """DurableExecutionPlugin is abstract but has no abstract methods, so it can be instantiated via a subclass.""" + self.assertTrue(issubclass(DurableExecutionPlugin, object)) + + +class _NoOpPlugin(DurableExecutionPlugin): + """Concrete subclass that inherits all default no-op methods.""" + pass + + +class _TrackingPlugin(DurableExecutionPlugin): + """Concrete subclass that tracks calls to specific hooks.""" + + def __init__(self): + self.calls: list[str] = [] + + def on_execution_start(self, info: ExecutionStartInfo) -> None: + self.calls.append(f"execution_start:{info.request_id}") + + def on_operation_start(self, info: OperationStartInfo) -> None: + self.calls.append(f"operation_start:{info.operation_id}") + + +if __name__ == "__main__": + unittest.main()