diff --git a/ucapi/api.py b/ucapi/api.py index 0204db4..47586e4 100644 --- a/ucapi/api.py +++ b/ucapi/api.py @@ -78,6 +78,18 @@ class _VoiceSessionContext: handler_task: asyncio.Task | None = None +@dataclass(slots=True) +class _WsContext: + """Websocket context.""" + + incoming: asyncio.Queue[str | bytes | None] + outgoing: asyncio.Queue[str | None] + pending: dict[int, asyncio.Future] + consumer_task: asyncio.Task | None = None + producer_task: asyncio.Task | None = None + router_task: asyncio.Task | None = None + + # pylint: disable=too-many-public-methods, too-many-lines class IntegrationAPI: """Integration API to communicate with Remote Two/3.""" @@ -107,12 +119,18 @@ def __init__(self, loop: AbstractEventLoop | None = None): self._available_entities = Entities("available", self._loop) self._configured_entities = Entities("configured", self._loop) + self._req_id = 1 # Request ID counter for outgoing requests + self._voice_handler: VoiceStreamHandler | None = None self._voice_session_timeout: int = self.DEFAULT_VOICE_SESSION_TIMEOUT_S # Active voice sessions self._voice_sessions: dict[VoiceSessionKey, _VoiceSessionContext] = {} # Enforce: at most one active session per entity_id (across all websockets) self._voice_session_by_entity: dict[str, VoiceSessionKey] = {} + # Websocket context with incoming & outgoing queues and handlers + self._ws_contexts: dict[Any, _WsContext] = {} + # Supported entity types + self._supported_entity_types: list[str] | None = None # Setup event loop asyncio.set_event_loop(self._loop) @@ -214,40 +232,74 @@ async def _start_web_socket_server(self, host: str, port: int) -> None: await asyncio.Future() async def _handle_ws(self, websocket) -> None: + # Initialize incoming and outgoing queues + incoming: asyncio.Queue[str | bytes | None] = asyncio.Queue(maxsize=100) + outgoing: asyncio.Queue[str | None] = asyncio.Queue(maxsize=100) + + ctx = _WsContext( + incoming=incoming, + outgoing=outgoing, + pending={}, + ) + + self._clients.add(websocket) + self._ws_contexts[websocket] = ctx + try: - self._clients.add(websocket) _LOG.info("WS: Client added: %s", websocket.remote_address) + ctx.consumer_task = self._loop.create_task( + self._ws_consumer(websocket, ctx) + ) + ctx.producer_task = self._loop.create_task( + self._ws_producer(websocket, ctx) + ) + ctx.router_task = self._loop.create_task(self._ws_router(websocket, ctx)) + # authenticate on connection await self._authenticate(websocket, True) - self._events.emit(uc.Events.CLIENT_CONNECTED, websocket=websocket) + tasks = [ + t + for t in [ctx.consumer_task, ctx.producer_task, ctx.router_task] + if t is not None + ] + done, pending = await asyncio.wait( + tasks, + return_when=asyncio.FIRST_COMPLETED, + ) - async for message in websocket: - # Distinguish between text (str) and binary (bytes-like) messages - if isinstance(message, str): - # JSON text message - await self._process_ws_message(websocket, message) - elif isinstance(message, (bytes, bytearray, memoryview)): - # Binary message (protobuf in future) - await self._process_ws_binary_message(websocket, bytes(message)) - else: - _LOG.warning( - "[%s] WS: Unsupported message type %s", + if pending: + # graceful shutdown: wait a bit for pending tasks to process sentinel 'None' + _LOG.debug("[%s] WS: Draining tasks", websocket.remote_address) + await asyncio.wait(pending, timeout=1.0) + + for task in pending: + task.cancel() + + results = await asyncio.gather(*done, *pending, return_exceptions=True) + for result in results: + if isinstance(result, Exception) and not isinstance( + result, asyncio.CancelledError + ): + _LOG.error( + "[%s] WS: Exception in task", websocket.remote_address, - type(message).__name__, + exc_info=result, ) except ConnectionClosedOK: _LOG.info("[%s] WS: Connection closed", websocket.remote_address) except websockets.exceptions.ConnectionClosedError as e: - # no idea why they made code & reason deprecated... + close = e.rcvd or e.sent + code = getattr(close, "code", None) + reason = getattr(close, "reason", None) _LOG.info( - "[%s] WS: Connection closed with error %d: %s", + "[%s] WS: Connection closed with error %s: %s", websocket.remote_address, - e.code, - e.reason, + code, + reason, ) except websockets.exceptions.WebSocketException as e: @@ -258,22 +310,101 @@ async def _handle_ws(self, websocket) -> None: ) finally: - # Cleanup any active voice sessions associated with this websocket - keys_to_cleanup = [k for k in self._voice_sessions if k[0] is websocket] - for key in keys_to_cleanup: - try: - await self._cleanup_voice_session(key, VoiceEndReason.REMOTE) - except Exception as ex: # pylint: disable=W0718 - _LOG.exception( - "[%s] WS: Error during voice session cleanup for session_id=%s: %s", + await self._cleanup_ws(websocket) + + async def _ws_consumer(self, websocket, ctx: _WsContext) -> None: + """Route incoming message (requests or events from remote or responses to driver).""" + try: + async for raw_message in websocket: + if isinstance(raw_message, str): + try: + data = json.loads(raw_message) + except json.JSONDecodeError: + _LOG.warning( + "[%s] WS: Invalid JSON message: %s", + websocket.remote_address, + raw_message, + ) + continue + + kind: str | None = None + if isinstance(data, dict): + kind = data.get("kind") + + # Handle the response to a previous driver request + if kind == "resp": + self._handle_pending_response(websocket, data) + # Otherwise handle the json request + else: + await ctx.incoming.put(data) + # Handle the binary message + elif isinstance(raw_message, (bytes, bytearray, memoryview)): + await ctx.incoming.put(bytes(raw_message)) + else: + _LOG.warning( + "[%s] WS: Unsupported message type %s", websocket.remote_address, - key[1], - ex, + type(raw_message).__name__, ) + finally: + await ctx.incoming.put(None) + await ctx.outgoing.put(None) + + async def _ws_producer(self, websocket, ctx: _WsContext) -> None: + """Route outgoing messages.""" + try: + while True: + msg = await ctx.outgoing.get() + if msg is None: + break + await websocket.send(msg) + except (ConnectionClosedOK, websockets.exceptions.ConnectionClosedError): + pass + + async def _ws_router(self, websocket, ctx: _WsContext) -> None: + """Route incoming requests.""" + while True: + message = await ctx.incoming.get() + if message is None: + break + if isinstance(message, dict): + await self._process_ws_message(websocket, message) + elif isinstance(message, bytes): + await self._process_ws_binary_message(websocket, message) + else: + _LOG.warning( + "[%s] WS: Unsupported routed message type %s", + websocket.remote_address, + type(message).__name__, + ) + + def _get_ws_context(self, websocket) -> _WsContext | None: + return self._ws_contexts.get(websocket) + + async def _enqueue_ws_payload(self, websocket, payload: dict[str, Any]) -> None: + ctx = self._get_ws_context(websocket) + if ctx is None or websocket not in self._clients: + _LOG.error("Error sending payload: connection no longer established") + return + + if _LOG.isEnabledFor(logging.DEBUG): + _LOG.debug( + "[%s] ->: %s", websocket.remote_address, filter_log_msg_data(payload) + ) - self._clients.remove(websocket) - _LOG.info("[%s] WS: Client removed", websocket.remote_address) - self._events.emit(uc.Events.CLIENT_DISCONNECTED, websocket=websocket) + match payload.get("kind"): + case "event": + try: + ctx.outgoing.put_nowait(json.dumps(payload)) + except asyncio.QueueFull: + _LOG.warning( + "[%s] Outgoing queue full, dropping event", + websocket.remote_address, + ) + case "req": + ctx.outgoing.put_nowait(json.dumps(payload)) + case _: + await ctx.outgoing.put(json.dumps(payload)) async def _send_ok_result( self, websocket, req_id: int, msg_data: dict[str, Any] | list | None = None @@ -312,7 +443,7 @@ async def _send_error_result( """ await self._send_ws_response(websocket, req_id, "result", msg_data, status_code) - # pylint: disable=R0917 + # pylint: disable=too-many-positional-arguments async def _send_ws_response( self, websocket, @@ -340,16 +471,7 @@ async def _send_ws_response( "msg": msg, "msg_data": msg_data if msg_data is not None else {}, } - - if websocket in self._clients: - data_dump = json.dumps(data) - if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "[%s] ->: %s", websocket.remote_address, filter_log_msg_data(data) - ) - await websocket.send(data_dump) - else: - _LOG.error("Error sending response: connection no longer established") + await self._enqueue_ws_payload(websocket, data) async def _broadcast_ws_event( self, msg: str, msg_data: dict[str, Any], category: uc.EventCategory @@ -365,17 +487,13 @@ async def _broadcast_ws_event( :param category: event category """ data = {"kind": "event", "msg": msg, "msg_data": msg_data, "cat": category} - data_dump = json.dumps(data) - for websocket in self._clients.copy(): - if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "[%s] =>: %s", websocket.remote_address, filter_log_msg_data(data) - ) try: - await websocket.send(data_dump) - except websockets.exceptions.WebSocketException: - pass + await self._enqueue_ws_payload(websocket, data) + except Exception: # pylint: disable=broad-exception-caught + _LOG.exception( + "Failed to enqueue broadcast for %s", websocket.remote_address + ) async def _send_ws_event( self, websocket, msg: str, msg_data: dict[str, Any], category: uc.EventCategory @@ -392,35 +510,116 @@ async def _send_ws_event( websockets.ConnectionClosed: When the connection is closed. """ data = {"kind": "event", "msg": msg, "msg_data": msg_data, "cat": category} - data_dump = json.dumps(data) + await self._enqueue_ws_payload(websocket, data) - if websocket in self._clients: - if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "[%s] ->: %s", websocket.remote_address, filter_log_msg_data(data) - ) - await websocket.send(data_dump) - else: - _LOG.error("Error sending event: connection no longer established") + async def _process_ws_message(self, websocket, data: dict[str, Any]) -> None: + _LOG.debug("[%s] <-: %s", websocket.remote_address, data) - async def _process_ws_message(self, websocket, message) -> None: - _LOG.debug("[%s] <-: %s", websocket.remote_address, message) - - data = json.loads(message) kind = data["kind"] - req_id = data["id"] if "id" in data else None + req_id = data.get("id") msg = data["msg"] - msg_data = data["msg_data"] if "msg_data" in data else None + msg_data = data.get("msg_data") if kind == "req": if req_id is None: _LOG.warning( - "Ignoring request message with missing 'req_id': %s", message + "Ignoring request message with missing 'id': %s", + data, ) - else: - await self._handle_ws_request_msg(websocket, msg, req_id, msg_data) + return + await self._handle_ws_request_msg(websocket, msg, req_id, msg_data) elif kind == "event": await self._handle_ws_event_msg(websocket, msg, msg_data) + else: + _LOG.warning( + "[%s] WS: Unsupported routed message kind %s", + websocket.remote_address, + kind, + ) + + def _handle_pending_response(self, websocket, data: dict[str, Any]) -> None: + """Resolve the response message that corresponds to a pending request from the driver.""" + resp_id = data.get("req_id", data.get("id")) + if resp_id is None: + _LOG.warning( + "[%s] WS: Received resp without req_id/id: %s", + websocket.remote_address, + data, + ) + return + + ctx = self._get_ws_context(websocket) + if ctx is None: + _LOG.debug("[%s] WS: No context for resp", websocket.remote_address) + return + + fut = ctx.pending.get(int(resp_id)) + if fut is None: + _LOG.debug( + "[%s] WS: Unmatched resp_id=%s (not pending). msg=%s", + websocket.remote_address, + resp_id, + data.get("msg"), + ) + return + + if not fut.done(): + fut.set_result(data) + + async def _ws_request( + self, + websocket, + msg: str, + msg_data: dict[str, Any] | None = None, + *, + timeout: float = 10.0, + ) -> dict[str, Any]: + """ + Send a request over websocket and await the matching response. + + - Uses a Future stored in self._ws_pending[websocket][req_id] + - Reader task (_handle_ws -> _process_ws_message) completes the future on 'resp' + - Raises TimeoutError on timeout + :param websocket: client connection + :param msg: event message name + :param msg_data: message data payload + :param timeout: timeout for message + """ + # Ensure per-socket structures exist (in case you call before _handle_ws init) + ctx = self._get_ws_context(websocket) + if ctx is None: + raise ConnectionError("WebSocket context not found") + + # Allocate req_id safely + req_id = self._req_id + self._req_id += 1 + + fut = self._loop.create_future() + ctx.pending[req_id] = fut + + try: + payload: dict[str, Any] = {"kind": "req", "id": req_id, "msg": msg} + if msg_data is not None: + payload["msg_data"] = msg_data + + await self._enqueue_ws_payload(websocket, payload) + + # Await response from client until given timeout + resp = await asyncio.wait_for(fut, timeout=timeout) + return resp + + except asyncio.TimeoutError as ex: + _LOG.error( + "[%s] Timeout waiting for response to %s (req_id=%s) %s", + websocket.remote_address, + msg, + req_id, + ex, + ) + raise ex + finally: + # Cleanup pending future entry + ctx.pending.pop(req_id, None) async def _process_ws_binary_message(self, websocket, data: bytes) -> None: """Process a binary WebSocket message using protobuf IntegrationMessage. @@ -462,6 +661,30 @@ async def _process_ws_binary_message(self, websocket, data: bytes) -> None: kind, ) + async def _cleanup_ws(self, websocket) -> None: + ctx = self._ws_contexts.pop(websocket, None) + + keys_to_cleanup = [k for k in self._voice_sessions if k[0] is websocket] + for key in keys_to_cleanup: + try: + await self._cleanup_voice_session(key, VoiceEndReason.REMOTE) + except Exception as ex: # pylint: disable=broad-exception-caught + _LOG.exception( + "[%s] WS: Error during voice session cleanup for session_id=%s: %s", + websocket.remote_address, + key[1], + ex, + ) + + if ctx is not None: + for fut in ctx.pending.values(): + if not fut.done(): + fut.set_exception(ConnectionError("WebSocket disconnected")) + + self._clients.discard(websocket) + _LOG.info("[%s] WS: Client removed", websocket.remote_address) + self._events.emit(uc.Events.CLIENT_DISCONNECTED, websocket=websocket) + async def _on_remote_voice_begin(self, websocket, msg: RemoteVoiceBegin) -> None: """Handle a RemoteVoiceBegin protobuf message. @@ -702,13 +925,7 @@ async def _handle_ws_request_msg( {"state": self.device_state}, ) elif msg == uc.WsMessages.GET_AVAILABLE_ENTITIES: - available_entities = self._available_entities.get_all() - await self._send_ws_response( - websocket, - req_id, - uc.WsMsgEvents.AVAILABLE_ENTITIES, - {"available_entities": available_entities}, - ) + await self._get_available_entities(websocket, req_id) elif msg == uc.WsMessages.GET_ENTITY_STATES: entity_states = await self._configured_entities.get_states() await self._send_ws_response( @@ -1351,10 +1568,108 @@ def remove_all_listeners(self, event: uc.Events | None) -> None: """ self._events.remove_all_listeners(event) + async def get_supported_entity_types( + self, websocket, *, timeout: float = 5.0 + ) -> list[str]: + """Request supported entity types from client and return msg_data.""" + resp = await self._ws_request( + websocket, + "get_supported_entity_types", + timeout=timeout, + ) + if resp.get("msg") != "supported_entity_types": + _LOG.debug( + "[%s] Unexpected resp msg for get_supported_entity_types: %s", + websocket.remote_address, + resp.get("msg"), + ) + return resp.get("msg_data", []) + + async def get_version( + self, websocket, *, timeout: float = 5.0 + ) -> dict[str, Any] | None: + """Request client version and return msg_data.""" + resp = await self._ws_request( + websocket, + "get_version", + timeout=timeout, + ) + if resp.get("msg") != "version": + _LOG.debug( + "[%s] Unexpected resp msg for get_version: %s", + websocket.remote_address, + resp.get("msg"), + ) + + return resp.get("msg_data") + + async def get_localization_cfg( + self, websocket, *, timeout: float = 5.0 + ) -> dict[str, Any] | None: + """Request localization config and return msg_data.""" + resp = await self._ws_request( + websocket, + "get_localization_cfg", + timeout=timeout, + ) + + if resp.get("msg") != "localization_cfg": + _LOG.debug( + "[%s] Unexpected resp msg for get_localization_cfg: %s", + websocket.remote_address, + resp.get("msg"), + ) + + return resp.get("msg_data") + + async def _update_supported_entity_types( + self, websocket, *, timeout: float = 5.0 + ) -> None: + """Update supported entity types by remote.""" + await asyncio.sleep(0) + try: + self._supported_entity_types = await self.get_supported_entity_types( + websocket, timeout=timeout + ) + _LOG.debug( + "[%s] Supported entity types %s", + websocket.remote_address, + self._supported_entity_types, + ) + except Exception as ex: # pylint: disable=W0718 + _LOG.error( + "[%s] Unable to retrieve entity types %s", + websocket.remote_address, + ex, + ) + + async def _get_available_entities(self, websocket, req_id) -> None: + if self._supported_entity_types is None: + # Request supported entity types from remote + await self._update_supported_entity_types(websocket) + available_entities = self._available_entities.get_all() + if self._supported_entity_types: + available_entities = [ + entity + for entity in available_entities + if entity.get("entity_type") in self._supported_entity_types + ] + await self._send_ws_response( + websocket, + req_id, + uc.WsMsgEvents.AVAILABLE_ENTITIES, + {"available_entities": available_entities}, + ) + ############## # Properties # ############## + @property + def clients(self) -> set: + """Return all clients.""" + return self._clients.copy() + @property def client_count(self) -> int: """Return number of WebSocket clients."""