diff --git a/Cargo.lock b/Cargo.lock index 4fc1f75..16ae782 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -806,8 +806,7 @@ dependencies = [ [[package]] name = "restate-sdk-shared-core" version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0417331c92f9ef4a14dcbd40759cb4008221a3538affda05f907923c574119" +source = "git+https://github.com/restatedev/sdk-shared-core.git?branch=awaiting_on#4b8fb2b900d5a14acc65218463c5c7578a16cabb" dependencies = [ "base64", "bs58", diff --git a/Cargo.toml b/Cargo.toml index 7e4b467..db0de3b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,4 +14,4 @@ doc = false [dependencies] pyo3 = { version = "0.25.1", features = ["extension-module"] } tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } -restate-sdk-shared-core = { version = "=0.9.0", features = ["request_identity", "sha2_random_seed"] } +restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core.git", branch = "awaiting_on", features = ["request_identity", "sha2_random_seed"] } diff --git a/python/restate/asyncio.py b/python/restate/asyncio.py index 7c8d18f..8fa5bf5 100644 --- a/python/restate/asyncio.py +++ b/python/restate/asyncio.py @@ -16,6 +16,11 @@ from restate.exceptions import TerminalError from restate.context import RestateDurableFuture from restate.server_context import ServerDurableFuture, ServerInvocationContext +from restate.vm import ( + AllCompletedUnresolvedFuture, + FirstCompletedUnresolvedFuture, + SingleUnresolvedFuture, +) async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFuture[Any]]: @@ -24,9 +29,28 @@ async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFutu Returns a list of all futures. """ - async for _ in as_completed(*futures): - pass - return list(futures) + context: ServerInvocationContext | None = None + handles: List[int] = [] + futures_list = list(futures) + + if not futures_list: + return [] + for f in futures_list: + if not isinstance(f, ServerDurableFuture): + raise TerminalError("All futures must SDK created futures.") + if context is None: + context = f.context + elif context is not f.context: + raise TerminalError("All futures must be created by the same SDK context.") + if not f.is_completed(): + handles.append(f.handle) + + if handles: + assert context is not None + await context.create_poll_or_cancel_coroutine( + AllCompletedUnresolvedFuture([SingleUnresolvedFuture(h) for h in handles]) + ) + return futures_list async def select(**kws: RestateDurableFuture[Any]) -> List[Any]: @@ -118,7 +142,9 @@ async def wait_completed( completed = [] uncompleted = [] assert context is not None - await context.create_poll_or_cancel_coroutine(handles) + await context.create_poll_or_cancel_coroutine( + FirstCompletedUnresolvedFuture([SingleUnresolvedFuture(h) for h in handles]) + ) for index, handle in enumerate(handles): future = futures[index] diff --git a/python/restate/discovery.py b/python/restate/discovery.py index 2c90c65..b6a0d95 100644 --- a/python/restate/discovery.py +++ b/python/restate/discovery.py @@ -429,4 +429,4 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as: typing.Literal[" protocol_mode = PROTOCOL_MODES[endpoint.protocol] else: protocol_mode = PROTOCOL_MODES[discovered_as] - return Endpoint(protocolMode=protocol_mode, minProtocolVersion=5, maxProtocolVersion=6, services=services) + return Endpoint(protocolMode=protocol_mode, minProtocolVersion=5, maxProtocolVersion=7, services=services) diff --git a/python/restate/server_context.py b/python/restate/server_context.py index a7de250..721b29b 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -61,9 +61,10 @@ from restate.vm import ( DoProgressAnyCompleted, DoProgressCancelSignalReceived, - DoProgressReadFromInput, + DoProgressWaitExternalProgress, DoProgressExecuteRun, - DoWaitPendingRun, + SingleUnresolvedFuture, + UnresolvedFuture, ) logger = logging.getLogger(__name__) @@ -193,7 +194,7 @@ def __init__(self, context: "ServerInvocationContext", handle: int) -> None: async def coro() -> str: if not context.vm.is_completed(handle): - await context.create_poll_or_cancel_coroutine([handle]) + await context.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle)) invocation_id = await context.must_take_notification(handle) return typing.cast(str, invocation_id) @@ -235,7 +236,7 @@ def resolve(self, value: Any) -> Awaitable[None]: async def await_point(): if not self.server_context.vm.is_completed(handle): - await self.server_context.create_poll_or_cancel_coroutine([handle]) + await self.server_context.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle)) await self.server_context.must_take_notification(handle) return ServerDurableFuture(self.server_context, handle, await_point) @@ -248,7 +249,7 @@ def reject(self, message: str, code: int = 500) -> Awaitable[None]: async def await_point(): if not self.server_context.vm.is_completed(handle): - await self.server_context.create_poll_or_cancel_coroutine([handle]) + await self.server_context.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle)) await self.server_context.must_take_notification(handle) return ServerDurableFuture(self.server_context, handle, await_point) @@ -527,11 +528,11 @@ async def must_take_notification(self, handle): raise TerminalError(res.message, res.code) return res - async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> None: - """Create a coroutine to poll the handle.""" + async def create_poll_or_cancel_coroutine(self, unresolved_future: UnresolvedFuture) -> None: + """Create a coroutine to poll the unresolved future.""" while True: await self.take_and_send_output() - do_progress_response = self.vm.do_progress(handles) + do_progress_response = self.vm.do_progress(unresolved_future) if isinstance(do_progress_response, BaseException): logger.exception("Exception in do_progress", exc_info=do_progress_response) raise SdkInternalException() from do_progress_response @@ -556,7 +557,7 @@ async def wrapper(f): task = asyncio.create_task(wrapper(fn)) self.tasks.add(task) continue - if isinstance(do_progress_response, (DoWaitPendingRun, DoProgressReadFromInput)): + if isinstance(do_progress_response, DoProgressWaitExternalProgress): chunk = await self.receive() if chunk.get("type") == "restate.run_completed": continue @@ -574,7 +575,7 @@ def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = N async def fetch_result(): if not self.vm.is_completed(handle): - await self.create_poll_or_cancel_coroutine([handle]) + await self.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle)) res = await self.must_take_notification(handle) if res is None or serde is None: return res @@ -593,7 +594,7 @@ def create_sleep_future(self, handle: int) -> ServerDurableSleepFuture: async def transform(): if not self.vm.is_completed(handle): - await self.create_poll_or_cancel_coroutine([handle]) + await self.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle)) await self.must_take_notification(handle) return ServerDurableSleepFuture(self, handle, transform) @@ -605,7 +606,7 @@ def create_call_future( async def inv_id_factory(): if not self.vm.is_completed(invocation_id_handle): - await self.create_poll_or_cancel_coroutine([invocation_id_handle]) + await self.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(invocation_id_handle)) return await self.must_take_notification(invocation_id_handle) return ServerCallDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde), inv_id_factory) diff --git a/python/restate/vm.py b/python/restate/vm.py index 3250a0a..2efaa8b 100644 --- a/python/restate/vm.py +++ b/python/restate/vm.py @@ -15,10 +15,10 @@ # pylint: disable=E1101,R0917 # pylint: disable=too-many-arguments # pylint: disable=too-few-public-methods -from typing import Optional +from typing import List, Optional, Union from datetime import timedelta -from dataclasses import dataclass +from dataclasses import dataclass, field import typing from restate._internal import ( PyVM, @@ -30,10 +30,10 @@ PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, - PyDoProgressReadFromInput, + PyDoProgressWaitExternalProgress, PyDoProgressExecuteRun, - PyDoWaitForPendingRun, PyDoProgressCancelSignalReceived, + PyUnresolvedFuture, CANCEL_NOTIFICATION_HANDLE, ) # pylint: disable=import-error,no-name-in-module,line-too-long @@ -105,9 +105,10 @@ class DoProgressAnyCompleted: """ -class DoProgressReadFromInput: +class DoProgressWaitExternalProgress: """ - Represents a notification that the input needs to be read. + Represents a notification that external progress is required + (either new input from the server or a pending run proposal). """ @@ -128,26 +129,57 @@ class DoProgressCancelSignalReceived: """ -class DoWaitPendingRun: - """ - Represents a notification that a run is pending - """ - - DO_PROGRESS_ANY_COMPLETED = DoProgressAnyCompleted() -DO_PROGRESS_READ_FROM_INPUT = DoProgressReadFromInput() +DO_PROGRESS_WAIT_EXTERNAL_PROGRESS = DoProgressWaitExternalProgress() DO_PROGRESS_CANCEL_SIGNAL_RECEIVED = DoProgressCancelSignalReceived() -DO_WAIT_PENDING_RUN = DoWaitPendingRun() DoProgressResult = typing.Union[ DoProgressAnyCompleted, - DoProgressReadFromInput, + DoProgressWaitExternalProgress, DoProgressExecuteRun, DoProgressCancelSignalReceived, - DoWaitPendingRun, ] +@dataclass(frozen=True) +class SingleUnresolvedFuture: + """A single leaf handle.""" + + handle: int + + +@dataclass(frozen=True) +class FirstCompletedUnresolvedFuture: + """first child to complete (success or failure) wins.""" + + children: List["UnresolvedFuture"] = field(default_factory=list) + + +@dataclass(frozen=True) +class AllCompletedUnresolvedFuture: + """wait for all children to complete.""" + + children: List["UnresolvedFuture"] = field(default_factory=list) + + +UnresolvedFuture = Union[ + SingleUnresolvedFuture, + FirstCompletedUnresolvedFuture, + AllCompletedUnresolvedFuture, +] + + +def _unresolved_future_to_pyo3(uf: UnresolvedFuture) -> PyUnresolvedFuture: + """Recursively convert a Python-side UnresolvedFuture dataclass to its PyO3 pyclass.""" + if isinstance(uf, SingleUnresolvedFuture): + return PyUnresolvedFuture.single(uf.handle) + if isinstance(uf, FirstCompletedUnresolvedFuture): + return PyUnresolvedFuture.first_completed([_unresolved_future_to_pyo3(c) for c in uf.children]) + if isinstance(uf, AllCompletedUnresolvedFuture): + return PyUnresolvedFuture.all_completed([_unresolved_future_to_pyo3(c) for c in uf.children]) + raise TypeError(f"Unknown UnresolvedFuture variant: {type(uf).__name__}") + + # pylint: disable=too-many-public-methods class VMWrapper: """ @@ -195,24 +227,22 @@ def is_completed(self, handle: int) -> bool: return self.vm.is_completed(handle) # pylint: disable=R0911 - def do_progress(self, handles: list[int]) -> typing.Union[DoProgressResult, Exception, Suspended]: + def do_progress(self, unresolved_future: UnresolvedFuture) -> typing.Union[DoProgressResult, Exception, Suspended]: """Do progress with notifications.""" try: - result = self.vm.do_progress(handles) + result = self.vm.do_progress(_unresolved_future_to_pyo3(unresolved_future)) except VMException as e: return e if isinstance(result, PySuspended): return SUSPENDED if isinstance(result, PyDoProgressAnyCompleted): return DO_PROGRESS_ANY_COMPLETED - if isinstance(result, PyDoProgressReadFromInput): - return DO_PROGRESS_READ_FROM_INPUT + if isinstance(result, PyDoProgressWaitExternalProgress): + return DO_PROGRESS_WAIT_EXTERNAL_PROGRESS if isinstance(result, PyDoProgressExecuteRun): return DoProgressExecuteRun(result.handle) if isinstance(result, PyDoProgressCancelSignalReceived): return DO_PROGRESS_CANCEL_SIGNAL_RECEIVED - if isinstance(result, PyDoWaitForPendingRun): - return DO_WAIT_PENDING_RUN return ValueError(f"Unknown progress type: {result}") def take_notification(self, handle: int) -> typing.Union[NotificationType, Exception, Suspended]: @@ -343,9 +373,8 @@ def sys_call( headers: typing.Optional[typing.List[typing.Tuple[str, str]]] = None, ): """Call a service""" - if headers: - headers = [PyHeader(key=h[0], value=h[1]) for h in headers] - return self.vm.sys_call(service, handler, parameter, key, idempotency_key, headers) + py_headers = [PyHeader(key=h[0], value=h[1]) for h in headers] if headers else None + return self.vm.sys_call(service, handler, parameter, key, idempotency_key, py_headers) # pylint: disable=too-many-arguments def sys_send( @@ -362,9 +391,8 @@ def sys_send( send an invocation to a service, and return the handle to the promise that will resolve with the invocation id """ - if headers: - headers = [PyHeader(key=h[0], value=h[1]) for h in headers] - return self.vm.sys_send(service, handler, parameter, key, delay, idempotency_key, headers) + py_headers = [PyHeader(key=h[0], value=h[1]) for h in headers] if headers else None + return self.vm.sys_send(service, handler, parameter, key, delay, idempotency_key, py_headers) def sys_run(self, name: str) -> int: """ @@ -391,17 +419,11 @@ def sys_reject_awakeable(self, name: str, failure: Failure): py_failure = PyFailure(failure.code, failure.message) self.vm.sys_complete_awakeable_failure(name, py_failure) - def propose_run_completion_success(self, handle: int, output: bytes) -> int: + def propose_run_completion_success(self, handle: int, output: bytes) -> None: """ - Exit a side effect - - Args: - output: The output of the side effect. - - Returns: - handle + Exit a side effect with a success value. """ - return self.vm.propose_run_completion_success(handle, output) + self.vm.propose_run_completion_success(handle, output) def sys_get_promise(self, name: str) -> int: """Returns the promise handle""" @@ -420,16 +442,12 @@ def sys_complete_promise_failure(self, name: str, failure: Failure) -> int: res = PyFailure(failure.code, failure.message) return self.vm.sys_complete_promise_failure(name, res) - def propose_run_completion_failure(self, handle: int, output: Failure) -> int: + def propose_run_completion_failure(self, handle: int, output: Failure) -> None: """ - Exit a side effect - - Args: - name: The name of the side effect. - output: The output of the side effect. + Exit a side effect with a terminal failure. """ res = PyFailure(output.code, output.message) - return self.vm.propose_run_completion_failure(handle, res) + self.vm.propose_run_completion_failure(handle, res) # pylint: disable=line-too-long def propose_run_completion_transient( diff --git a/src/lib.rs b/src/lib.rs index 78b4dd4..e582919 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,9 +3,9 @@ use pyo3::prelude::*; use pyo3::types::{PyBytes, PyNone, PyString}; use restate_sdk_shared_core::fmt::{set_error_formatter, ErrorFormatter}; use restate_sdk_shared_core::{ - CallHandle, CoreVM, DoProgressResponse, Error, Header, IdentityVerifier, Input, NonEmptyValue, + AwaitResponse, CallHandle, CoreVM, Error, Header, IdentityVerifier, Input, NonEmptyValue, NotificationHandle, ResponseHead, RetryPolicy, RunExitResult, TakeOutputResult, Target, - TerminalFailure, VMOptions, Value, CANCEL_NOTIFICATION_HANDLE, VM, + TerminalFailure, UnresolvedFuture, VMOptions, Value, CANCEL_NOTIFICATION_HANDLE, VM, }; use std::fmt; use std::time::{Duration, SystemTime}; @@ -245,7 +245,7 @@ impl From for PyInput { } #[pyclass] -struct PyDoProgressReadFromInput; +struct PyDoProgressWaitExternalProgress; #[pyclass] struct PyDoProgressAnyCompleted; @@ -260,7 +260,38 @@ struct PyDoProgressExecuteRun { struct PyDoProgressCancelSignalReceived; #[pyclass] -struct PyDoWaitForPendingRun; +#[derive(Clone)] +pub struct PyUnresolvedFuture { + inner: UnresolvedFuture, +} + +#[pymethods] +impl PyUnresolvedFuture { + #[staticmethod] + fn single(handle: PyNotificationHandle) -> Self { + PyUnresolvedFuture { + inner: UnresolvedFuture::Single(NotificationHandle::from(handle)), + } + } + + #[staticmethod] + fn first_completed(children: Vec>) -> Self { + PyUnresolvedFuture { + inner: UnresolvedFuture::FirstCompleted( + children.into_iter().map(|c| c.inner.clone()).collect(), + ), + } + } + + #[staticmethod] + fn all_completed(children: Vec>) -> Self { + PyUnresolvedFuture { + inner: UnresolvedFuture::AllCompleted( + children.into_iter().map(|c| c.inner.clone()).collect(), + ), + } + } +} #[pyclass] pub struct PyCallHandle { @@ -362,41 +393,33 @@ impl PyVM { self_.vm.is_completed(handle.into()) } - fn do_progress( - mut self_: PyRefMut<'_, Self>, - any_handle: Vec, - ) -> PyResult> { - let res = self_.vm.do_progress( - any_handle - .into_iter() - .map(NotificationHandle::from) - .collect(), - ); + fn do_progress<'py>( + mut self_: PyRefMut<'py, Self>, + unresolved_future: PyRef<'py, PyUnresolvedFuture>, + ) -> PyResult> { + let res = self_.vm.do_await(unresolved_future.inner.clone()); let py = self_.py(); match res { Err(e) if e.is_suspended_error() => Ok(Bound::new(py, PySuspended)?.into_any()), Err(e) => Err(PyVMError::from(e))?, - Ok(DoProgressResponse::AnyCompleted) => { + Ok(AwaitResponse::AnyCompleted) => { Ok(Bound::new(py, PyDoProgressAnyCompleted)?.into_any()) } - Ok(DoProgressResponse::ReadFromInput) => { - Ok(Bound::new(py, PyDoProgressReadFromInput)?.into_any()) + Ok(AwaitResponse::WaitingExternalProgress { .. }) => { + Ok(Bound::new(py, PyDoProgressWaitExternalProgress)?.into_any()) } - Ok(DoProgressResponse::ExecuteRun(handle)) => Ok(Bound::new( + Ok(AwaitResponse::ExecuteRun(handle)) => Ok(Bound::new( py, PyDoProgressExecuteRun { handle: handle.into(), }, )? .into_any()), - Ok(DoProgressResponse::CancelSignalReceived) => { + Ok(AwaitResponse::CancelSignalReceived) => { Ok(Bound::new(py, PyDoProgressCancelSignalReceived)?.into_any()) } - Ok(DoProgressResponse::WaitingPendingRun) => { - Ok(Bound::new(py, PyDoWaitForPendingRun)?.into_any()) - } } } @@ -890,10 +913,10 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add("VMException", m.py().get_type::())?; diff --git a/tests/disconnect_hotloop.py b/tests/disconnect_hotloop.py index f81a724..46a7b57 100644 --- a/tests/disconnect_hotloop.py +++ b/tests/disconnect_hotloop.py @@ -105,17 +105,17 @@ async def mock_receive() -> ASGIReceiveEvent: async def test_empty_body_frames_do_not_cause_hotloop(): """ - When the VM returns DoProgressReadFromInput and the chunk has body=b'', + When the VM returns DoProgressWaitExternalProgress and the chunk has body=b'', notify_input should NOT be called (it would cause a tight loop). The loop should exit via DisconnectedException when http.disconnect arrives. """ from restate.server_context import ServerInvocationContext, DisconnectedException - from restate.vm import DoProgressReadFromInput + from restate.vm import DoProgressWaitExternalProgress, SingleUnresolvedFuture # Build a minimal mock context vm = MagicMock() vm.take_output.return_value = None - vm.do_progress.return_value = DoProgressReadFromInput() + vm.do_progress.return_value = DoProgressWaitExternalProgress() handler = MagicMock() invocation = MagicMock() @@ -149,7 +149,7 @@ async def mock_receive() -> ASGIReceiveEvent: try: with pytest.raises(DisconnectedException): await asyncio.wait_for( - ctx.create_poll_or_cancel_coroutine([0]), + ctx.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(0)), timeout=2.0, )