-
Notifications
You must be signed in to change notification settings - Fork 180
feat: add workflow connector integration #499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
78f246a
chore: add wf connector integration
pevers 61e1cd2
wip
pevers 2f28100
wip
pevers f2785f3
wip
pevers c09b000
wip
pevers 316a0a8
wip
pevers 8e61de9
fix: remove opiniated client interaction
pevers ce8d908
wip
pevers File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,262 @@ | ||
| """Helper for executing workflows that require connector OAuth authentication. | ||
|
|
||
| When a workflow uses connectors that need OAuth, it emits ``connector-auth`` | ||
| custom task events. This module provides a high-level async function that | ||
| automates the handshake: | ||
|
|
||
| 1. Start the workflow execution. | ||
| 2. Stream events, watching for ``connector-auth`` custom task events. | ||
| 3. When a ``waiting_for_auth`` event arrives, invoke a user-supplied callback. | ||
| 4. The interceptor polls for credentials server-side and resumes automatically. | ||
| 5. Return the final execution result once the workflow completes. | ||
|
|
||
| Example:: | ||
|
|
||
| from mistralai import Mistral | ||
| from mistralai.extra.workflows import ( | ||
| ConnectorAuthTaskState, | ||
| ConnectorSlot, | ||
| execute_with_connector_auth_async, | ||
| ) | ||
|
|
||
| async def prompt_user(state: ConnectorAuthTaskState) -> None: | ||
| print(f"Please authenticate: {state.auth_url}") | ||
|
|
||
| gmail = ConnectorSlot(connector_name="gmail") | ||
|
|
||
| client = Mistral(api_key="...") | ||
| result = await execute_with_connector_auth_async( | ||
| client, | ||
| workflow_identifier="my-workflow", | ||
| input_data={"query": "summarize my emails"}, | ||
| on_auth_required=prompt_user, | ||
| connectors=[gmail], | ||
| ) | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import logging | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| Any, | ||
| Awaitable, | ||
| Callable, | ||
| Dict, | ||
| Optional, | ||
| Sequence, | ||
| ) | ||
|
|
||
| import httpx | ||
| from pydantic import BaseModel | ||
|
|
||
| from mistralai.client.models import ( | ||
| CustomTaskStartedResponse, | ||
| WorkflowExecutionCanceledResponse, | ||
| WorkflowExecutionCompletedResponse, | ||
| WorkflowExecutionFailedResponse, | ||
| WorkflowExecutionResponse, | ||
| ) | ||
|
|
||
| from .connector_slot import ConnectorSlot, WorkflowExtensions | ||
|
|
||
| if TYPE_CHECKING: | ||
| from mistralai.client.sdk import Mistral | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _TERMINAL_EVENT_TYPES = ( | ||
| WorkflowExecutionCompletedResponse, | ||
| WorkflowExecutionFailedResponse, | ||
| WorkflowExecutionCanceledResponse, | ||
| ) | ||
|
|
||
| _MAX_RECONNECT_ATTEMPTS = 10 | ||
|
|
||
|
|
||
| class ConnectorAuthTaskState(BaseModel): | ||
| """State emitted by a ``connector_auth`` custom task when it needs OAuth. | ||
|
|
||
| Attributes: | ||
| connector_name: Identifier of the connector requiring authentication. | ||
| connector_id: Server-side connector ID. | ||
| credentials_name: Optional named credential set used for this connector. | ||
| auth_url: URL the user should visit to complete authentication. | ||
| message: Optional human-readable context about the auth request. | ||
| """ | ||
|
|
||
| connector_name: str | ||
| connector_id: str | ||
| credentials_name: Optional[str] = None | ||
| auth_url: Optional[str] = None | ||
| message: Optional[str] = None | ||
|
|
||
|
|
||
| async def execute_with_connector_auth_async( | ||
| client: Mistral, | ||
| workflow_identifier: str, | ||
| input_data: Any = None, | ||
| *, | ||
| on_auth_required: Optional[ | ||
| Callable[[ConnectorAuthTaskState], Awaitable[None]] | ||
| ] = None, | ||
| execution_id: Optional[str] = None, | ||
| task_queue: Optional[str] = None, | ||
| deployment_name: Optional[str] = None, | ||
| connectors: Sequence[ConnectorSlot] = (), | ||
| polling_interval: float = 2, | ||
| max_polling_attempts: Optional[int] = None, | ||
| ) -> WorkflowExecutionResponse: | ||
| """Execute a workflow, automatically handling connector OAuth flows. | ||
|
|
||
| Args: | ||
| client: An initialised :class:`Mistral` client. | ||
| workflow_identifier: Name or ID of the workflow to execute. | ||
| input_data: Input payload for the workflow. Pydantic models are | ||
| serialised via ``model_dump(mode="json")``. | ||
| on_auth_required: Async callback invoked when a connector needs | ||
| the user to authenticate. Receives a | ||
| :class:`ConnectorAuthTaskState` whose ``auth_url`` field | ||
| contains the OAuth URL. The workflow resumes automatically | ||
| after this callback returns. | ||
| execution_id: Optional custom execution ID. | ||
| task_queue: Optional task queue name (deprecated upstream). | ||
| deployment_name: Optional deployment target. | ||
| connectors: Typed connector slots that declare which connectors | ||
| the workflow needs. | ||
| polling_interval: Seconds between status polls after the event | ||
| stream ends. | ||
| max_polling_attempts: Maximum number of polling iterations before | ||
| raising :class:`TimeoutError`. ``None`` means poll forever. | ||
|
|
||
| Returns: | ||
| The completed :class:`WorkflowExecutionResponse`. | ||
|
|
||
| Raises: | ||
| RuntimeError: If the workflow finishes with a non-COMPLETED status. | ||
| TimeoutError: If *max_polling_attempts* is set and exceeded. | ||
| """ | ||
| extensions = ( | ||
| WorkflowExtensions.from_connectors(connectors).to_dict() if connectors else None | ||
| ) | ||
|
|
||
| execute_kwargs: Dict[str, Any] = dict( | ||
| workflow_identifier=workflow_identifier, | ||
| input=input_data, | ||
| execution_id=execution_id, | ||
| task_queue=task_queue, | ||
| deployment_name=deployment_name, | ||
| ) | ||
| if extensions is not None: | ||
| execute_kwargs["extensions"] = extensions | ||
|
|
||
| execution = await client.workflows.execute_workflow_async(**execute_kwargs) | ||
| exec_id = execution.execution_id | ||
|
|
||
| await _stream_and_handle_auth(client, exec_id, on_auth_required) | ||
|
|
||
| return await _poll_until_done( | ||
| client, exec_id, polling_interval, max_polling_attempts | ||
| ) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Internal helpers | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| async def _stream_and_handle_auth( | ||
| client: Mistral, | ||
| exec_id: str, | ||
| on_auth_required: Optional[Callable[[ConnectorAuthTaskState], Awaitable[None]]], | ||
| ) -> None: | ||
| """Stream workflow events, handling connector-auth tasks. | ||
|
|
||
| Reconnects automatically with exponential back-off when the SSE | ||
| connection drops. | ||
| """ | ||
| last_seq = 0 | ||
|
|
||
| for attempt in range(_MAX_RECONNECT_ATTEMPTS): | ||
| try: | ||
| event_stream = await client.workflows.events.get_stream_events_async( | ||
| root_workflow_exec_id=exec_id, | ||
| workflow_exec_id="*", | ||
| parent_workflow_exec_id="*", | ||
| start_seq=last_seq, | ||
| ) | ||
| async with event_stream: | ||
| async for sse_event in event_stream: | ||
| if sse_event.data is None: | ||
| continue | ||
|
|
||
| payload = sse_event.data | ||
| last_seq = payload.broker_sequence + 1 | ||
| event = payload.data | ||
|
|
||
| if isinstance(event, _TERMINAL_EVENT_TYPES): | ||
| return | ||
|
|
||
| if not isinstance(event, CustomTaskStartedResponse): | ||
| continue | ||
| if event.attributes.custom_task_type != "connector_auth": | ||
| continue | ||
|
|
||
| payload_value = ( | ||
| event.attributes.payload.value | ||
| if event.attributes.payload is not None | ||
| else None | ||
| ) | ||
| if not isinstance(payload_value, dict): | ||
| continue | ||
|
|
||
| state = ConnectorAuthTaskState.model_validate(payload_value) | ||
|
|
||
| if on_auth_required: | ||
| await on_auth_required(state) | ||
|
|
||
| # The interceptor polls for credentials server-side — | ||
| # no signal or update needed from the client. | ||
| else: | ||
| # Stream exhausted without a terminal event — retry. | ||
| continue | ||
| except (ConnectionError, httpx.RemoteProtocolError): | ||
| logger.debug( | ||
| "Event stream connection lost, reconnecting " | ||
| "(execution_id=%s, attempt=%d)", | ||
| exec_id, | ||
| attempt, | ||
| ) | ||
| await asyncio.sleep(min(2**attempt, 30)) | ||
| else: | ||
| logger.warning( | ||
| "Exhausted %d reconnect attempts for event stream (execution_id=%s)", | ||
| _MAX_RECONNECT_ATTEMPTS, | ||
| exec_id, | ||
| ) | ||
|
|
||
|
|
||
| async def _poll_until_done( | ||
| client: Mistral, | ||
| exec_id: str, | ||
| polling_interval: float, | ||
| max_attempts: Optional[int], | ||
| ) -> WorkflowExecutionResponse: | ||
| """Poll the execution status until it reaches a terminal state.""" | ||
| attempts = 0 | ||
| while True: | ||
| result = await client.workflows.executions.get_workflow_execution_async( | ||
| execution_id=exec_id, | ||
| ) | ||
| if result.status != "RUNNING": | ||
| if result.status == "COMPLETED": | ||
| return result | ||
| raise RuntimeError(f"Workflow failed with status: {result.status}") | ||
|
|
||
| attempts += 1 | ||
| if max_attempts is not None and attempts >= max_attempts: | ||
| raise TimeoutError( | ||
| f"Workflow still running after {max_attempts} polling attempts" | ||
| ) | ||
| await asyncio.sleep(polling_interval) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| """Typed descriptors for connector dependencies and extensions.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, Dict, List, Optional, Sequence | ||
|
|
||
| from pydantic import BaseModel | ||
|
|
||
|
|
||
| class ConnectorSlot(BaseModel): | ||
| """A declared connector dependency for a workflow execution. | ||
|
|
||
| Mirrors the server-side ``ConnectorSlot`` from the workflow SDK plugin, | ||
| providing a typed interface for specifying connector bindings instead of | ||
| raw ``Dict[str, Any]`` extension dicts. | ||
|
|
||
| Example:: | ||
|
|
||
| from mistralai.extra.workflows import ConnectorSlot | ||
|
|
||
| gmail = ConnectorSlot(connector_name="gmail") | ||
| notion = ConnectorSlot(connector_name="notion", credentials_name="work-account") | ||
| """ | ||
|
|
||
| connector_name: str | ||
| credentials_name: Optional[str] = None | ||
|
|
||
|
|
||
| class ConnectorBindings(BaseModel): | ||
| """Container for a list of connector bindings.""" | ||
|
|
||
| bindings: List[ConnectorSlot] | ||
|
|
||
|
|
||
| class ConnectorExtension(BaseModel): | ||
| """Mistral-specific extension carrying connector configuration.""" | ||
|
|
||
| connectors: ConnectorBindings | ||
|
|
||
|
|
||
| class WorkflowExtensions(BaseModel): | ||
| """Top-level extensions dict passed to the workflow execution API. | ||
|
|
||
| Serialises to the shape expected by the API:: | ||
|
|
||
| {"mistralai": {"connectors": {"bindings": [...]}}} | ||
| """ | ||
|
|
||
| mistralai: ConnectorExtension | ||
|
|
||
| @classmethod | ||
| def from_connectors(cls, connectors: Sequence[ConnectorSlot]) -> WorkflowExtensions: | ||
| """Build extensions from a sequence of connector slots.""" | ||
| return cls( | ||
| mistralai=ConnectorExtension( | ||
| connectors=ConnectorBindings(bindings=list(connectors)) | ||
| ) | ||
| ) | ||
|
|
||
| def to_dict(self) -> Dict[str, Any]: | ||
| """Serialise to the ``Dict[str, Any]`` the API expects.""" | ||
| return self.model_dump(mode="json", exclude_none=True) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.