diff --git a/Cargo.lock b/Cargo.lock
index 4fc1f75..16ae782 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -806,8 +806,7 @@ dependencies = [
[[package]]
name = "restate-sdk-shared-core"
version = "0.9.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5f0417331c92f9ef4a14dcbd40759cb4008221a3538affda05f907923c574119"
+source = "git+https://github.com/restatedev/sdk-shared-core.git?branch=awaiting_on#4b8fb2b900d5a14acc65218463c5c7578a16cabb"
dependencies = [
"base64",
"bs58",
diff --git a/Cargo.toml b/Cargo.toml
index 7e4b467..db0de3b 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -14,4 +14,4 @@ doc = false
[dependencies]
pyo3 = { version = "0.25.1", features = ["extension-module"] }
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
-restate-sdk-shared-core = { version = "=0.9.0", features = ["request_identity", "sha2_random_seed"] }
+restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core.git", branch = "awaiting_on", features = ["request_identity", "sha2_random_seed"] }
diff --git a/python/restate/asyncio.py b/python/restate/asyncio.py
index 7c8d18f..8fa5bf5 100644
--- a/python/restate/asyncio.py
+++ b/python/restate/asyncio.py
@@ -16,6 +16,11 @@
from restate.exceptions import TerminalError
from restate.context import RestateDurableFuture
from restate.server_context import ServerDurableFuture, ServerInvocationContext
+from restate.vm import (
+ AllCompletedUnresolvedFuture,
+ FirstCompletedUnresolvedFuture,
+ SingleUnresolvedFuture,
+)
async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFuture[Any]]:
@@ -24,9 +29,28 @@ async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFutu
Returns a list of all futures.
"""
- async for _ in as_completed(*futures):
- pass
- return list(futures)
+ context: ServerInvocationContext | None = None
+ handles: List[int] = []
+ futures_list = list(futures)
+
+ if not futures_list:
+ return []
+ for f in futures_list:
+ if not isinstance(f, ServerDurableFuture):
+ raise TerminalError("All futures must SDK created futures.")
+ if context is None:
+ context = f.context
+ elif context is not f.context:
+ raise TerminalError("All futures must be created by the same SDK context.")
+ if not f.is_completed():
+ handles.append(f.handle)
+
+ if handles:
+ assert context is not None
+ await context.create_poll_or_cancel_coroutine(
+ AllCompletedUnresolvedFuture([SingleUnresolvedFuture(h) for h in handles])
+ )
+ return futures_list
async def select(**kws: RestateDurableFuture[Any]) -> List[Any]:
@@ -118,7 +142,9 @@ async def wait_completed(
completed = []
uncompleted = []
assert context is not None
- await context.create_poll_or_cancel_coroutine(handles)
+ await context.create_poll_or_cancel_coroutine(
+ FirstCompletedUnresolvedFuture([SingleUnresolvedFuture(h) for h in handles])
+ )
for index, handle in enumerate(handles):
future = futures[index]
diff --git a/python/restate/discovery.py b/python/restate/discovery.py
index 2c90c65..b6a0d95 100644
--- a/python/restate/discovery.py
+++ b/python/restate/discovery.py
@@ -429,4 +429,4 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as: typing.Literal["
protocol_mode = PROTOCOL_MODES[endpoint.protocol]
else:
protocol_mode = PROTOCOL_MODES[discovered_as]
- return Endpoint(protocolMode=protocol_mode, minProtocolVersion=5, maxProtocolVersion=6, services=services)
+ return Endpoint(protocolMode=protocol_mode, minProtocolVersion=5, maxProtocolVersion=7, services=services)
diff --git a/python/restate/server_context.py b/python/restate/server_context.py
index a7de250..721b29b 100644
--- a/python/restate/server_context.py
+++ b/python/restate/server_context.py
@@ -61,9 +61,10 @@
from restate.vm import (
DoProgressAnyCompleted,
DoProgressCancelSignalReceived,
- DoProgressReadFromInput,
+ DoProgressWaitExternalProgress,
DoProgressExecuteRun,
- DoWaitPendingRun,
+ SingleUnresolvedFuture,
+ UnresolvedFuture,
)
logger = logging.getLogger(__name__)
@@ -193,7 +194,7 @@ def __init__(self, context: "ServerInvocationContext", handle: int) -> None:
async def coro() -> str:
if not context.vm.is_completed(handle):
- await context.create_poll_or_cancel_coroutine([handle])
+ await context.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle))
invocation_id = await context.must_take_notification(handle)
return typing.cast(str, invocation_id)
@@ -235,7 +236,7 @@ def resolve(self, value: Any) -> Awaitable[None]:
async def await_point():
if not self.server_context.vm.is_completed(handle):
- await self.server_context.create_poll_or_cancel_coroutine([handle])
+ await self.server_context.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle))
await self.server_context.must_take_notification(handle)
return ServerDurableFuture(self.server_context, handle, await_point)
@@ -248,7 +249,7 @@ def reject(self, message: str, code: int = 500) -> Awaitable[None]:
async def await_point():
if not self.server_context.vm.is_completed(handle):
- await self.server_context.create_poll_or_cancel_coroutine([handle])
+ await self.server_context.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle))
await self.server_context.must_take_notification(handle)
return ServerDurableFuture(self.server_context, handle, await_point)
@@ -527,11 +528,11 @@ async def must_take_notification(self, handle):
raise TerminalError(res.message, res.code)
return res
- async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> None:
- """Create a coroutine to poll the handle."""
+ async def create_poll_or_cancel_coroutine(self, unresolved_future: UnresolvedFuture) -> None:
+ """Create a coroutine to poll the unresolved future."""
while True:
await self.take_and_send_output()
- do_progress_response = self.vm.do_progress(handles)
+ do_progress_response = self.vm.do_progress(unresolved_future)
if isinstance(do_progress_response, BaseException):
logger.exception("Exception in do_progress", exc_info=do_progress_response)
raise SdkInternalException() from do_progress_response
@@ -556,7 +557,7 @@ async def wrapper(f):
task = asyncio.create_task(wrapper(fn))
self.tasks.add(task)
continue
- if isinstance(do_progress_response, (DoWaitPendingRun, DoProgressReadFromInput)):
+ if isinstance(do_progress_response, DoProgressWaitExternalProgress):
chunk = await self.receive()
if chunk.get("type") == "restate.run_completed":
continue
@@ -574,7 +575,7 @@ def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = N
async def fetch_result():
if not self.vm.is_completed(handle):
- await self.create_poll_or_cancel_coroutine([handle])
+ await self.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle))
res = await self.must_take_notification(handle)
if res is None or serde is None:
return res
@@ -593,7 +594,7 @@ def create_sleep_future(self, handle: int) -> ServerDurableSleepFuture:
async def transform():
if not self.vm.is_completed(handle):
- await self.create_poll_or_cancel_coroutine([handle])
+ await self.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(handle))
await self.must_take_notification(handle)
return ServerDurableSleepFuture(self, handle, transform)
@@ -605,7 +606,7 @@ def create_call_future(
async def inv_id_factory():
if not self.vm.is_completed(invocation_id_handle):
- await self.create_poll_or_cancel_coroutine([invocation_id_handle])
+ await self.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(invocation_id_handle))
return await self.must_take_notification(invocation_id_handle)
return ServerCallDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde), inv_id_factory)
diff --git a/python/restate/vm.py b/python/restate/vm.py
index 3250a0a..2efaa8b 100644
--- a/python/restate/vm.py
+++ b/python/restate/vm.py
@@ -15,10 +15,10 @@
# pylint: disable=E1101,R0917
# pylint: disable=too-many-arguments
# pylint: disable=too-few-public-methods
-from typing import Optional
+from typing import List, Optional, Union
from datetime import timedelta
-from dataclasses import dataclass
+from dataclasses import dataclass, field
import typing
from restate._internal import (
PyVM,
@@ -30,10 +30,10 @@
PyStateKeys,
PyExponentialRetryConfig,
PyDoProgressAnyCompleted,
- PyDoProgressReadFromInput,
+ PyDoProgressWaitExternalProgress,
PyDoProgressExecuteRun,
- PyDoWaitForPendingRun,
PyDoProgressCancelSignalReceived,
+ PyUnresolvedFuture,
CANCEL_NOTIFICATION_HANDLE,
) # pylint: disable=import-error,no-name-in-module,line-too-long
@@ -105,9 +105,10 @@ class DoProgressAnyCompleted:
"""
-class DoProgressReadFromInput:
+class DoProgressWaitExternalProgress:
"""
- Represents a notification that the input needs to be read.
+ Represents a notification that external progress is required
+ (either new input from the server or a pending run proposal).
"""
@@ -128,26 +129,57 @@ class DoProgressCancelSignalReceived:
"""
-class DoWaitPendingRun:
- """
- Represents a notification that a run is pending
- """
-
-
DO_PROGRESS_ANY_COMPLETED = DoProgressAnyCompleted()
-DO_PROGRESS_READ_FROM_INPUT = DoProgressReadFromInput()
+DO_PROGRESS_WAIT_EXTERNAL_PROGRESS = DoProgressWaitExternalProgress()
DO_PROGRESS_CANCEL_SIGNAL_RECEIVED = DoProgressCancelSignalReceived()
-DO_WAIT_PENDING_RUN = DoWaitPendingRun()
DoProgressResult = typing.Union[
DoProgressAnyCompleted,
- DoProgressReadFromInput,
+ DoProgressWaitExternalProgress,
DoProgressExecuteRun,
DoProgressCancelSignalReceived,
- DoWaitPendingRun,
]
+@dataclass(frozen=True)
+class SingleUnresolvedFuture:
+ """A single leaf handle."""
+
+ handle: int
+
+
+@dataclass(frozen=True)
+class FirstCompletedUnresolvedFuture:
+ """first child to complete (success or failure) wins."""
+
+ children: List["UnresolvedFuture"] = field(default_factory=list)
+
+
+@dataclass(frozen=True)
+class AllCompletedUnresolvedFuture:
+ """wait for all children to complete."""
+
+ children: List["UnresolvedFuture"] = field(default_factory=list)
+
+
+UnresolvedFuture = Union[
+ SingleUnresolvedFuture,
+ FirstCompletedUnresolvedFuture,
+ AllCompletedUnresolvedFuture,
+]
+
+
+def _unresolved_future_to_pyo3(uf: UnresolvedFuture) -> PyUnresolvedFuture:
+ """Recursively convert a Python-side UnresolvedFuture dataclass to its PyO3 pyclass."""
+ if isinstance(uf, SingleUnresolvedFuture):
+ return PyUnresolvedFuture.single(uf.handle)
+ if isinstance(uf, FirstCompletedUnresolvedFuture):
+ return PyUnresolvedFuture.first_completed([_unresolved_future_to_pyo3(c) for c in uf.children])
+ if isinstance(uf, AllCompletedUnresolvedFuture):
+ return PyUnresolvedFuture.all_completed([_unresolved_future_to_pyo3(c) for c in uf.children])
+ raise TypeError(f"Unknown UnresolvedFuture variant: {type(uf).__name__}")
+
+
# pylint: disable=too-many-public-methods
class VMWrapper:
"""
@@ -195,24 +227,22 @@ def is_completed(self, handle: int) -> bool:
return self.vm.is_completed(handle)
# pylint: disable=R0911
- def do_progress(self, handles: list[int]) -> typing.Union[DoProgressResult, Exception, Suspended]:
+ def do_progress(self, unresolved_future: UnresolvedFuture) -> typing.Union[DoProgressResult, Exception, Suspended]:
"""Do progress with notifications."""
try:
- result = self.vm.do_progress(handles)
+ result = self.vm.do_progress(_unresolved_future_to_pyo3(unresolved_future))
except VMException as e:
return e
if isinstance(result, PySuspended):
return SUSPENDED
if isinstance(result, PyDoProgressAnyCompleted):
return DO_PROGRESS_ANY_COMPLETED
- if isinstance(result, PyDoProgressReadFromInput):
- return DO_PROGRESS_READ_FROM_INPUT
+ if isinstance(result, PyDoProgressWaitExternalProgress):
+ return DO_PROGRESS_WAIT_EXTERNAL_PROGRESS
if isinstance(result, PyDoProgressExecuteRun):
return DoProgressExecuteRun(result.handle)
if isinstance(result, PyDoProgressCancelSignalReceived):
return DO_PROGRESS_CANCEL_SIGNAL_RECEIVED
- if isinstance(result, PyDoWaitForPendingRun):
- return DO_WAIT_PENDING_RUN
return ValueError(f"Unknown progress type: {result}")
def take_notification(self, handle: int) -> typing.Union[NotificationType, Exception, Suspended]:
@@ -343,9 +373,8 @@ def sys_call(
headers: typing.Optional[typing.List[typing.Tuple[str, str]]] = None,
):
"""Call a service"""
- if headers:
- headers = [PyHeader(key=h[0], value=h[1]) for h in headers]
- return self.vm.sys_call(service, handler, parameter, key, idempotency_key, headers)
+ py_headers = [PyHeader(key=h[0], value=h[1]) for h in headers] if headers else None
+ return self.vm.sys_call(service, handler, parameter, key, idempotency_key, py_headers)
# pylint: disable=too-many-arguments
def sys_send(
@@ -362,9 +391,8 @@ def sys_send(
send an invocation to a service, and return the handle
to the promise that will resolve with the invocation id
"""
- if headers:
- headers = [PyHeader(key=h[0], value=h[1]) for h in headers]
- return self.vm.sys_send(service, handler, parameter, key, delay, idempotency_key, headers)
+ py_headers = [PyHeader(key=h[0], value=h[1]) for h in headers] if headers else None
+ return self.vm.sys_send(service, handler, parameter, key, delay, idempotency_key, py_headers)
def sys_run(self, name: str) -> int:
"""
@@ -391,17 +419,11 @@ def sys_reject_awakeable(self, name: str, failure: Failure):
py_failure = PyFailure(failure.code, failure.message)
self.vm.sys_complete_awakeable_failure(name, py_failure)
- def propose_run_completion_success(self, handle: int, output: bytes) -> int:
+ def propose_run_completion_success(self, handle: int, output: bytes) -> None:
"""
- Exit a side effect
-
- Args:
- output: The output of the side effect.
-
- Returns:
- handle
+ Exit a side effect with a success value.
"""
- return self.vm.propose_run_completion_success(handle, output)
+ self.vm.propose_run_completion_success(handle, output)
def sys_get_promise(self, name: str) -> int:
"""Returns the promise handle"""
@@ -420,16 +442,12 @@ def sys_complete_promise_failure(self, name: str, failure: Failure) -> int:
res = PyFailure(failure.code, failure.message)
return self.vm.sys_complete_promise_failure(name, res)
- def propose_run_completion_failure(self, handle: int, output: Failure) -> int:
+ def propose_run_completion_failure(self, handle: int, output: Failure) -> None:
"""
- Exit a side effect
-
- Args:
- name: The name of the side effect.
- output: The output of the side effect.
+ Exit a side effect with a terminal failure.
"""
res = PyFailure(output.code, output.message)
- return self.vm.propose_run_completion_failure(handle, res)
+ self.vm.propose_run_completion_failure(handle, res)
# pylint: disable=line-too-long
def propose_run_completion_transient(
diff --git a/src/lib.rs b/src/lib.rs
index 78b4dd4..e582919 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -3,9 +3,9 @@ use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyNone, PyString};
use restate_sdk_shared_core::fmt::{set_error_formatter, ErrorFormatter};
use restate_sdk_shared_core::{
- CallHandle, CoreVM, DoProgressResponse, Error, Header, IdentityVerifier, Input, NonEmptyValue,
+ AwaitResponse, CallHandle, CoreVM, Error, Header, IdentityVerifier, Input, NonEmptyValue,
NotificationHandle, ResponseHead, RetryPolicy, RunExitResult, TakeOutputResult, Target,
- TerminalFailure, VMOptions, Value, CANCEL_NOTIFICATION_HANDLE, VM,
+ TerminalFailure, UnresolvedFuture, VMOptions, Value, CANCEL_NOTIFICATION_HANDLE, VM,
};
use std::fmt;
use std::time::{Duration, SystemTime};
@@ -245,7 +245,7 @@ impl From for PyInput {
}
#[pyclass]
-struct PyDoProgressReadFromInput;
+struct PyDoProgressWaitExternalProgress;
#[pyclass]
struct PyDoProgressAnyCompleted;
@@ -260,7 +260,38 @@ struct PyDoProgressExecuteRun {
struct PyDoProgressCancelSignalReceived;
#[pyclass]
-struct PyDoWaitForPendingRun;
+#[derive(Clone)]
+pub struct PyUnresolvedFuture {
+ inner: UnresolvedFuture,
+}
+
+#[pymethods]
+impl PyUnresolvedFuture {
+ #[staticmethod]
+ fn single(handle: PyNotificationHandle) -> Self {
+ PyUnresolvedFuture {
+ inner: UnresolvedFuture::Single(NotificationHandle::from(handle)),
+ }
+ }
+
+ #[staticmethod]
+ fn first_completed(children: Vec>) -> Self {
+ PyUnresolvedFuture {
+ inner: UnresolvedFuture::FirstCompleted(
+ children.into_iter().map(|c| c.inner.clone()).collect(),
+ ),
+ }
+ }
+
+ #[staticmethod]
+ fn all_completed(children: Vec>) -> Self {
+ PyUnresolvedFuture {
+ inner: UnresolvedFuture::AllCompleted(
+ children.into_iter().map(|c| c.inner.clone()).collect(),
+ ),
+ }
+ }
+}
#[pyclass]
pub struct PyCallHandle {
@@ -362,41 +393,33 @@ impl PyVM {
self_.vm.is_completed(handle.into())
}
- fn do_progress(
- mut self_: PyRefMut<'_, Self>,
- any_handle: Vec,
- ) -> PyResult> {
- let res = self_.vm.do_progress(
- any_handle
- .into_iter()
- .map(NotificationHandle::from)
- .collect(),
- );
+ fn do_progress<'py>(
+ mut self_: PyRefMut<'py, Self>,
+ unresolved_future: PyRef<'py, PyUnresolvedFuture>,
+ ) -> PyResult> {
+ let res = self_.vm.do_await(unresolved_future.inner.clone());
let py = self_.py();
match res {
Err(e) if e.is_suspended_error() => Ok(Bound::new(py, PySuspended)?.into_any()),
Err(e) => Err(PyVMError::from(e))?,
- Ok(DoProgressResponse::AnyCompleted) => {
+ Ok(AwaitResponse::AnyCompleted) => {
Ok(Bound::new(py, PyDoProgressAnyCompleted)?.into_any())
}
- Ok(DoProgressResponse::ReadFromInput) => {
- Ok(Bound::new(py, PyDoProgressReadFromInput)?.into_any())
+ Ok(AwaitResponse::WaitingExternalProgress { .. }) => {
+ Ok(Bound::new(py, PyDoProgressWaitExternalProgress)?.into_any())
}
- Ok(DoProgressResponse::ExecuteRun(handle)) => Ok(Bound::new(
+ Ok(AwaitResponse::ExecuteRun(handle)) => Ok(Bound::new(
py,
PyDoProgressExecuteRun {
handle: handle.into(),
},
)?
.into_any()),
- Ok(DoProgressResponse::CancelSignalReceived) => {
+ Ok(AwaitResponse::CancelSignalReceived) => {
Ok(Bound::new(py, PyDoProgressCancelSignalReceived)?.into_any())
}
- Ok(DoProgressResponse::WaitingPendingRun) => {
- Ok(Bound::new(py, PyDoWaitForPendingRun)?.into_any())
- }
}
}
@@ -890,10 +913,10 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::()?;
m.add_class::()?;
m.add_class::()?;
- m.add_class::()?;
+ m.add_class::()?;
m.add_class::()?;
m.add_class::()?;
- m.add_class::()?;
+ m.add_class::()?;
m.add_class::()?;
m.add("VMException", m.py().get_type::())?;
diff --git a/tests/disconnect_hotloop.py b/tests/disconnect_hotloop.py
index f81a724..46a7b57 100644
--- a/tests/disconnect_hotloop.py
+++ b/tests/disconnect_hotloop.py
@@ -105,17 +105,17 @@ async def mock_receive() -> ASGIReceiveEvent:
async def test_empty_body_frames_do_not_cause_hotloop():
"""
- When the VM returns DoProgressReadFromInput and the chunk has body=b'',
+ When the VM returns DoProgressWaitExternalProgress and the chunk has body=b'',
notify_input should NOT be called (it would cause a tight loop).
The loop should exit via DisconnectedException when http.disconnect arrives.
"""
from restate.server_context import ServerInvocationContext, DisconnectedException
- from restate.vm import DoProgressReadFromInput
+ from restate.vm import DoProgressWaitExternalProgress, SingleUnresolvedFuture
# Build a minimal mock context
vm = MagicMock()
vm.take_output.return_value = None
- vm.do_progress.return_value = DoProgressReadFromInput()
+ vm.do_progress.return_value = DoProgressWaitExternalProgress()
handler = MagicMock()
invocation = MagicMock()
@@ -149,7 +149,7 @@ async def mock_receive() -> ASGIReceiveEvent:
try:
with pytest.raises(DisconnectedException):
await asyncio.wait_for(
- ctx.create_poll_or_cancel_coroutine([0]),
+ ctx.create_poll_or_cancel_coroutine(SingleUnresolvedFuture(0)),
timeout=2.0,
)