diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3d44fcc0..34237922 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,6 +10,7 @@ repos: - id: check-merge-conflict - id: check-ast fail_fast: True + language_version: python3.12 - id: check-json - id: check-added-large-files args: ['--maxkb=200'] @@ -35,4 +36,5 @@ repos: hooks: - id: mypy files: 'src/.*\.py$' - additional_dependencies: ['types-setuptools==57.0.2'] + additional_dependencies: ['types-setuptools>=57.0.2'] + language_version: python3.12 diff --git a/README.rst b/README.rst index 400a20ae..ed71f7db 100644 --- a/README.rst +++ b/README.rst @@ -22,18 +22,10 @@ Workflows :target: https://pypi.org/project/workflows/ :alt: Supported Python versions -.. image:: https://img.shields.io/badge/code%20style-black-000000.svg - :target: https://github.com/psf/black +.. image:: https://img.shields.io/badge/code%20style-ruff-000000.svg + :target: https://github.com/astral-sh/ruff :alt: Code style: black -.. image:: https://img.shields.io/lgtm/grade/python/g/DiamondLightSource/python-workflows.svg?logo=lgtm&logoWidth=18 - :target: https://lgtm.com/projects/g/DiamondLightSource/python-workflows/context:python - :alt: Language grade: Python - -.. image:: https://img.shields.io/lgtm/alerts/g/DiamondLightSource/python-workflows.svg?logo=lgtm&logoWidth=18 - :target: https://lgtm.com/projects/g/DiamondLightSource/python-workflows/alerts/ - :alt: Total alerts - Workflows enables light-weight services to process tasks in a message-oriented environment. diff --git a/pyproject.toml b/pyproject.toml index e55a62f9..1c6d3c6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,14 +99,20 @@ filename = "pyproject.toml" filename = "src/workflows/__init__.py" [tool.ruff.lint] -select = ["E", "F", "W", "C4", "I"] +select = ["E", "F", "W", "C4", "I", "ANN"] unfixable = ["F841"] # E501 line too long (handled by formatter) -ignore = ["E501"] +ignore = ["E501", "ANN204", "ANN401"] [tool.ruff.lint.isort] known-first-party = ["dxtbx_*", "dxtbx"] required-imports = ["from __future__ import annotations"] +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["ANN"] + [tool.mypy] mypy_path = "src/" + +[tool.pyright] +exclude = ["tests", ".venv"] diff --git a/src/workflows/contrib/start_service.py b/src/workflows/contrib/start_service.py index 6545f703..cb6da236 100644 --- a/src/workflows/contrib/start_service.py +++ b/src/workflows/contrib/start_service.py @@ -1,8 +1,10 @@ from __future__ import annotations +import optparse import sys from collections.abc import Callable from optparse import SUPPRESS_HELP, OptionParser +from typing import Any import workflows import workflows.frontend @@ -17,20 +19,22 @@ class ServiceStarter: used in a number of scenarios.""" @staticmethod - def on_parser_preparation(parser): + def on_parser_preparation(parser: OptionParser) -> OptionParser | None: """Plugin hook to manipulate the OptionParser object before command line parsing. If a value is returned here it will replace the OptionParser object.""" @staticmethod - def on_parsing(options, args): + def on_parsing( + options: optparse.Values, args: list[str] + ) -> tuple[optparse.Values, list[str]] | None: """Plugin hook to manipulate the command line parsing results. A tuple of values can be returned, which will replace (options, args). """ @staticmethod def on_transport_factory_preparation( - transport_factory, + transport_factory: Callable[[], CommonTransport], ) -> Callable[[], CommonTransport] | None: """Plugin hook to intercept/manipulate newly created Transport factories before first invocation.""" @@ -41,28 +45,32 @@ def on_transport_preparation(transport: CommonTransport) -> CommonTransport | No before connecting.""" @staticmethod - def before_frontend_construction(kwargs): + def before_frontend_construction(kwargs: dict[str, Any]) -> dict[str, Any] | None: """Plugin hook to manipulate the Frontend object constructor arguments. If a value is returned here it will replace the keyword arguments dictionary passed to the constructor.""" @staticmethod - def on_frontend_preparation(frontend): + def on_frontend_preparation( + frontend: workflows.frontend.Frontend, + ) -> workflows.frontend.Frontend | None: """Plugin hook to manipulate the Frontend object before starting it. If a value is returned here it will replace the Frontend object.""" def run( self, - cmdline_args=None, - program_name="start_service", - version=None, + cmdline_args: list[str] | None = None, + program_name: str = "start_service", + version: str | None = None, add_metrics_option: bool = False, - **kwargs, - ): + **kwargs: Any, + ) -> None: """Example command line interface to start services. - :param cmdline_args: List of command line arguments to pass to parser - :param program_name: Name of the command line tool to display in help - :param version: Version number to print when run with '--version' + + Args: + cmdline_args: List of command line arguments to pass to parser + program_name: Name of the command line tool to display in help + version: Version number to print when run with '--version' """ # Enumerate all known services @@ -163,10 +171,12 @@ def on_transport_preparation_hook() -> CommonTransport: if options.service not in known_services: # First check whether the provided service name is a case-insensitive match. service_lower = options.service.lower() - match = {s.lower(): s for s in known_services}.get(service_lower, None) - match = ( - [match] - if match + exact_match = {s.lower(): s for s in known_services}.get( + service_lower, None + ) + match: list[str] = ( + [exact_match] + if exact_match # Next, check whether the provided service name is a partial # case-sensitive match. else [s for s in known_services if s.startswith(options.service)] diff --git a/src/workflows/contrib/status_monitor.py b/src/workflows/contrib/status_monitor.py index cc109e3d..83d73aa3 100644 --- a/src/workflows/contrib/status_monitor.py +++ b/src/workflows/contrib/status_monitor.py @@ -3,12 +3,12 @@ import curses import threading import time -from typing import Any +from collections.abc import Mapping +from typing import Any, Callable import workflows.transport from workflows.services.common_service import CommonService - -basestring = (str, bytes) +from workflows.transport.common_transport import CommonTransport class Monitor: # pragma: no cover @@ -19,7 +19,7 @@ class Monitor: # pragma: no cover shutdown = False """Set to true to end the main loop and shut down the service monitor.""" - cards: dict[Any, Any] = {} + cards: list """Register card shown for seen services""" border_chars = () @@ -27,21 +27,21 @@ class Monitor: # pragma: no cover border_chars_text = ("|", "|", "=", "=", "/", "\\", "\\", "/") """Example alternative set of frame border characters.""" - def __init__(self, transport=None): + def __init__(self, transport: Callable[[], CommonTransport] | str | None = None): """Set up monitor and connect to the network transport layer""" - if transport is None or isinstance(transport, basestring): - self._transport = workflows.transport.lookup(transport)() - else: + if callable(transport): self._transport = transport() + else: + self._transport = workflows.transport.lookup(transport)() assert self._transport.connect(), "Could not connect to transport layer" self._lock = threading.RLock() - self._node_status = {} - self.message_box = None + self._node_status: dict = {} + self.message_box: curses.window | None = None self._transport.subscribe_broadcast( "transient.status", self.update_status, retroactive=True ) - def update_status(self, header, message): + def update_status(self, header: Mapping[str, Any], message: Any) -> None: """Process incoming status message. Acquire lock for status dictionary before updating.""" with self._lock: if self.message_box: @@ -70,14 +70,21 @@ def update_status(self, header, message): self._node_status[message["host"]] = message self._node_status[message["host"]]["last_seen"] = receipt_time - def run(self): + def run(self) -> None: """A wrapper for the real _run() function to cleanly enable/disable the curses environment.""" curses.wrapper(self._run) def _boxwin( - self, height, width, row, column, title=None, title_x=7, color_pair=None - ): + self, + height: int, + width: int, + row: int, + column: int, + title: str | None = None, + title_x: int = 7, + color_pair: int | None = None, + ) -> curses.window: with self._lock: box = curses.newwin(height, width, row, column) box.clear() @@ -91,7 +98,7 @@ def _boxwin( box.noutrefresh() return curses.newwin(height - 2, width - 2, row + 1, column + 1) - def _redraw_screen(self, stdscr): + def _redraw_screen(self, stdscr: curses.window) -> None: """Redraw screen. This could be to initialize, or to redraw after resizing.""" with self._lock: stdscr.clear() @@ -105,7 +112,7 @@ def _redraw_screen(self, stdscr): self.message_box.scrollok(True) self.cards = [] - def _get_card(self, number): + def _get_card(self, number: int) -> curses.window: with self._lock: if number < len(self.cards): return self.cards[number] @@ -123,7 +130,7 @@ def _get_card(self, number): return self.cards[number] raise RuntimeError("Card number too high") - def _erase_card(self, number): + def _erase_card(self, number: int) -> None: """Destroy cards with this or higher number.""" with self._lock: if number < (len(self.cards) - 1): @@ -141,7 +148,7 @@ def _erase_card(self, number): obliterate.noutrefresh() del self.cards[number] - def _run(self, stdscr): + def _run(self, stdscr: curses.window) -> None: """Start the actual service monitor""" with self._lock: curses.use_default_colors() diff --git a/src/workflows/frontend/__init__.py b/src/workflows/frontend/__init__.py index 01962dda..17d67eab 100644 --- a/src/workflows/frontend/__init__.py +++ b/src/workflows/frontend/__init__.py @@ -4,7 +4,8 @@ import multiprocessing import threading import time -from collections.abc import Callable +from collections.abc import Callable, Mapping +from typing import Any import workflows import workflows.frontend.utilization @@ -14,8 +15,6 @@ from workflows.services.common_service import CommonService from workflows.transport.common_transport import CommonTransport -basestring = (str, bytes) - # Pin the fork start method: service instances carry pipes and transport # state that aren't pickleable, so spawn/forkserver (the 3.14+ default on # Linux) can't serialize them. fork is deprecated upstream and will @@ -35,42 +34,44 @@ class Frontend: def __init__( self, transport: Callable[[], CommonTransport] | str | None = None, - service=None, - transport_command_channel=None, - restart_service=False, - verbose_service=False, - environment=None, + service: type[CommonService] | str | None = None, + transport_command_channel: str | None = None, + restart_service: bool = False, + verbose_service: bool = False, + environment: dict[str, Any] | None = None, ): """Create a frontend instance. Connect to the transport layer, start any requested service, begin broadcasting status information and listen for control commands. - :param restart_service: - If the service process dies unexpectedly the frontend should start - a new instance. - :param service: - A class or name of the class to be instantiated in a subprocess as - service. - :param transport: - Either the name of a transport class, a transport class, or a - transport class object. - :param transport_command_channel: - An optional channel of a transport subscription to be listened to for - commands. - :param verbose_service: - If set, run services with increased logging level (DEBUG). - :param environment: - An optional dictionary that is passed to started services. + + Args: + restart_service: + If the service process dies unexpectedly the frontend should start + a new instance. + service: + A class or name of the class to be instantiated in a subprocess as + service. + transport: + Either the name of a transport class, a transport class, or a + transport class object. + transport_command_channel: + An optional channel of a transport subscription to be listened to + for commands. + verbose_service: + If set, run services with increased logging level (DEBUG). + environment: + An optional dictionary that is passed to started services. """ self.__lock = threading.RLock() self.__hostid = workflows.util.generate_unique_host_id() - self._service = None # pointer to the service instance - self._service_class_name = None - self._service_factory = None # pointer to the service class - self._service_name = None - self._service_starttime = None - self._service_rapidstarts = None - self._pipe_commands = None # frontend -> service - self._pipe_service = None # frontend <- service + self._service: multiprocessing.Process | None = None + self._service_class_name: str | None = None + self._service_factory: type[CommonService] | str | None = None + self._service_name: str | None = None + self._service_starttime: float | None = None + self._service_rapidstarts: int | None = None + self._pipe_commands: Any = None # frontend -> service + self._pipe_service: Any = None # frontend <- service self._service_status = CommonService.SERVICE_STATUS_NONE self._service_status_announced = CommonService.SERVICE_STATUS_NONE @@ -79,8 +80,8 @@ def __init__( # Status broadcast related variables self._status_interval = 6 - self._status_last_broadcast = 0 - self._status_idle_since = None + self._status_last_broadcast: float = 0 + self._status_idle_since: float | None = None self._utilization = workflows.frontend.utilization.UtilizationStatistics( summation_period=self._status_interval ) @@ -101,7 +102,7 @@ def __iter__(self): self.status = {"workflows_" + k: v for k, v in self.status_fn().items()} return self.status.__iter__() - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: """Return a value from the status dictionary.""" return self.status.__getitem__(key) @@ -111,7 +112,7 @@ def __getitem__(self, key): ) # Connect to the network transport layer - if transport is None or isinstance(transport, basestring): + if transport is None or isinstance(transport, str): self._transport_factory = workflows.transport.lookup(transport) else: self._transport_factory = transport @@ -141,16 +142,19 @@ def __getitem__(self, key): if environment and "liveness" in environment: self._start_liveness_endpoint(environment["liveness"]["port"]) - def update_status(self, status_code=None): + def update_status(self, status_code: int | None = None) -> None: """Update the service status kept inside the frontend (_service_status). The status is broadcast over the network immediately. If the status changes to IDLE then this message is delayed. The IDLE status is only broadcast if it is held for over 0.5 seconds. When the status does not change it is still broadcast every _status_interval seconds. - :param status_code: Either an integer describing the service status - (see workflows.services.common_service), or None - if the status is unchanged. + + Args: + status_code: + Either an integer describing the service status (see + workflows.services.common_service), or None if the status is + unchanged. """ if status_code is not None: self._service_status = status_code @@ -179,7 +183,7 @@ def update_status(self, status_code=None): self._transport.broadcast_status(self.get_status()) self._status_last_broadcast = time.time() - def run(self): + def run(self) -> None: """The main loop of the frontend. This is where the frontend process will spend most of its time.""" self.log.debug("Entered main loop") @@ -204,7 +208,7 @@ def run(self): self._transport.disconnect() self.log.debug("Terminating.") - def _iterate_main_loop(self): + def _iterate_main_loop(self) -> None: """Collection of steps that are run over and over again in the main loop of the frontend. Here incoming messages from the service are processed and forwarded to their corresponding callback methods.""" @@ -226,7 +230,7 @@ def _iterate_main_loop(self): self.log.warning("Invalid message received %s", str(message)) except EOFError: # Service has gone away - error_message = False + error_message: str | bool = False if self._service_status == CommonService.SERVICE_STATUS_END: self.log.info("Service terminated") elif self._service_status == CommonService.SERVICE_STATUS_ERROR: @@ -266,7 +270,7 @@ def _iterate_main_loop(self): self.update_status(status_code=CommonService.SERVICE_STATUS_NEW) self.switch_service() - def send_command(self, command): + def send_command(self, command: Any) -> None: """Send command to service via the command queue.""" if self._pipe_commands: self._pipe_commands.send(command) @@ -282,7 +286,9 @@ def send_command(self, command): "No command queue pipe found for command\n%s", str(command) ) - def process_transport_command(self, header, message): + def process_transport_command( + self, header: Mapping[str, Any], message: Any + ) -> None: """Parse a command coming in through the transport command subscription""" if not isinstance(message, dict): return @@ -308,7 +314,7 @@ def process_transport_command(self, header, message): else: self.log.warning("Received invalid transport command message") - def parse_band_log(self, message): + def parse_band_log(self, message: dict[str, Any]) -> None: """Process incoming logging messages from the service.""" try: record = message["payload"] @@ -331,14 +337,14 @@ def parse_band_log(self, message): setattr(record, "workflows_" + k, v) logging.getLogger(record_name).handle(record) - def parse_band_request_termination(self, message): + def parse_band_request_termination(self, message: dict[str, Any]) -> None: """Service declares it should be terminated.""" self.log.debug("Service requests termination") self._terminate_service() if not self.restart_service: self.shutdown = True - def parse_band_set_name(self, message): + def parse_band_set_name(self, message: dict[str, Any]) -> None: """Process incoming message indicating service name change.""" if message.get("name"): self._service_name = message["name"] @@ -347,21 +353,21 @@ def parse_band_set_name(self, message): "Received broken record on set_name band\nMessage: %s", str(message) ) - def parse_band_status_update(self, message): + def parse_band_status_update(self, message: dict[str, Any]) -> None: """Process incoming status updates from the service.""" self.log.debug("Status update: " + str(message)) self.update_status(status_code=message["statuscode"]) - def parse_band_liveness_check(self, message): + def parse_band_liveness_check(self, message: dict[str, Any]) -> None: """Respond by sending message to backend to let it know we are alive.""" self.log.debug("Service liveness check: alive!") self.__alive = True - def get_host_id(self): + def get_host_id(self) -> str: """Get a cached copy of the host id.""" return self.__hostid - def get_status(self): + def get_status(self) -> dict[str, Any]: """Returns a dictionary containing all relevant status information to be broadcast across the network.""" return { @@ -376,7 +382,7 @@ def get_status(self): "workflows": workflows.version(), } - def exponential_backoff(self): + def exponential_backoff(self) -> None: """A function that keeps waiting longer and longer the more rapidly it is called. It can be used to increasingly slow down service starts when they keep failing. """ @@ -397,11 +403,17 @@ def exponential_backoff(self): self.log.debug("Slowing down service starts (%.1f seconds)", minimum_wait) time.sleep(minimum_wait) - def switch_service(self, new_service=None): + def switch_service( + self, new_service: type[CommonService] | str | None = None + ) -> bool: """Start a new service in a subprocess. - :param new_service: Either a service name or a service class. If not set, - start up a new instance of the previous class - :return: True on success, False on failure. + + Args: + new_service: Either a service name or a service class. If not set, + start up a new instance of the previous class. + + Returns: + True on success, False on failure. """ if new_service: self._service_factory = new_service @@ -411,7 +423,7 @@ def switch_service(self, new_service=None): self._terminate_service() # Find service class if necessary - if isinstance(self._service_factory, basestring): + if isinstance(self._service_factory, str): self._service_factory = workflows.services.lookup(self._service_factory) if not self._service_factory: return False @@ -450,7 +462,7 @@ def switch_service(self, new_service=None): self.log.info("Started service: %s", self._service_name) return True - def _terminate_service(self): + def _terminate_service(self) -> None: """Force termination of running service. Disconnect queues, end queue feeder threads. Wait for service process to clear, drop all references.""" @@ -471,10 +483,13 @@ def _terminate_service(self): self._service.join() # must wait for process to be actually destroyed self._service = None - def _start_liveness_endpoint(self, port: int): + def _start_liveness_endpoint(self, port: int) -> None: from wsgiref.simple_server import make_server + from wsgiref.types import StartResponse, WSGIEnvironment - def alive(environ, start_response): + def alive( + environ: WSGIEnvironment, start_response: StartResponse + ) -> list[bytes]: self.__alive = False self.send_command({"band": "command", "payload": "liveness_check"}) diff --git a/src/workflows/frontend/utilization.py b/src/workflows/frontend/utilization.py index d88378d3..604d04b9 100644 --- a/src/workflows/frontend/utilization.py +++ b/src/workflows/frontend/utilization.py @@ -1,6 +1,7 @@ from __future__ import annotations import time +from typing import Any from workflows.services.common_service import CommonService @@ -9,15 +10,15 @@ class UtilizationStatistics: """Generate statistics about the percentage of time spent in different statuses over a fixed time slice. This class is not thread-safe.""" - def __init__(self, summation_period=10): + def __init__(self, summation_period: float = 10): """Reports will always cover the most recent period of summation_period seconds.""" self.period = summation_period - self.status_history = [ + self.status_history: list[dict[str, Any]] = [ {"start": 0, "end": None, "status": CommonService.SERVICE_STATUS_NEW} ] - def update_status(self, new_status): + def update_status(self, new_status: int) -> None: """Record a status change with a current timestamp.""" timestamp = time.time() self.status_history[-1]["end"] = timestamp @@ -25,14 +26,14 @@ def update_status(self, new_status): {"start": timestamp, "end": None, "status": new_status} ) - def report(self): + def report(self) -> dict[int, float]: """Return a dictionary of different status codes and the percentage of time spent in each throughout the last summation_period seconds. Truncate the aggregated history appropriately.""" timestamp = time.time() cutoff = timestamp - self.period truncate = 0 - summary = {} + summary: dict[int, float] = {} for event in self.status_history[:-1]: if event["end"] < cutoff: truncate = truncate + 1 diff --git a/src/workflows/logging.py b/src/workflows/logging.py index 627e6c2d..2b12649e 100644 --- a/src/workflows/logging.py +++ b/src/workflows/logging.py @@ -4,12 +4,15 @@ import logging import os.path import sys +from typing import Callable -def get_exception_source(): +def get_exception_source() -> tuple[str, str, int, str, str | None]: """Returns full file path, file name, line number, function name, and line contents causing the last exception.""" _, _, tb = sys.exc_info() + if tb is None: + raise RuntimeError("No exception currently being handled") while tb.tb_next: tb = tb.tb_next f = tb.tb_frame @@ -19,7 +22,7 @@ def get_exception_source(): filename = os.path.basename(filefullpath) name = co.co_name linecache.checkcache(filefullpath) - line = linecache.getline(filefullpath, lineno, f.f_globals) + line: str | None = linecache.getline(filefullpath, lineno, f.f_globals) if line: line = line.strip() else: @@ -30,12 +33,12 @@ def get_exception_source(): class CallbackHandler(logging.Handler): """This handler sends logrecords to a callback function.""" - def __init__(self, callback): + def __init__(self, callback: Callable[[logging.LogRecord], None]): """Set up a handler instance, record the callback function.""" super().__init__() self._callback = callback - def prepare(self, record): + def prepare(self, record: logging.LogRecord) -> logging.LogRecord: # Function taken from Python 3.6 QueueHandler """ Prepares a record for queuing. The object returned by this method is @@ -59,7 +62,7 @@ def prepare(self, record): record.exc_info = None return record - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: """Send a LogRecord to the callback function, after preparing it for serialization.""" try: @@ -67,8 +70,10 @@ def emit(self, record): except Exception: self.handleError(record) - def handleError(self, record): + def handleError(self, record: logging.LogRecord) -> None: t, v, _ = sys.exc_info() + if t is None: + raise RuntimeError("Trying to handle error when no exception active") try: sys.stderr.write( f"--- Logging error --- {t.__name__}: {v}\n" diff --git a/src/workflows/recipe/__init__.py b/src/workflows/recipe/__init__.py index 0f70c8ea..0f7fc84f 100644 --- a/src/workflows/recipe/__init__.py +++ b/src/workflows/recipe/__init__.py @@ -3,14 +3,15 @@ import functools import logging from collections.abc import Callable -from contextlib import ExitStack -from typing import Any +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Literal, overload from opentelemetry import trace from workflows.recipe.recipe import Recipe from workflows.recipe.validate import validate_recipe from workflows.recipe.wrapper import RecipeWrapper +from workflows.transport.common_transport import CommonTransport __all__ = [ "Recipe", @@ -24,49 +25,54 @@ def _wrap_subscription( - transport_layer, - subscription_call, - channel, - callback, - *args, - mangle_for_receiving: Callable[[Any], Any] | None = None, - **kwargs, -): + transport_layer: CommonTransport, + subscription_call: Callable[..., int], + channel: str, + callback: Callable[..., Any], + *args: Any, + mangle_for_receiving: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + allow_non_recipe_messages: bool = False, + log_extender: Callable[[str, Any], AbstractContextManager[Any]] | None = None, + **kwargs: Any, +) -> int: """Internal method to create an intercepting function for incoming messages to interpret recipes. This function is then used to subscribe to a channel on the transport layer. - :param transport_layer: Reference to underlying transport object. - :param subscription_call: Reference to the subscribing function of the - transport layer. - :param channel: Channel name to subscribe to. - :param callback: Real function to be called when messages are received. - The callback will pass three arguments, - a RecipeWrapper object (details below), the header as - a dictionary structure, and the message. - - :param allow_non_recipe_messages: Pass on incoming messages that do not - include recipe information. In this case the first - argument to the callback function will be 'None'. - :param log_extender: If the recipe contains useful contextual information - for log messages, such as a unique ID which can be used - to connect all messages originating from the same - recipe, then the information will be passed to this - function, which must be a context manager factory. - :return: Return value of call to subscription_call. - """ - allow_non_recipe_messages = kwargs.pop("allow_non_recipe_messages", False) - log_extender = kwargs.pop("log_extender", None) + Args: + transport_layer: Reference to underlying transport object. + subscription_call: Reference to the subscribing function of the transport layer. + channel: Channel name to subscribe to. + callback: + Real function to be called when messages are received. + The callback will pass three arguments: a RecipeWrapper object, + the header as a dictionary structure, and the message. + allow_non_recipe_messages: + Pass on incoming messages that do not include recipe information. + In this case the first argument to the callback function will be None. + log_extender: + If the recipe contains useful contextual information for log messages, + such as a unique ID which can be used to connect all messages + originating from the same recipe, then the information will be passed + to this function, which must be a context manager factory. + + Returns: + Return value of call to subscription_call. + """ @functools.wraps(callback) - def unwrap_recipe(header, message): - """This is a helper function unpacking incoming messages when they are - in a recipe format. Other messages are passed through unmodified. - :param header: A dictionary of message headers. If the header contains - an entry 'workflows-recipe' then the message is parsed - and the embedded recipe information is passed on in a - RecipeWrapper object to the target function. - :param message: Incoming deserialized message object. + def unwrap_recipe(header: dict[str, Any], message: dict[str, Any]) -> Any: + """Unpack incoming messages when they are in a recipe format. + + Other messages are passed through unmodified. + + Args: + header: + A dictionary of message headers. If the header contains an entry + 'workflows-recipe' then the message is parsed and the embedded + recipe information is passed on in a RecipeWrapper object to the + target function. + message: Incoming deserialized message object. """ if mangle_for_receiving: message = mangle_for_receiving(message) @@ -107,31 +113,64 @@ def unwrap_recipe(header, message): "Unable to process incoming message." ) transport_layer.nack(header) + return None if mangle_for_receiving: kwargs = {**kwargs, "disable_mangling": True} return subscription_call(channel, unwrap_recipe, *args, **kwargs) +@overload +def wrap_subscribe( + transport_layer: CommonTransport, + channel: str, + callback: Callable[[RecipeWrapper, dict, dict], None], + *args: Any, + allow_non_recipe_messages: Literal[False] = False, + mangle_for_receiving: Callable[[Any], Any] | None = None, + log_extender: Callable[[str, Any], AbstractContextManager[Any]] | None = None, + **kwargs: Any, +) -> int: ... + + +@overload def wrap_subscribe( - transport_layer, - channel, - callback, - *args, + transport_layer: CommonTransport, + channel: str, + callback: Callable[[RecipeWrapper | None, dict, dict | bytes], None], + *args: Any, + allow_non_recipe_messages: Literal[True], mangle_for_receiving: Callable[[Any], Any] | None = None, - **kwargs, -): + log_extender: Callable[[str, Any], AbstractContextManager[Any]] | None = None, + **kwargs: Any, +) -> int: ... + + +def wrap_subscribe( + transport_layer: CommonTransport, + channel: str, + callback: Callable[..., Any], + *args: Any, + allow_non_recipe_messages: bool = False, + mangle_for_receiving: Callable[[Any], Any] | None = None, + log_extender: Callable[[str, Any], AbstractContextManager[Any]] | None = None, + **kwargs: Any, +) -> int: """Listen to a queue on the transport layer, similar to the subscribe call in transport/common_transport.py. Intercept all incoming messages and parse for recipe information. See common_transport.subscribe for possible additional keyword arguments. - :param transport_layer: Reference to underlying transport object. - :param channel: Queue name to subscribe to. - :param callback: Function to be called when messages are received. - The callback will pass three arguments, - a RecipeWrapper object (details below), the header as - a dictionary structure, and the message. - :return: A unique subscription ID + + Args: + transport_layer: Reference to underlying transport object. + channel: Queue name to subscribe to. + callback: + Function to be called when messages are received. The callback will + pass three arguments: a RecipeWrapper object, the header as a + dictionary structure, and the message. + + Returns: + A unique subscription ID """ return _wrap_subscription( @@ -141,29 +180,35 @@ def wrap_subscribe( callback, *args, mangle_for_receiving=mangle_for_receiving, + allow_non_recipe_messages=allow_non_recipe_messages, + log_extender=log_extender, **kwargs, ) def wrap_subscribe_broadcast( - transport_layer, - channel, - callback, - *args, + transport_layer: CommonTransport, + channel: str, + callback: Callable[..., Any], + *args: Any, mangle_for_receiving: Callable[[Any], Any] | None = None, - **kwargs, -): + **kwargs: Any, +) -> int: """Listen to a topic on the transport layer, similar to the subscribe_broadcast call in transport/common_transport.py. Intercept all incoming messages and parse for recipe information. See common_transport.subscribe_broadcast for possible arguments. - :param transport_layer: Reference to underlying transport object. - :param channel: Topic name to subscribe to. - :param callback: Function to be called when messages are received. - The callback will pass three arguments, - a RecipeWrapper object (details below), the header as - a dictionary structure, and the message. - :return: A unique subscription ID + + Args: + transport_layer: Reference to underlying transport object. + channel: Topic name to subscribe to. + callback: + Function to be called when messages are received. The callback will + pass three arguments: a RecipeWrapper object, the header as a + dictionary structure, and the message. + + Returns: + A unique subscription ID """ return _wrap_subscription( diff --git a/src/workflows/recipe/recipe.py b/src/workflows/recipe/recipe.py index e157fc3f..74832b99 100644 --- a/src/workflows/recipe/recipe.py +++ b/src/workflows/recipe/recipe.py @@ -3,11 +3,12 @@ import copy import json import string -from typing import Any +from typing import Any, Literal import workflows -basestring = (str, bytes) +type RecipeKey = Literal["start", "error"] | int +type RawRecipe = dict[RecipeKey, Any] class Recipe: @@ -15,81 +16,83 @@ class Recipe: A recipe describes how all involved services are connected together, how data should be passed and how errors should be handled.""" - recipe: dict[Any, Any] = {} + recipe: RawRecipe """The processing recipe is encoded in this dictionary.""" - # TODO: Describe format - def __init__(self, recipe=None): + def __init__[KT: RecipeKey | str](self, recipe: dict[KT, Any] | str | None = None): """Constructor allows passing in a recipe dictionary.""" - if isinstance(recipe, basestring): + if isinstance(recipe, str): self.recipe = self.deserialize(recipe) elif recipe: self.recipe = self._sanitize(recipe) + elif not hasattr(self, "recipe"): + self.recipe = {} - def deserialize(self, string): + def deserialize(self, data: str) -> RawRecipe: """Convert a recipe that has been stored as serialized json string to a data structure.""" - return self._sanitize(json.loads(string)) + return self._sanitize(json.loads(data)) @staticmethod - def _sanitize(recipe): + def _sanitize[KT: RecipeKey | str](recipe: dict[KT, Any]) -> RawRecipe: """Clean up a recipe that may have been stored as serialized json string. Convert any numerical pointers that are stored as strings to integers.""" - recipe = recipe.copy() - for k in list(recipe): - if k not in ("start", "error") and int(k) and k != int(k): - recipe[int(k)] = recipe[k] - del recipe[k] - for k in list(recipe): - if "output" in recipe[k] and not isinstance( - recipe[k]["output"], (list, dict) + sanitized: RawRecipe = {} + for k, v in recipe.items(): + if k == "start" or k == "error": + sanitized[k] = v + elif isinstance(k, int): + sanitized[k] = v + elif isinstance(k, str): + sanitized[int(k)] = v + + for node in sanitized.values(): + if ( + isinstance(node, dict) + and "output" in node + and not isinstance(node["output"], (list, dict)) ): - recipe[k]["output"] = [recipe[k]["output"]] - # dicts should be normalized, too - if "start" in recipe: - recipe["start"] = [tuple(x) for x in recipe["start"]] - return recipe + node["output"] = [node["output"]] - def serialize(self): + if "start" in sanitized: + sanitized["start"] = [tuple(x) for x in sanitized["start"]] + return sanitized + + def serialize(self) -> str: """Write out the current recipe as serialized json string.""" return json.dumps(self.recipe) - def pretty(self): + def pretty(self) -> str: """Write out the current recipe as serialized json string with pretty formatting.""" return json.dumps(self.recipe, indent=2) - def __getitem__(self, item): + def __getitem__(self, item: RecipeKey) -> Any: """Allow direct dictionary access to recipe elements.""" return self.recipe.__getitem__(item) - def __contains__(self, item): + def __contains__(self, item: object) -> bool: """Testing for presence of recipe elements.""" return item in self.recipe - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """Overload equality operator (!=) to allow comparing recipe objects with one another and with their string representations.""" if isinstance(other, Recipe): return self.recipe == other.recipe if isinstance(other, dict): return self.recipe == self._sanitize(other) - return self.recipe == self.deserialize(other) + if isinstance(other, str): + return self.recipe == self.deserialize(other) + return NotImplemented - def __ne__(self, other): - """Overload inequality operator (!=) to allow comparing recipe objects - with one another and with their string representations.""" - result = self.__eq__(other) - if result is NotImplemented: - return result - return not result - - def __hash__(self): - """Recipe objects are mutable and therefore should not be hashable.""" - return None - - def validate(self): - """Check whether the encoded recipe is valid. It must describe a directed - acyclical graph, all connections must be defined, etc.""" + def validate(self) -> None: + """Check whether the encoded recipe is valid. + + It must describe a directed acyclical graph, all connections + must be defined, etc. + + Raises if the recipe is not valid. + """ if not self.recipe: raise workflows.Error("Invalid recipe: No recipe defined") @@ -107,7 +110,7 @@ def validate(self): # Check that 'error' node points to regular nodes only if "error" in self.recipe and isinstance( - self.recipe["error"], (list, tuple, basestring) + self.recipe["error"], (list, tuple, str) ): if "start" in self.recipe["error"]: raise workflows.Error( @@ -129,7 +132,7 @@ def validate(self): # Detect cycles touched_nodes = {"start", "error"} - def flatten_links(struct): + def flatten_links(struct: Any) -> list[int]: """Take an output/error link object, list or dictionary and return flat list of linked nodes.""" if struct is None: return [] @@ -150,7 +153,7 @@ def flatten_links(struct): "Invalid recipe: Invalid link in recipe (%s)" % str(struct) ) - def find_cycles(path): + def find_cycles(path: list[Any]) -> None: """Depth-First-Search helper function to identify cycles.""" if path[-1] not in self.recipe: raise workflows.Error( @@ -186,7 +189,7 @@ def find_cycles(path): 'Invalid recipe: Recipe contains unreferenced node "%s"' % str(node) ) - def apply_parameters(self, parameters): + def apply_parameters(self, parameters: dict[str, Any]) -> None: """Recursively apply dictionary entries in 'parameters' to {item}s in recipe structure, leaving undefined {item}s as they are. A special case is a {$REPLACE:item}, which replaces the string with a copy of the referenced @@ -212,23 +215,23 @@ def apply_parameters(self, parameters): """ class SafeString: - def __init__(self, s): + def __init__(self, s: str): self.string = s - def __repr__(self): + def __repr__(self) -> str: return "{" + self.string + "}" - def __str__(self): + def __str__(self) -> str: return "{" + self.string + "}" - def __getitem__(self, item): + def __getitem__(self, item: str) -> SafeString: return SafeString(self.string + "[" + item + "]") class SafeDict(dict): """A dictionary that returns undefined keys as {keyname}. This can be used to selectively replace variables in datastructures.""" - def __missing__(self, key): + def __missing__(self, key: str) -> SafeString: return SafeString(key) # By default the python formatter class is used to resolve {item} references @@ -239,23 +242,23 @@ def __missing__(self, key): # string. ds_formatter = string.Formatter() - def ds_format_field(value, spec): - ds_format_field.last = value + def ds_format_field(value: Any, spec: str) -> str: + ds_format_field.last = value # type: ignore return "" - ds_formatter.format_field = ds_format_field + ds_formatter.format_field = ds_format_field # type: ignore params = SafeDict(parameters) - def _recursive_apply(item): + def _recursive_apply(item: Any) -> Any: """Helper function to recursively apply replacements.""" - if isinstance(item, basestring): + if isinstance(item, str): if item.startswith("{$REPLACE") and item.endswith("}"): try: ds_formatter.vformat("{" + item[10:-1] + "}", (), parameters) except KeyError: return None - return copy.deepcopy(ds_formatter.format_field.last) + return copy.deepcopy(ds_formatter.format_field.last) # type: ignore else: return formatter.vformat(item, (), params) if isinstance(item, dict): @@ -271,13 +274,16 @@ def _recursive_apply(item): self.recipe = _recursive_apply(self.recipe) - def merge(self, other): + def merge(self, other: Recipe | str) -> Recipe: """Merge two recipes together, returning a single recipe containing all nodes. Note: This does NOT yet return a minimal recipe. - :param other: A Recipe object that should be merged with the current - Recipe object. - :return: A new Recipe object containing information from both recipes. + + Args: + other: A Recipe object that should be merged with the current Recipe object. + + Returns: + A new Recipe object containing information from both recipes. """ # Merging empty values returns a copy of the original @@ -285,7 +291,7 @@ def merge(self, other): return Recipe(self.recipe) # When a string is passed, merge with a constructed recipe object - if isinstance(other, basestring): + if isinstance(other, str): return self.merge(Recipe(other)) # Merging empty recipes returns a copy of the original @@ -304,7 +310,7 @@ def merge(self, other): new_recipe = self.recipe # Find the maximum index of the current recipe - max_index = max(1, *filter(lambda x: isinstance(x, int), self.recipe.keys())) + max_index = max(1, *(k for k in self.recipe if isinstance(k, int))) next_index = max_index + 1 # Set up a translation table for indices and copy all entries @@ -317,7 +323,7 @@ def merge(self, other): new_recipe[translation[key]] = value # Rewrite all copied entries to point to new keys - def translate(x): + def translate(x: Any) -> Any: if isinstance(x, list): return list(map(translate, x)) elif isinstance(x, tuple): @@ -351,18 +357,4 @@ def translate(x): else: new_recipe["error"].append(translate(other.recipe["error"])) - # # Minimize DAG - # queuehash, topichash = {}, {} - # for k, v in new_recipe.items(): - # if isinstance(v, dict): - # if 'queue' in v: - # queuehash[v['queue']] = queuehash.get(v['queue'], []) - # queuehash[v['queue']].append(k) - # if 'topic' in v: - # topichash[v['topic']] = topichash.get(v['topic'], []) - # topichash[v['topic']].append(k) - # - # print queuehash - # print topichash - return Recipe(new_recipe) diff --git a/src/workflows/recipe/validate.py b/src/workflows/recipe/validate.py index 7bb72063..6a6e6d83 100644 --- a/src/workflows/recipe/validate.py +++ b/src/workflows/recipe/validate.py @@ -20,13 +20,14 @@ import argparse import json import logging +import os import sys import workflows import workflows.recipe -def validate_recipe(json_filename): +def validate_recipe(json_filename: str | os.PathLike[str]) -> None: """Reads a json file, tries to turn it into a recipe and then validates it. Exits on exception with non-zero error""" @@ -53,7 +54,7 @@ def validate_recipe(json_filename): raise e -def main(): +def main() -> None: """Run the program from entry point""" parser = argparse.ArgumentParser() parser.add_argument( diff --git a/src/workflows/recipe/wrapper.py b/src/workflows/recipe/wrapper.py index c15de95d..bc27ac2b 100644 --- a/src/workflows/recipe/wrapper.py +++ b/src/workflows/recipe/wrapper.py @@ -3,24 +3,104 @@ import logging import time from collections.abc import Callable -from typing import Any +from typing import Any, overload import workflows.recipe +from workflows.recipe.recipe import Recipe +from workflows.transport.common_transport import CommonTransport logger = logging.getLogger("workflows.recipe.wrapper") class RecipeWrapper: - """A wrapper object which contains a recipe and a number of functions to make - life easier for recipe users. """ + Represent a "Live" recipe, including the current state and history. + + Makes it possible to keep track of the current state of the recipe + and any parameters passed between services. If provided with a + transport class, then the transport convenience methods can be used + to manage all the bookkeeping to send onward messages/messages + between services. + + Services normally do not construct wrappers directly. Instead, the + :func:`~workflows.recipe.wrap_subscribe` and + :func:`~workflows.recipe.wrap_subscribe_broadcast` helpers intercept + incoming messages, build a wrapper positioned at the relevant recipe + step, and pass it to the service's callback. The callback then uses + :meth:`send` or :meth:`send_to` to dispatch results onward without + needing to know the names of the next services in the chain. + + Wrappers can also be constructed from a bare recipe (rather than a + received message) to kick off a new recipe via :meth:`start`. + + Attributes: + recipe: The underlying :class:`Recipe` object. + recipe_pointer: Index of the current step within ``recipe``, or + ``None`` if the wrapper was built from a recipe that has not + yet been started. + recipe_step: The recipe node at ``recipe_pointer``, or ``None`` + before the recipe has started. Inspect this to read output + channel definitions for the current step. + recipe_path: History of nodes passed through to reach this node. + environment: Dictionary of contextual information carried with + the recipe (e.g. a recipe ``ID`` used for log correlation). + Propagated unchanged to every downstream message. + payload: The message payload delivered to this step, or ``None`` + when the wrapper was built from a bare recipe. + default_channel: Named output channel treated as the target for + :meth:`send` when the current step has named outputs. Set via + :meth:`set_default_channel`. + transport: The transport layer used to dispatch messages. + Accessing this attribute raises :class:`RuntimeError` if no + transport was supplied at construction time. + """ + + recipe_pointer: int | None + default_channel: str | None + @overload def __init__( - self, message=None, transport=None, recipe=None, environment=None, **kwargs + self, + *, + transport: CommonTransport | None = None, + environment: dict[str, Any] | None = None, + message: dict[str, Any], + ) -> None: ... + + @overload + def __init__( + self, + *, + transport: CommonTransport | None = None, + environment: dict[str, Any] | None = None, + recipe: Recipe | dict[str, Any], + ) -> None: ... + + def __init__( + self, + *, + transport: CommonTransport | None = None, + environment: dict[str, Any] | None = None, + message: dict[str, Any] | None = None, + recipe: Recipe | dict[str, Any] | None = None, + **kwargs, ): - """Create a RecipeWrapper object from a wrapped message. - References to the transport layer are required to send directly to - connected downstream processes. + """ + Create a RecipeWrapper object from a wrapped message. + + transport: + References to the transport layer, required to make use of + methods that send directly to downstream processes, such as + :meth:`send`, :meth:`start`, and :meth:`checkpoint`. If not + provided, then the RecipeWrapper can only be inspected. + environment: + Optional environment dictionary propagated to all downstream + messages. Used for Recipe-workflow global information, such + as `ID`. + recipe: + A :class:`Recipe` instance, or a raw recipe dictionary that + will be validated and wrapped in one. Used for construction + of a wrapper that has not yet been started. """ if message: self.recipe = workflows.recipe.Recipe(message["recipe"]) @@ -47,21 +127,64 @@ def __init__( "A message or recipe is required to create a RecipeWrapper object." ) self.default_channel = None - self.transport = transport - - def send(self, *args, **kwargs): - """Send messages to another service that is connected to the currently - running service via the recipe. The 'send' method will either use a - default channel name, set via the set_default_channel method, or an - unnamed output definition. - """ + self._transport = transport - if not self.transport: - raise ValueError( - "This RecipeWrapper object does not contain " - "a reference to a transport object." + @property + def transport(self) -> CommonTransport: + if self._transport is None: + raise RuntimeError( + "This RecipeWrapper object does not contain a reference to a transport object." ) + return self._transport + def send( + self, + message: Any, + header: dict[str, Any] | None = None, + *, + mangle_for_sending: Callable[[Any], Any] | None = None, + **kwargs: Any, + ) -> None: + """ + Send a message to the current step's downstream services. + + This is another service that is connected to the currently + running service via the recipe, specified by the "output" field + of the current step. There are three behaviours: + + - If ``output`` is unset or the step has no ``output`` entry, + then the message will be discarded. + - If ``output`` is a single integer or an array of integers + naming destination nodes in the recipe, then a copy of the + message will be delivered to each one of them. + - If ``output`` is a dict of named target nodes; If a default + channel name has been set with :meth:`set_default_channel` + then the message will only be sent to any destinations listed + there. + + For general sending to named output channels, use the + :meth:`send_to` method instead. + + Args: + message: + The payload to deliver. Wrapped in the standard recipe + envelope (recipe, pointer, path, environment) before + handing off to the transport. + header: + Optional dictionary of transport headers. This is merged + with any set by the transport (e.g. the ``workflows-recipe`` + header flag is automatically added). + mangle_for_sending: + Optional callable applied to the fully formed message + right before it is handed to the transport. If specified + here, then the default transport serialization is not + applied. The default mangling function for transports is + usually "encode to JSON". + **kwargs: + Any additional keyword arguments forwarded to the + transport's ``send``/``broadcast`` calls, in addition + to any declared in the recipe. + """ if not self.recipe_step: raise ValueError( "This RecipeWrapper object does not contain " @@ -72,26 +195,71 @@ def send(self, *args, **kwargs): # The current recipe step does not have output channels. return + # Merge down the annotated function call to bare kwargs + kwargs["message"] = message + if header: + kwargs["header"] = header + if mangle_for_sending: + kwargs["mangle_for_sending"] = mangle_for_sending + if isinstance(self.recipe_step["output"], dict): # The current recipe step does have named output channels. if self.default_channel: # Use named output channel - self.send_to(self.default_channel, *args, **kwargs) + self.send_to(self.default_channel, **kwargs) else: # The current recipe step does have unnamed output channels. - self._send_to_destinations(self.recipe_step["output"], *args, **kwargs) + self._send_to_destinations(self.recipe_step["output"], **kwargs) - def send_to(self, channel, *args, **kwargs): - """Send messages to another service that is connected to the currently - running service via the recipe. Discard messages if the recipe does - not have anything connected to the specified output channel. + def send_to( + self, + channel: str, + message: Any, + header: dict[str, Any] | None = None, + *, + mangle_for_sending: Callable[[Any], Any] | None = None, + **kwargs: Any, + ) -> None: + """ + Send a message to a service named by this steps' output channel. + + This is another service that is connected to the currently + running service via the recipe, specified by the "output" field + of the current step. There are three behaviours: + + - If ``output`` is unset or the step has no ``output`` entry, + then the message will be discarded. + - If ``output`` is *NOT* a dictionary, then the message is sent + to all services specified *ONLY IF* the passed channel name is + the same as the default channel name. + - The channel name is looked up in the ``outputs`` dictionary + and a copy of the message is sent to every destination + specified there. + + Args: + channel: + The name of the output channel, corresponding to an + entry in the ``outputs`` dictionary. + message: + The payload to deliver. Wrapped in the standard recipe + envelope (recipe, pointer, path, environment) before + handing off to the transport. + header: + Optional dictionary of transport headers. This is merged + with any set by the transport (e.g. the ``workflows-recipe`` + header flag is automatically added). + mangle_for_sending: + Optional callable applied to the fully formed message + right before it is handed to the transport. If specified + here, then the default transport serialization is not + applied. The default mangling function for transports is + usually "encode to JSON". + **kwargs: + Any additional keyword arguments forwarded to the + transport's ``send``/``broadcast`` calls, in addition + to any declared in the recipe. """ - if not self.transport: - raise ValueError( - "This RecipeWrapper object does not contain " - "a reference to a transport object." - ) if not self.recipe_step: raise ValueError( @@ -103,55 +271,111 @@ def send_to(self, channel, *args, **kwargs): # The current recipe step does not have output channels. return + # Merge down the annotated function call to bare kwargs + kwargs["message"] = message + if header: + kwargs["header"] = header + if mangle_for_sending: + kwargs["mangle_for_sending"] = mangle_for_sending + if not isinstance(self.recipe_step["output"], dict): # The current recipe step does not have named output channels. if self.default_channel == channel: # Use unnamed output channels - self.send(*args, **kwargs) + self.send(**kwargs) return if channel not in self.recipe_step["output"]: # The current recipe step does not have an output channel with this name. return - self._send_to_destinations(self.recipe_step["output"][channel], *args, **kwargs) + self._send_to_destinations(self.recipe_step["output"][channel], **kwargs) - def set_default_channel(self, channel): + def set_default_channel(self, channel: str) -> None: """Define one named output channel to be equivalent to unnamed output channels. For this channel send() and send_to() will be identical.""" self.default_channel = channel - def start(self, header=None, **kwargs): - """Trigger the start of a recipe, sending the defined payloads to the - recipients set in the recipe. Any parameters to this function are - passed to the transport send/broadcast methods. - If the wrapped recipe has already been started then a ValueError will - be raised. + def start( + self, + header: dict[str, Any] | None = None, + *, + mangle_for_sending: Callable[[Any], Any] | None = None, + **kwargs: Any, + ) -> None: + """ + Trigger the start of a recipe. + + Dispatches the payloads defined in the recipe's ``start`` node + to their respective recipients via the transport layer. Only + valid on a wrapper constructed from a bare recipe; calling + :meth:`start` on a wrapper that already has a selected step + raises :class:`ValueError`. + + Args: + header: + Optional dictionary of transport headers. This is merged + with any set by the transport (e.g. the ``workflows-recipe`` + header flag is automatically added). + mangle_for_sending: + Optional callable applied to the fully formed message + right before it is handed to the transport. If specified + here, then the default transport serialization is not + applied. The default mangling function for transports is + usually "encode to JSON". + **kwargs: Keywords passed on to the transport. + + Raises: + ValueError: If the wrapped recipe has already been started. """ - if not self.transport: - raise ValueError( - "This RecipeWrapper object does not contain " - "a reference to a transport object." - ) if self.recipe_step: raise ValueError("This recipe has already been started.") for destination, payload in self.recipe["start"]: - self._send_to_destination(destination, header, payload, kwargs) - - def checkpoint(self, message, header=None, delay=0, **kwargs): - """Send a message to the current recipe destination. This can be used to - keep a state for longer processing tasks. - :param delay: Delay transport of message by this many seconds - """ - if not self.transport: - raise ValueError( - "This RecipeWrapper object does not contain " - "a reference to a transport object." + self._send_to_destination( + destination, + header=header, + payload=payload, + transport_kwargs=kwargs, + mangle_for_sending=mangle_for_sending, ) - if not self.recipe_step: + def checkpoint( + self, + message: Any, + header: dict[str, Any] | None = None, + delay: float = 0.0, + *, + mangle_for_sending: Callable[[Any], Any] | None = None, + **kwargs: Any, + ) -> None: + """ + Send a message back to "yourself" e.g. the current recipe destination. + + This can be used to store state in the rabbitmq queue for + processing of longer-term tasks, without making individual + services stateful. + + Args: + message: + The payload to deliver. Wrapped in the standard recipe + envelope (recipe, pointer, path, environment) before + handing off to the transport. + header: + Optional dictionary of transport headers. This is merged + with any set by the transport (e.g. the ``workflows-recipe`` + header flag is automatically added). + delay: Time, in seconds, to delay delivery of this message. + mangle_for_sending: + Optional callable applied to the fully formed message + right before it is handed to the transport. If specified + here, then the default transport serialization is not + applied. The default mangling function for transports is + usually "encode to JSON". + **kwargs: Keywords passed on to the transport. + """ + if not self.recipe_step or self.recipe_pointer is None: raise ValueError( "This RecipeWrapper object does not contain " "a recipe with a selected step." @@ -160,20 +384,43 @@ def checkpoint(self, message, header=None, delay=0, **kwargs): kwargs["delay"] = delay self._send_to_destination( - self.recipe_pointer, header, message, kwargs, add_path_step=False + self.recipe_pointer, + header, + message, + kwargs, + add_path_step=False, + mangle_for_sending=mangle_for_sending, ) - def apply_parameters(self, parameters): - """Recursively apply parameter replacement (see recipe.py) to the wrapped - recipe, updating internal references afterwards. - While this operation is useful for testing it should not be used in - production. Replacing parameters means that the recipe changes as it is - passed down the chain of services. This makes debugging very difficult. + def apply_parameters(self, parameters: dict[str, Any]) -> None: + """ + Recursively substitute parameter values into the wrapped recipe. + + Delegates to :meth:`Recipe.apply_parameters` and then refreshes + :attr:`recipe_step` so subsequent :meth:`send` / :meth:`send_to` + calls observe the substituted outputs and queue names. + + Note: + This is primarily useful in tests. Mutating the recipe as it + is passed down the chain of services means each hop sees a + different recipe, which makes failures very difficult to + diagnose; prefer baking parameters in before dispatch in + production. + + Args: + parameters: Mapping of parameter names to replacement values. + Keys are referenced from the recipe as ``{name}``, or as + ``{$REPLACE:name}`` to substitute an entire data + structure in place of the string. See + :meth:`Recipe.apply_parameters` for the full grammar. """ self.recipe.apply_parameters(parameters) + assert self.recipe_pointer is not None self.recipe_step = self.recipe[self.recipe_pointer] - def _generate_full_recipe_message(self, destination, message, add_path_step): + def _generate_full_recipe_message( + self, destination: int, message: Any, add_path_step: bool + ) -> dict[str, Any]: """Factory function to generate independent message objects for downstream recipients with different destinations.""" if add_path_step and self.recipe_pointer: @@ -191,17 +438,17 @@ def _generate_full_recipe_message(self, destination, message, add_path_step): def _send_to_destinations( self, - destinations, - message, - header=None, + destinations: int | list[int], + message: Any, + header: dict[str, Any] | None = None, mangle_for_sending: Callable[[Any], Any] | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: """Send messages to a list of numbered destinations. This is an internal helper method used by the public 'send' methods. """ if not isinstance(destinations, list): - destinations = (destinations,) + destinations = [destinations] for destination in destinations: self._send_to_destination( destination, @@ -213,13 +460,13 @@ def _send_to_destinations( def _send_to_destination( self, - destination, - header, - payload, - transport_kwargs, - add_path_step=True, + destination: int, + header: dict[str, Any] | None, + payload: Any, + transport_kwargs: dict[str, Any], + add_path_step: bool = True, mangle_for_sending: Callable[[Any], Any] | None = None, - ): + ) -> None: """Helper function to send a message to a specific recipe destination.""" if header: header = header.copy() @@ -271,7 +518,9 @@ def _send_to_destination( **dest_kwargs, ) - def _retry_transport(self, function, *args, **kwargs): + def _retry_transport( + self, function: Callable[..., Any], *args: Any, **kwargs: Any + ) -> Any: """Attempt to send a message, and in case the connection has been lost, attempt to reconnect. Reconnecting only works on the assumption that the previous connection did not include any subscriptions, which should diff --git a/src/workflows/services/__init__.py b/src/workflows/services/__init__.py index 5a6a1d45..0e356f76 100644 --- a/src/workflows/services/__init__.py +++ b/src/workflows/services/__init__.py @@ -1,12 +1,21 @@ from __future__ import annotations +from collections.abc import Callable from importlib.metadata import entry_points +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from workflows.services.common_service import CommonService -def lookup(service: str): + +def lookup(service: str) -> type[CommonService] | None: """Find a service class based on a name. - :param service: Name of the service - :return: A service class + + Args: + service: Name of the service + + Returns: + A service class """ service_factory = get_known_services().get(service) if service_factory: @@ -15,11 +24,13 @@ def lookup(service: str): return None -def get_known_services(): +def get_known_services() -> dict[str, Callable[[], type[CommonService]]]: """Return a dictionary of all known services. - :return: A dictionary containing entries { service name : service class factory } - A factory is a function that takes no arguments and returns an - uninstantiated service class. + + Returns: + A dictionary containing entries { service name : service class factory } + where a factory is a function that takes no arguments and returns an + uninstantiated service class. """ if not hasattr(get_known_services, "cache"): setattr( @@ -27,5 +38,5 @@ def get_known_services(): "cache", {e.name: e.load for e in entry_points(group="workflows.services")}, ) - register = get_known_services.cache.copy() + register = get_known_services.cache.copy() # type: ignore return register diff --git a/src/workflows/services/common_service.py b/src/workflows/services/common_service.py index f6e34040..4d6b9573 100644 --- a/src/workflows/services/common_service.py +++ b/src/workflows/services/common_service.py @@ -4,9 +4,11 @@ import enum import itertools import logging +import multiprocessing.connection import queue import threading import time +from collections.abc import Callable, Generator, Mapping from typing import Any from opentelemetry import trace @@ -17,6 +19,7 @@ import workflows import workflows.logging +from workflows.transport.common_transport import CommonTransport, MessageCallback from workflows.transport.middleware.otel_tracing import OTELTracingMiddleware @@ -56,7 +59,7 @@ class Status(enum.Enum): NONE = (-1, "no service loaded") # Node has no service instance loaded TEARDOWN = (-2, "shutdown") # Node is shutting down - def __init__(self, intval, description): + def __init__(self, intval: int, description: str): """ Each status is defined as a tuple of a unique integer value and a descriptive string. These are available via enum properties @@ -97,18 +100,19 @@ class CommonService: # Logger name --------------------------------------------------------------- - _logger_name = "workflows.service" # The logger can be accessed via self.log + #: The logger can be accessed via self.log + _logger_name = "workflows.service" # Overrideable functions ---------------------------------------------------- - def initializing(self): + def initializing(self) -> None: """Service initialization. This function is run before any commands are received from the frontend. This is the place to request channel subscriptions with the messaging layer, and register callbacks. This function can be overridden by specific service implementations.""" pass - def in_shutdown(self): + def in_shutdown(self) -> None: """Service shutdown. This function is run before the service is terminated. No more commands are received, but communications can still be sent. This function can be overridden by specific service implementations.""" @@ -135,54 +139,65 @@ def in_shutdown(self): # Any keyword arguments set on service invocation - start_kwargs: dict[Any, Any] = {} + start_kwargs: dict[str, Any] + _transport_interceptor_counter: itertools.count[int] # Not so overrideable functions --------------------------------------------- - def __init__(self, *args, **kwargs): - """Service constructor. Parameters include optional references to two - pipes: frontend= for messages from the service to the frontend, - and commands= for messages from the frontend to the service. - A dictionary can optionally be passed with environment=, which is then - available to the service during runtime.""" - self.__pipe_frontend = None - self.__pipe_commands = None - self._environment = kwargs.get("environment", {}) - self._transport = None - self.__callback_register = {} - self.__log_extensions = [] - self.__service_status = None - self.__shutdown = False - self.__update_service_status(self.SERVICE_STATUS_NEW) - self.__queue = queue.PriorityQueue() - self._idle_callback = None - self._idle_time = None + def __init__(self, *, environment: dict[str, Any] | None = None): + """ + Service constructor. + + Args: + environment: + Optional dictionary made available to the service at runtime + via ``self._environment``. Typically used by the frontend to + pass configuration (e.g. ``config``, ``metrics``, ``liveness``) + into the spawned service process. + """ + self.__pipe_frontend: multiprocessing.connection.Connection | None = None + self.__pipe_commands: multiprocessing.connection.Connection | None = None + self._environment = environment if environment is not None else {} + self._transport: CommonTransport | None = None + self.__callback_register: dict[str, Callable[[Any], None]] = {} + self.__log_extensions: list[tuple[str, Any]] = [] + self.__service_status: int = self.SERVICE_STATUS_NEW + self.__shutdown: bool = False + self.__queue: queue.PriorityQueue[tuple[Priority, int, Any]] = ( + queue.PriorityQueue() + ) + self._idle_callback: Callable[[], None] | None = None + self._idle_time: float | None = None + self.start_kwargs = {} # Logger will be overwritten in start() function self.log = logging.getLogger(self._logger_name) - def __send_to_frontend(self, data_structure): + def __send_to_frontend(self, data_structure: Any) -> None: """Put a message in the pipe for the frontend.""" if self.__pipe_frontend: self.__pipe_frontend.send(data_structure) @property - def config(self): + def config(self) -> Any: return self._environment.get("config") @property - def transport(self): + def transport(self) -> CommonTransport: + # Handle the fact that we apparently allow missing transport layers + if self._transport is None: + raise RuntimeError("Transport layer has not yet been defined") return self._transport @transport.setter - def transport(self, value): + def transport(self, value: CommonTransport) -> None: if self._transport: raise AttributeError("Transport already defined") self._transport = value - def start_transport(self): - """If a transport object has been defined then connect it now.""" - if self.transport: + def start_transport(self) -> None: + """If a transport object has been defined, then connect it.""" + if self._transport: if self.transport.connect(): self.log.debug("Service successfully connected to transport layer") else: @@ -220,7 +235,7 @@ def start_transport(self): otel_middleware = OTELTracingMiddleware( tracer, service_name=self._service_name ) - self._transport.add_middleware(otel_middleware) + self.transport.add_middleware(otel_middleware) metrics = self._environment.get("metrics") if metrics: @@ -233,24 +248,24 @@ def start_transport(self): self.log.debug("Instrumenting transport") source = f"{self.__module__}:{self.__class__.__name__}" instrument = PrometheusMiddleware(source=source) - self._transport.add_middleware(instrument) + self.transport.add_middleware(instrument) port = metrics["port"] self.log.debug(f"Starting metrics endpoint on port {port}") prometheus_client.start_http_server(port=port) else: self.log.debug("No transport layer defined for service. Skipping.") - def stop_transport(self): + def stop_transport(self) -> None: """If a transport object has been defined then tear it down.""" - if self.transport: + if self._transport: self.log.debug("Stopping transport object") self.transport.disconnect() - def _transport_interceptor(self, callback): + def _transport_interceptor(self, callback: MessageCallback) -> MessageCallback: """Takes a callback function and returns a function that takes headers and messages and places them on the main service queue.""" - def add_item_to_queue(header, message): + def add_item_to_queue(header: Mapping[str, Any], message: Any) -> None: queue_item = ( Priority.TRANSPORT, next( @@ -264,22 +279,52 @@ def add_item_to_queue(header, message): return add_item_to_queue - def connect(self, frontend=None, commands=None): - """Inject pipes connecting the service to the frontend. Two arguments are - supported: frontend= for messages from the service to the frontend, - and commands= for messages from the frontend to the service. - The injection should happen before the service is started, otherwise the - underlying file descriptor references may not be handled correctly.""" - if frontend: + def connect( + self, + frontend: multiprocessing.connection.Connection | None = None, + commands: multiprocessing.connection.Connection | None = None, + ) -> None: + """Inject the pipes connecting this service to the frontend. + + Injection should happen before :meth:`start` is called, otherwise + the underlying file descriptor references may not be handled + correctly across the process boundary. + + Args: + frontend: Write end of the pipe used to send messages from the + service to the frontend (status updates, log records, etc.). + Setting this also triggers an immediate status broadcast. + commands: Read end of the pipe used to receive command messages + from the frontend. If left as ``None`` the service has no + way to receive commands and will shut itself down shortly + after :meth:`start`. + """ + if frontend is not None: self.__pipe_frontend = frontend self.__send_service_status_to_frontend() - if commands: + if commands is not None: self.__pipe_commands = commands @contextlib.contextmanager - def extend_log(self, field, value): - """A context wherein a specified extra field in log messages is populated - with a fixed value. This affects all log messages within the context.""" + def extend_log(self, field: str, value: Any) -> Generator[None, None, None]: + """Annotate log records emitted within the context with an extra field. + + The ``(field, value)`` pair is attached to every log record produced + while the context is active, and removed on exit. If an exception + propagates out of the block, the field is also stashed on the + exception as ``workflows_log_`` so downstream handlers + (notably :meth:`process_uncaught_exception`) can surface it. + + Args: + field: Name of the extra field to attach to log records. Must be + a valid Python identifier suffix, as it is also used to + build the attribute name on any escaping exception. + value: Value to associate with ``field``. Anything that the + log handler can serialize is acceptable. + + Yields: + Control to the wrapped block. No value is yielded. + """ self.__log_extensions.append((field, value)) try: yield @@ -289,7 +334,7 @@ def extend_log(self, field, value): finally: self.__log_extensions.remove((field, value)) - def __command_queue_listener(self): + def __command_queue_listener(self) -> None: """Function to continuously retrieve data from the frontend. Commands are sent to the central priority queue. If the pipe from the frontend is closed the service shutdown is initiated. Check every second if service @@ -297,6 +342,9 @@ def __command_queue_listener(self): This function is run by a separate daemon thread, which is started by the __start_command_queue_listener function. """ + assert self.__pipe_commands is not None, ( + "Listener started without command queue connection" + ) self.log.debug("Queue listener thread started") counter = itertools.count() # insertion sequence to keep messages in order while not self.__shutdown: @@ -325,14 +373,14 @@ def __command_queue_listener(self): time.sleep(0.05) self.log.debug("Queue listener thread terminating") - def __start_command_queue_listener(self): + def __start_command_queue_listener(self) -> None: """Start the function __command_queue_listener in a separate thread. This function continuously listens to the pipe connected to the frontend. """ thread_function = self.__command_queue_listener class QueueListenerThread(threading.Thread): - def run(qltself): + def run(self) -> None: thread_function() assert not hasattr(self, "__queue_listener_thread") @@ -342,52 +390,52 @@ def run(qltself): self.__queue_listener_thread.name = "Command Queue Listener" self.__queue_listener_thread.start() - def _log_send(self, logrecord): + def _log_send(self, logrecord: logging.LogRecord) -> None: """Forward log records to the frontend.""" for field, value in self.__log_extensions: setattr(logrecord, field, value) self.__send_to_frontend({"band": "log", "payload": logrecord}) - def _register(self, message_band, callback): + def _register(self, message_band: str, callback: Callable[[Any], None]) -> None: """Register a callback function for a specific message band.""" self.__callback_register[message_band] = callback - def _register_idle(self, idle_time, callback): + def _register_idle(self, idle_time: float, callback: Callable[[], None]) -> None: """Register a callback function that is run when idling for a given time span (in seconds).""" self._idle_callback = callback self._idle_time = idle_time - def __update_service_status(self, statuscode): + def __update_service_status(self, statuscode: int) -> None: """Set the internal status of the service object, and notify frontend.""" if self.__service_status != statuscode: self.__service_status = statuscode self.__send_service_status_to_frontend() - def __send_service_status_to_frontend(self): + def __send_service_status_to_frontend(self) -> None: """Actually send the internal status of the service object to the frontend.""" self.__send_to_frontend( {"band": "status_update", "statuscode": self.__service_status} ) - def get_name(self): + def get_name(self) -> str: """Get the name for this service.""" return self._service_name - def _set_name(self, name): + def _set_name(self, name: str) -> None: """Set a new name for this service, and notify the frontend accordingly.""" self._service_name = name self.__send_to_frontend({"band": "set_name", "name": self._service_name}) - def _request_termination(self): + def _request_termination(self) -> None: """Terminate the service from the frontend side""" self.__send_to_frontend({"band": "request_termination"}) - def _shutdown(self): + def _shutdown(self) -> None: """Terminate the service from the service side.""" self.__shutdown = True - def initialize_logging(self): + def initialize_logging(self) -> None: """Reset the logging for the service process. All logged messages are forwarded to the frontend. If any filtering is desired, then this must take place on the service side.""" @@ -419,14 +467,28 @@ def initialize_logging(self): console.setLevel(logging.CRITICAL) root_logger.addHandler(console) - def start(self, **kwargs): - """Start listening to command queue, process commands in main loop, - set status, etc... - This function is most likely called by the frontend in a separate - process.""" - + def start(self, *, verbose_log: bool = False, **kwargs: Any) -> None: + """Run the service main loop until shutdown. + + This is the entry point invoked by the frontend in the spawned service + process. It sets up logging and transport, calls :meth:`initializing`, + then enters the main loop, dispatching command-band and transport-band + messages off the internal priority queue and emitting status updates as + the service state changes. On shutdown - ``clean``, or via an unhandled + exception - :meth:`in_shutdown`, is invoked and the transport is torn + down. + + Args: + verbose_log: + If set, initialises the service logger level to ``DEBUG``. + **kwargs: + Other arbitrary keyword arguments, forwarded by the frontend. + Stored on :attr:`start_kwargs` for use by subclasses. + """ # Keep a copy of keyword arguments for use in subclasses self.start_kwargs.update(kwargs) + if verbose_log: + self.start_kwargs["verbose_log"] = verbose_log try: self.initialize_logging() @@ -449,14 +511,14 @@ def start(self, **kwargs): try: task = self.__queue.get(True, self._idle_time or 2) - run_idle_task = False except queue.Empty: - run_idle_task = True + task = None - if self.transport and not self.transport.is_connected(): + if self._transport and not self.transport.is_connected(): raise workflows.Disconnected("Connection lost") - if run_idle_task: + if task is None: + # Run the idle task if self._idle_time: # run this outside the 'except' to avoid exception chaining self.__update_service_status(self.SERVICE_STATUS_TIMER) @@ -505,7 +567,7 @@ def start(self, **kwargs): self.process_uncaught_exception(e) self.__update_service_status(self.SERVICE_STATUS_ERROR) - def process_uncaught_exception(self, e): + def process_uncaught_exception(self, e: BaseException) -> None: """This is called to handle otherwise uncaught exceptions from the service. The service will terminate either way, but here we can do things such as gathering useful environment information and logging for posterity.""" @@ -532,7 +594,7 @@ def process_uncaught_exception(self, e): "Unhandled service exception: %s", e, exc_info=True, extra=added_information ) - def __process_command(self, command): + def __process_command(self, command: str) -> None: """Process an incoming command message from the frontend.""" if command == Commands.SHUTDOWN: self.__shutdown = True diff --git a/src/workflows/services/sample_consumer.py b/src/workflows/services/sample_consumer.py index 61087b36..e67dabb8 100644 --- a/src/workflows/services/sample_consumer.py +++ b/src/workflows/services/sample_consumer.py @@ -2,6 +2,8 @@ import json import time +from collections.abc import Mapping +from typing import Any from workflows.services.common_service import CommonService @@ -17,11 +19,11 @@ class SampleConsumer(CommonService): # Logger name _logger_name = "workflows.service.sample_consumer" - def initializing(self): + def initializing(self) -> None: """Subscribe to a channel.""" - self._transport.subscribe("transient.destination", self.consume_message) + self.transport.subscribe("transient.destination", self.consume_message) - def consume_message(self, header, message): + def consume_message(self, header: Mapping[str, Any], message: Any) -> None: """Consume a message""" t = (time.time() % 1000) * 1000 diff --git a/src/workflows/services/sample_pipethrough.py b/src/workflows/services/sample_pipethrough.py index 51488abe..e9975174 100644 --- a/src/workflows/services/sample_pipethrough.py +++ b/src/workflows/services/sample_pipethrough.py @@ -2,8 +2,10 @@ import json import time +from typing import Any import workflows.recipe +from workflows.recipe.wrapper import RecipeWrapper from workflows.services.common_service import CommonService @@ -18,15 +20,15 @@ class SamplePipethrough(CommonService): # Logger name _logger_name = "workflows.service.sample_pipethrough" - def initializing(self): + def initializing(self) -> None: """Subscribe to a channel.""" workflows.recipe.wrap_subscribe( - self._transport, + self.transport, "transient.destination", self.process, ) - def process(self, rw, header, message): + def process(self, rw: RecipeWrapper, header: dict, message: Any) -> None: """Consume message and send to output pipe.""" t = (time.time() % 1000) * 1000 diff --git a/src/workflows/services/sample_producer.py b/src/workflows/services/sample_producer.py index 221ef2b5..ff8d358d 100644 --- a/src/workflows/services/sample_producer.py +++ b/src/workflows/services/sample_producer.py @@ -18,7 +18,7 @@ class SampleProducer(CommonService): counter = 0 - def initializing(self): + def initializing(self) -> None: """Service initialization. This function is run before any commands are received from the frontend. This is the place to request channel subscriptions with the messaging layer, and register callbacks. @@ -26,11 +26,11 @@ def initializing(self): self.log.info("Starting message producer") self._register_idle(3, self.create_message) - def create_message(self): + def create_message(self) -> None: """Create and send a unique message for this service.""" self.counter += 1 self.log.info("Sending message #%d", self.counter) - self._transport.send( + self.transport.send( "transient.destination", "Message #%d\n++++++++Produced @%10.3f ms" % (self.counter, (time.time() % 1000) * 1000), diff --git a/src/workflows/services/sample_transaction.py b/src/workflows/services/sample_transaction.py index e50af554..c4ab2813 100644 --- a/src/workflows/services/sample_transaction.py +++ b/src/workflows/services/sample_transaction.py @@ -2,6 +2,8 @@ import random import time +from collections.abc import Mapping +from typing import Any from workflows.services.common_service import CommonService @@ -16,9 +18,9 @@ class SampleTxn(CommonService): # Human readable service name _service_name = "Transaction sample" - def initializing(self): + def initializing(self) -> None: """Subscribe to a channel. Received messages must be acknowledged.""" - self.subid = self._transport.subscribe( + self.subid = self.transport.subscribe( "transient.transaction", self.receive_message, acknowledgement=True, @@ -26,12 +28,12 @@ def initializing(self): ) @staticmethod - def crashpoint(): + def crashpoint() -> bool: """Return true if the service should malfunction at this point.""" # Probability of not crashing is 90% return random.uniform(0, 1) > 0.90 - def receive_message(self, header, message): + def receive_message(self, header: Mapping[str, Any], message: Any) -> None: """Receive a message""" self.log.info("=== Receive ===") @@ -41,29 +43,29 @@ def receive_message(self, header, message): self.log.info("MsgID: {}".format(header["message-id"])) assert header["message-id"] - txn = self._transport.transaction_begin() + txn = self.transport.transaction_begin() self.log.info(f" 1. Txn: {txn}") if self.crashpoint(): - self._transport.transaction_abort(txn) + self.transport.transaction_abort(txn) self.log.info("--- Abort ---") return - self._transport.ack(header["message-id"], self.subid, transaction=txn) + self.transport.ack(header["message-id"], self.subid, transaction=txn) self.log.info(" 2. Ack") if self.crashpoint(): - self._transport.transaction_abort(txn) + self.transport.transaction_abort(txn) self.log.info("--- Abort ---") return - self._transport.send("transient.destination", message, transaction=txn) + self.transport.send("transient.destination", message, transaction=txn) self.log.info(" 3. Send") if self.crashpoint(): - self._transport.transaction_abort(txn) + self.transport.transaction_abort(txn) self.log.info("--- Abort ---") return - self._transport.transaction_commit(txn) + self.transport.transaction_commit(txn) self.log.info(" 4. Commit") self.log.info("=== Done ===") @@ -78,17 +80,17 @@ class SampleTxnProducer(CommonService): counter = 0 - def initializing(self): + def initializing(self) -> None: """Service initialization. This function is run before any commands are received from the frontend. This is the place to request channel subscriptions with the messaging layer, and register callbacks. This function can be overridden by specific service implementations.""" self._register_idle(3, self.create_message) - def create_message(self): + def create_message(self) -> None: """Create and send a unique message for this service.""" self.counter += 1 - self._transport.send( + self.transport.send( "transient.transaction", "TXMessage #%d\n++++++++Produced@ %f" % (self.counter, (time.time() % 1000) * 1000), diff --git a/src/workflows/transport/common_transport.py b/src/workflows/transport/common_transport.py index 8e264604..551af0a7 100644 --- a/src/workflows/transport/common_transport.py +++ b/src/workflows/transport/common_transport.py @@ -1,7 +1,9 @@ from __future__ import annotations +import argparse import decimal import logging +import optparse from collections.abc import Callable, Mapping from typing import Any, NamedTuple @@ -17,8 +19,10 @@ class TemporarySubscription(NamedTuple): class CommonTransport: - """A common transport class, containing e.g. the logic to manage - subscriptions and transactions.""" + """A common transport class. + + Contains e.g. the logic to manage subscriptions and transactions. + """ __callback_interceptor = None __subscriptions: dict[int, dict[str, Any]] = {} @@ -33,89 +37,126 @@ class CommonTransport: # def __init__( - self, middleware: list[type[middleware.BaseTransportMiddleware]] | None = None + self, middleware: list[middleware.BaseTransportMiddleware] | None = None ): if middleware is None: self.middleware = [] else: self.middleware = middleware - def add_middleware(self, middleware: type[middleware.BaseTransportMiddleware]): + def add_middleware(self, middleware: middleware.BaseTransportMiddleware) -> None: self.middleware.insert(0, middleware) @classmethod - def add_command_line_options(cls, parser): - """Function to inject command line parameters.""" + def add_command_line_options( + cls, parser: argparse.ArgumentParser | optparse.OptionParser + ) -> None: + """Inject command line parameters.""" pass def connect(self) -> bool: """Connect the transport class. This function must be overridden. - :return: True-like value when connection successful, - False-like value otherwise.""" + + Returns: + True-like value when connection successful, False-like value + otherwise. + """ return False def is_connected(self) -> bool: - """Returns the current connection status. This function must be overridden. - :return: True-like value when connection is available, - False-like value otherwise.""" + """Return the current connection status. This function must be overridden. + + Returns: + True-like value when connection is available, False-like value + otherwise. + """ return False - def disconnect(self): - """Gracefully disconnect the transport class. This function should be - overridden.""" + def disconnect(self) -> None: + """Gracefully disconnect the transport class. + + This function should be overridden. + """ @middleware.wrap - def subscribe(self, channel, callback, **kwargs) -> int: + def subscribe( + self, + channel: str, + callback: MessageCallback, + *, + disable_mangling: bool = False, + acknowledgement: bool = False, + **kwargs: Any, + ) -> int: """Listen to a queue, notify via callback function. - :param channel: Queue name to subscribe to - :param callback: Function to be called when messages are received. - The callback will pass two arguments, the header as a - dictionary structure, and the message. - :param **kwargs: Further parameters for the transport layer. For example - disable_mangling: Receive messages as unprocessed strings. - exclusive: Attempt to become exclusive subscriber to the queue. - acknowledgement: If true receipt of each message needs to be - acknowledged. - :return: A unique subscription ID + + Args: + channel: Queue name to subscribe to. + callback: Function to be called when messages are received. + The callback will pass two arguments, the header as a + dictionary structure, and the message. + disable_mangling: Receive messages as unprocessed strings. + acknowledgement: If true receipt of each message needs to be + acknowledged. + **kwargs: Further parameters for the transport layer. For example: + exclusive: Attempt to become exclusive subscriber to the queue. + + Returns: + A unique subscription ID. """ self.__subscription_id += 1 - def mangled_callback(header, message): + def mangled_callback(header: Mapping[str, Any], message: Any, /) -> Any: return callback(header, self._mangle_for_receiving(message)) - if "disable_mangling" in kwargs: - if kwargs["disable_mangling"]: - mangled_callback = callback # noqa:F811 - del kwargs["disable_mangling"] + if disable_mangling: + mangled_callback = callback # noqa:F811 self.__subscriptions[self.__subscription_id] = { "channel": channel, "callback": mangled_callback, - "ack": kwargs.get("acknowledgement"), + "ack": acknowledgement, "unsubscribed": False, } self.log.debug("Subscribing to %s with ID %d", channel, self.__subscription_id) - self._subscribe(self.__subscription_id, channel, mangled_callback, **kwargs) + self._subscribe( + self.__subscription_id, + channel, + mangled_callback, + acknowledgement=acknowledgement, + **kwargs, + ) return self.__subscription_id @middleware.wrap def subscribe_temporary( - self, channel_hint: str | None, callback: MessageCallback, **kwargs + self, + channel_hint: str | None, + callback: MessageCallback, + *, + disable_mangling: bool = False, + acknowledgement: bool = False, + **kwargs: Any, ) -> TemporarySubscription: - """Listen to a new queue that is specifically created for this connection, - and has a limited lifetime. Notify for messages via callback function. - :param channel_hint: Suggested queue name to subscribe to, the actual - queue name will be decided by both transport layer - and server. - :param callback: Function to be called when messages are received. - The callback will pass two arguments, the header as a - dictionary structure, and the message. - :param **kwargs: Further parameters for the transport layer. For example - disable_mangling: Receive messages as unprocessed strings. - acknowledgement: If true receipt of each message needs to be - acknowledged. - :return: A named tuple containing a unique subscription ID and the actual - queue name which can then be referenced by other senders. + """Listen to a new queue specifically created for this connection. + + The queue has a limited lifetime. Notify for messages via callback + function. + + Args: + channel_hint: Suggested queue name to subscribe to, the actual + queue name will be decided by both transport layer and server. + callback: Function to be called when messages are received. + The callback will pass two arguments, the header as a + dictionary structure, and the message. + disable_mangling: Receive messages as unprocessed strings. + acknowledgement: If true receipt of each message needs to be + acknowledged. + **kwargs: Further parameters for the transport layer. + + Returns: + A named tuple containing a unique subscription ID and the actual + queue name which can then be referenced by other senders. """ self.__subscription_id += 1 @@ -125,14 +166,12 @@ def _(header: Mapping[str, Any], message: Any) -> None: mangled_callback: MessageCallback = _ - if "disable_mangling" in kwargs: - if kwargs["disable_mangling"]: - mangled_callback = callback # noqa:F811 - del kwargs["disable_mangling"] + if disable_mangling: + mangled_callback = callback # noqa:F811 self.__subscriptions[self.__subscription_id] = { # "channel": channel, "callback": mangled_callback, - "ack": kwargs.get("acknowledgement"), + "ack": acknowledgement, "unsubscribed": False, } self.log.debug( @@ -141,7 +180,11 @@ def _(header: Mapping[str, Any], message: Any) -> None: self.__subscription_id, ) queue_name = self._subscribe_temporary( - self.__subscription_id, channel_hint, mangled_callback, **kwargs + self.__subscription_id, + channel_hint, + mangled_callback, + acknowledgement=acknowledgement, + **kwargs, ) return TemporarySubscription( @@ -149,16 +192,22 @@ def _(header: Mapping[str, Any], message: Any) -> None: ) @middleware.wrap - def unsubscribe(self, subscription: int, drop_callback_reference=False, **kwargs): - """Stop listening to a queue or a broadcast - :param subscription: Subscription ID to cancel - :param drop_callback_reference: Drop the reference to the registered - callback function immediately. This - means any buffered messages still in - flight will not arrive at the intended - destination and cause exceptions to be - raised instead. - :param **kwargs: Further parameters for the transport layer. + def unsubscribe( + self, + subscription: int, + *, + drop_callback_reference: bool = False, + **kwargs: Any, + ) -> None: + """Stop listening to a queue or a broadcast. + + Args: + subscription: Subscription ID to cancel. + drop_callback_reference: Drop the reference to the registered + callback function immediately. This means any buffered + messages still in flight will not arrive at the intended + destination and cause exceptions to be raised instead. + **kwargs: Further parameters for the transport layer. """ if subscription not in self.__subscriptions: @@ -172,11 +221,14 @@ def unsubscribe(self, subscription: int, drop_callback_reference=False, **kwargs if drop_callback_reference: self.drop_callback_reference(subscription) - def drop_callback_reference(self, subscription: int): + def drop_callback_reference(self, subscription: int) -> None: """Drop reference to the callback function after unsubscribing. + Any future messages arriving for that subscription will result in exceptions being raised. - :param subscription: Subscription ID to delete callback reference for. + + Args: + subscription: Subscription ID to delete callback reference for. """ if subscription not in self.__subscriptions: raise workflows.Error( @@ -189,27 +241,36 @@ def drop_callback_reference(self, subscription: int): del self.__subscriptions[subscription] @middleware.wrap - def subscribe_broadcast(self, channel, callback, **kwargs) -> int: + def subscribe_broadcast( + self, + channel: str, + callback: MessageCallback, + *, + disable_mangling: bool = False, + **kwargs: Any, + ) -> int: """Listen to a broadcast topic, notify via callback function. - :param channel: Topic name to subscribe to - :param callback: Function to be called when messages are received. - The callback will pass two arguments, the header as a - dictionary structure, and the message. - :param **kwargs: Further parameters for the transport layer. For example - disable_mangling: Receive messages as unprocessed strings. - retroactive: Ask broker to send old messages if possible - :return: A unique subscription ID + + Args: + channel: Topic name to subscribe to. + callback: Function to be called when messages are received. + The callback will pass two arguments, the header as a + dictionary structure, and the message. + disable_mangling: Receive messages as unprocessed strings. + **kwargs: Further parameters for the transport layer. For example: + retroactive: Ask broker to send old messages if possible. + + Returns: + A unique subscription ID. """ self.__subscription_id += 1 - def mangled_callback(header, message): + def mangled_callback(header: Mapping[str, Any], message: Any, /) -> Any: return callback(header, self._mangle_for_receiving(message)) - if "disable_mangling" in kwargs: - if kwargs["disable_mangling"]: - mangled_callback = callback # noqa:F811 - del kwargs["disable_mangling"] + if disable_mangling: + mangled_callback = callback # noqa:F811 self.__subscriptions[self.__subscription_id] = { "channel": channel, "callback": mangled_callback, @@ -227,12 +288,19 @@ def mangled_callback(header, message): return self.__subscription_id def subscription_callback(self, subscription: int) -> MessageCallback: - """Retrieve the callback function for a subscription. Raise a - workflows.Error if the subscription does not exist. - All transport callbacks can be intercepted by setting an - interceptor function with subscription_callback_intercept(). - :param subscription: Subscription ID to look up - :return: Callback function + """Retrieve the callback function for a subscription. + + All transport callbacks can be intercepted by setting an interceptor + function with subscription_callback_intercept(). + + Args: + subscription: Subscription ID to look up. + + Returns: + Callback function. + + Raises: + workflows.Error: If the subscription does not exist. """ subscription_record = self.__subscriptions.get(subscription) if not subscription_record: @@ -242,96 +310,131 @@ def subscription_callback(self, subscription: int) -> MessageCallback: return self.__callback_interceptor(callback) return callback - def subscription_callback_set_intercept(self, interceptor): - """Set a function to intercept all callbacks. This is useful to, for - example, keep a thread barrier between the transport related functions - and processing functions. - :param interceptor: A function that takes the original callback function - and returns a modified callback function. Or None to - disable interception. + def subscription_callback_set_intercept( + self, + interceptor: Callable[[MessageCallback], MessageCallback] | None, + ) -> None: + """Set a function to intercept all callbacks. + + This is useful to, for example, keep a thread barrier between the + transport related functions and processing functions. + + Args: + interceptor: A function that takes the original callback function + and returns a modified callback function. Or None to disable + interception. """ self.__callback_interceptor = interceptor @middleware.wrap - def send(self, destination, message, **kwargs): + def send( + self, + destination: str, + message: Any, + *, + headers: dict | None = None, + **kwargs: Any, + ) -> None: """Send a message to a queue. - :param destination: Queue name to send to - :param message: Either a string or a serializable object to be sent - :param **kwargs: Further parameters for the transport layer. For example - delay: Delay transport of message by this many seconds - headers: Optional dictionary of header entries - expiration: Optional expiration time, relative to sending time - transaction: Transaction ID if message should be part of a - transaction + + Args: + destination: Queue name to send to. + message: The message. Usually string-like or json-serializable but + exact specification depends on the concrete transport. + headers: Optional dictionary of header entries to set. + **kwargs: Further parameters for the transport layer. For example: + delay: Delay transport of message by this many seconds. + expiration: Optional expiration time, relative to sending time. + transaction: Transaction ID if message should be part of a + transaction. + + Raises: On failure. """ message = self._mangle_for_sending(message) - self._send(destination, message, **kwargs) + self._send(destination, message, headers=headers, **kwargs) @middleware.wrap - def raw_send(self, destination, message, **kwargs): + def raw_send(self, destination: str, message: Any, **kwargs: Any) -> None: """Send a raw (unmangled) message to a queue. + This may cause errors if the receiver expects a mangled message. - :param destination: Queue name to send to - :param message: Either a string or a serializable object to be sent - :param **kwargs: Further parameters for the transport layer. For example - delay: Delay transport of message by this many seconds - headers: Optional dictionary of header entries - expiration: Optional expiration time, relative to sending time - transaction: Transaction ID if message should be part of a - transaction + + Args: + destination: Queue name to send to. + message: Either a string or a serializable object to be sent. + **kwargs: Further parameters for the transport layer. For example: + delay: Delay transport of message by this many seconds. + headers: Optional dictionary of header entries. + expiration: Optional expiration time, relative to sending time. + transaction: Transaction ID if message should be part of a + transaction. """ self._send(destination, message, **kwargs) @middleware.wrap - def broadcast(self, destination, message, **kwargs): + def broadcast(self, destination: str, message: Any, **kwargs: Any) -> None: """Broadcast a message. - :param destination: Topic name to send to - :param message: Either a string or a serializable object to be sent - :param **kwargs: Further parameters for the transport layer. For example - delay: Delay transport of message by this many seconds - headers: Optional dictionary of header entries - expiration: Optional expiration time, relative to sending time - transaction: Transaction ID if message should be part of a - transaction + + Args: + destination: Topic name to send to. + message: Either a string or a serializable object to be sent. + **kwargs: Further parameters for the transport layer. For example: + delay: Delay transport of message by this many seconds. + headers: Optional dictionary of header entries. + expiration: Optional expiration time, relative to sending time. + transaction: Transaction ID if message should be part of a + transaction. """ message = self._mangle_for_sending(message) self._broadcast(destination, message, **kwargs) @middleware.wrap - def raw_broadcast(self, destination, message, **kwargs): + def raw_broadcast(self, destination: str, message: Any, **kwargs: Any) -> None: """Broadcast a raw (unmangled) message. + This may cause errors if the receiver expects a mangled message. - :param destination: Topic name to send to - :param message: Either a string or a serializable object to be sent - :param **kwargs: Further parameters for the transport layer. For example - delay: Delay transport of message by this many seconds - headers: Optional dictionary of header entries - expiration: Optional expiration time, relative to sending time - transaction: Transaction ID if message should be part of a - transaction + + Args: + destination: Topic name to send to. + message: Either a string or a serializable object to be sent. + **kwargs: Further parameters for the transport layer. For example: + delay: Delay transport of message by this many seconds. + headers: Optional dictionary of header entries. + expiration: Optional expiration time, relative to sending time. + transaction: Transaction ID if message should be part of a + transaction. """ self._broadcast(destination, message, **kwargs) - def broadcast_status(self, status: dict) -> None: - """Broadcast transient status information to all listeners""" + def broadcast_status(self, status: Mapping) -> None: + """Broadcast transient status information to all listeners.""" raise NotImplementedError @middleware.wrap - def ack(self, message, subscription_id: int | None = None, **kwargs): - """Acknowledge receipt of a message. This only makes sense when the - 'acknowledgement' flag was set for the relevant subscription. - :param message: ID of the message to be acknowledged, OR a dictionary - containing a field 'message-id'. - :param subscription_id: ID of the associated subscription. Optional when - a dictionary is passed as first parameter and - that dictionary contains field 'subscription'. - :param **kwargs: Further parameters for the transport layer. For example - transaction: Transaction ID if acknowledgement should be part of - a transaction + def ack( + self, + message: Any, + subscription_id: int | None = None, + **kwargs: Any, + ) -> None: + """Acknowledge receipt of a message. + + This only makes sense when the 'acknowledgement' flag was set for the + relevant subscription. + + Args: + message: ID of the message to be acknowledged, OR a dictionary + containing a field 'message-id'. + subscription_id: ID of the associated subscription. Optional when + a dictionary is passed as first parameter and that dictionary + contains field 'subscription'. + **kwargs: Further parameters for the transport layer. For example: + transaction: Transaction ID if acknowledgement should be part + of a transaction. """ if isinstance(message, dict): @@ -352,17 +455,26 @@ def ack(self, message, subscription_id: int | None = None, **kwargs): self._ack(message_id, subscription_id=subscription_id, **kwargs) @middleware.wrap - def nack(self, message, subscription_id: int | None = None, **kwargs): - """Reject receipt of a message. This only makes sense when the - 'acknowledgement' flag was set for the relevant subscription. - :param message: ID of the message to be rejected, OR a dictionary - containing a field 'message-id'. - :param subscription_id: ID of the associated subscription. Optional when - a dictionary is passed as first parameter and - that dictionary contains field 'subscription'. - :param **kwargs: Further parameters for the transport layer. For example - transaction: Transaction ID if rejection should be part of a - transaction + def nack( + self, + message: Any, + subscription_id: int | None = None, + **kwargs: Any, + ) -> None: + """Reject receipt of a message. + + This only makes sense when the 'acknowledgement' flag was set for the + relevant subscription. + + Args: + message: ID of the message to be rejected, OR a dictionary + containing a field 'message-id'. + subscription_id: ID of the associated subscription. Optional when + a dictionary is passed as first parameter and that dictionary + contains field 'subscription'. + **kwargs: Further parameters for the transport layer. For example: + transaction: Transaction ID if rejection should be part of a + transaction. """ if isinstance(message, dict): @@ -381,10 +493,17 @@ def nack(self, message, subscription_id: int | None = None, **kwargs): self._nack(message_id, subscription_id=subscription_id, **kwargs) @middleware.wrap - def transaction_begin(self, subscription_id: int | None = None, **kwargs) -> int: + def transaction_begin( + self, subscription_id: int | None = None, **kwargs: Any + ) -> int: """Start a new transaction. - :param **kwargs: Further parameters for the transport layer. - :return: A transaction ID that can be passed to other functions. + + Args: + subscription_id: ID of the subscription to scope this transaction to. + **kwargs: Further parameters for the transport layer. + + Returns: + A transaction ID that can be passed to other functions. """ self.__transaction_id += 1 @@ -403,10 +522,12 @@ def transaction_begin(self, subscription_id: int | None = None, **kwargs) -> int return self.__transaction_id @middleware.wrap - def transaction_abort(self, transaction_id: int, **kwargs): + def transaction_abort(self, transaction_id: int, **kwargs: Any) -> None: """Abort a transaction and roll back all operations. - :param transaction_id: ID of transaction to be aborted. - :param **kwargs: Further parameters for the transport layer. + + Args: + transaction_id: ID of transaction to be aborted. + **kwargs: Further parameters for the transport layer. """ if transaction_id not in self.__transactions: @@ -416,10 +537,12 @@ def transaction_abort(self, transaction_id: int, **kwargs): self._transaction_abort(transaction_id, **kwargs) @middleware.wrap - def transaction_commit(self, transaction_id: int, **kwargs): + def transaction_commit(self, transaction_id: int, **kwargs: Any) -> None: """Commit a transaction. - :param transaction_id: ID of transaction to be committed. - :param **kwargs: Further parameters for the transport layer. + + Args: + transaction_id: ID of transaction to be committed. + **kwargs: Further parameters for the transport layer. """ if transaction_id not in self.__transactions: @@ -429,34 +552,52 @@ def transaction_commit(self, transaction_id: int, **kwargs): self._transaction_commit(transaction_id, **kwargs) @property - def is_reconnectable(self): - """Check if the transport object is in a status where reconnecting is - supported. There must not be any active subscriptions or transactions.""" + def is_reconnectable(self) -> bool: + """Check if the transport object is in a status where reconnecting is supported. + + There must not be any active subscriptions or transactions. + """ return not self.__subscriptions and not self.__transactions # # -- Low level communication calls to be implemented by subclass ----------- # - def _subscribe(self, sub_id: int, channel, callback, **kwargs): + def _subscribe( + self, + sub_id: int, + channel: str, + callback: MessageCallback, + **kwargs: Any, + ) -> None: """Listen to a queue, notify via callback function. - :param sub_id: ID for this subscription in the transport layer - :param channel: Queue name to subscribe to - :param callback: Function to be called when messages are received - :param **kwargs: Further parameters for the transport layer. For example - exclusive: Attempt to become exclusive subscriber to the queue. - acknowledgement: If true receipt of each message needs to be - acknowledged. + + Args: + sub_id: ID for this subscription in the transport layer. + channel: Queue name to subscribe to. + callback: Function to be called when messages are received. + **kwargs: Further parameters for the transport layer. For example: + exclusive: Attempt to become exclusive subscriber to the queue. + acknowledgement: If true receipt of each message needs to be + acknowledged. """ raise NotImplementedError("Transport interface not implemented") - def _subscribe_broadcast(self, sub_id: int, channel, callback, **kwargs): + def _subscribe_broadcast( + self, + sub_id: int, + channel: str, + callback: MessageCallback, + **kwargs: Any, + ) -> None: """Listen to a broadcast topic, notify via callback function. - :param sub_id: ID for this subscription in the transport layer - :param channel: Topic name to subscribe to - :param callback: Function to be called when messages are received - :param **kwargs: Further parameters for the transport layer. For example - retroactive: Ask broker to send old messages if possible + + Args: + sub_id: ID for this subscription in the transport layer. + channel: Topic name to subscribe to. + callback: Function to be called when messages are received. + **kwargs: Further parameters for the transport layer. For example: + retroactive: Ask broker to send old messages if possible. """ raise NotImplementedError("Transport interface not implemented") @@ -465,91 +606,121 @@ def _subscribe_temporary( sub_id: int, channel_hint: str | None, callback: MessageCallback, - **kwargs, + **kwargs: Any, ) -> str: """Create and then listen to a temporary queue, notify via callback function. - :param sub_id: ID for this subscription in the transport layer - :param channel_hint: Name suggestion for the temporary queue - :param callback: Function to be called when messages are received - :param **kwargs: Further parameters for the transport layer. For example - acknowledgement: If true receipt of each message needs to be - acknowledged. - :returns: The name of the temporary queue + + Args: + sub_id: ID for this subscription in the transport layer. + channel_hint: Name suggestion for the temporary queue. + callback: Function to be called when messages are received. + **kwargs: Further parameters for the transport layer. For example: + acknowledgement: If true receipt of each message needs to be + acknowledged. + + Returns: + The name of the temporary queue. """ raise NotImplementedError("Transport interface not implemented") - def _unsubscribe(self, sub_id: int, **kwargs): - """Stop listening to a queue or a broadcast - :param sub_id: ID for this subscription in the transport layer + def _unsubscribe(self, sub_id: int, **kwargs: Any) -> None: + """Stop listening to a queue or a broadcast. + + Args: + sub_id: ID for this subscription in the transport layer. + **kwargs: Further parameters for the transport layer. """ raise NotImplementedError("Transport interface not implemented") - def _send(self, destination, message, **kwargs): + def _send(self, destination: str, message: Any, **kwargs: Any) -> None: """Send a message to a queue. - :param destination: Queue name to send to - :param message: A string to be sent - :param **kwargs: Further parameters for the transport layer. For example - headers: Optional dictionary of header entries - expiration: Optional expiration time, relative to sending time - transaction: Transaction ID if message should be part of a - transaction + + Args: + destination: Queue name to send to. + message: A string to be sent. + **kwargs: Further parameters for the transport layer. For example: + headers: Optional dictionary of header entries. + expiration: Optional expiration time, relative to sending time. + transaction: Transaction ID if message should be part of a + transaction. """ raise NotImplementedError("Transport interface not implemented") - def _broadcast(self, destination, message, **kwargs): + def _broadcast(self, destination: str, message: Any, **kwargs: Any) -> None: """Broadcast a message. - :param destination: Topic name to send to - :param message: A string to be broadcast - :param **kwargs: Further parameters for the transport layer. For example - headers: Optional dictionary of header entries - expiration: Optional expiration time, relative to sending time - transaction: Transaction ID if message should be part of a - transaction + + Args: + destination: Topic name to send to. + message: A string to be broadcast. + **kwargs: Further parameters for the transport layer. For example: + headers: Optional dictionary of header entries. + expiration: Optional expiration time, relative to sending time. + transaction: Transaction ID if message should be part of a + transaction. """ raise NotImplementedError("Transport interface not implemented") - def _ack(self, message_id, subscription_id, **kwargs): - """Acknowledge receipt of a message. This only makes sense when the - 'acknowledgement' flag was set for the relevant subscription. - :param message_id: ID of the message to be acknowledged. - :param subscription_id: ID of the associated subscription. - :param **kwargs: Further parameters for the transport layer. For example - transaction: Transaction ID if acknowledgement should be part of - a transaction + def _ack(self, message_id: Any, subscription_id: int, **kwargs: Any) -> None: + """Acknowledge receipt of a message. + + This only makes sense when the 'acknowledgement' flag was set for the + relevant subscription. + + Args: + message_id: ID of the message to be acknowledged. + subscription_id: ID of the associated subscription. + **kwargs: Further parameters for the transport layer. For example: + transaction: Transaction ID if acknowledgement should be part + of a transaction. """ raise NotImplementedError("Transport interface not implemented") - def _nack(self, message_id, subscription_id, **kwargs): - """Reject receipt of a message. This only makes sense when the - 'acknowledgement' flag was set for the relevant subscription. - :param message_id: ID of the message to be rejected. - :param subscription_id: ID of the associated subscription. - :param **kwargs: Further parameters for the transport layer. For example - transaction: Transaction ID if rejection should be part of a - transaction + def _nack(self, message_id: Any, subscription_id: int, **kwargs: Any) -> None: + """Reject receipt of a message. + + This only makes sense when the 'acknowledgement' flag was set for the + relevant subscription. + + Args: + message_id: ID of the message to be rejected. + subscription_id: ID of the associated subscription. + **kwargs: Further parameters for the transport layer. For example: + transaction: Transaction ID if rejection should be part of a + transaction. """ raise NotImplementedError("Transport interface not implemented") def _transaction_begin( - self, transaction_id: int, *, subscription_id: int | None = None, **kwargs + self, + transaction_id: int, + *, + subscription_id: int | None = None, + **kwargs: Any, ) -> None: """Start a new transaction. - :param transaction_id: ID for this transaction in the transport layer. - :param **kwargs: Further parameters for the transport layer. + + Args: + transaction_id: ID for this transaction in the transport layer. + subscription_id: ID of the subscription to scope this transaction to. + **kwargs: Further parameters for the transport layer. """ raise NotImplementedError("Transport interface not implemented") - def _transaction_abort(self, transaction_id: int, **kwargs) -> None: + def _transaction_abort(self, transaction_id: int, **kwargs: Any) -> None: """Abort a transaction and roll back all operations. - :param transaction_id: ID of transaction to be aborted. - :param **kwargs: Further parameters for the transport layer. + + Args: + transaction_id: ID of transaction to be aborted. + **kwargs: Further parameters for the transport layer. """ raise NotImplementedError("Transport interface not implemented") - def _transaction_commit(self, transaction_id: int, **kwargs) -> None: + def _transaction_commit(self, transaction_id: int, **kwargs: Any) -> None: """Commit a transaction. - :param transaction_id: ID of transaction to be committed. - :param **kwargs: Further parameters for the transport layer. + + Args: + transaction_id: ID of transaction to be committed. + **kwargs: Further parameters for the transport layer. """ raise NotImplementedError("Transport interface not implemented") @@ -562,23 +733,32 @@ def _transaction_commit(self, transaction_id: int, **kwargs) -> None: # The canonical example is serialization/deserialization, see stomp_transport @staticmethod - def _mangle_for_sending(message): - """Function that any message will pass through before it being forwarded to - the actual _send* functions.""" + def _mangle_for_sending(message: Any) -> Any: + """Pass any message through this before forwarding to the actual _send* functions.""" return message @staticmethod - def _mangle_for_receiving(message): - """Function that any message will pass through before it being forwarded to - the receiving subscribed callback functions.""" + def _mangle_for_receiving(message: Any) -> Any: + """Pass any message through this before forwarding to the receiving subscribed callback functions.""" return message -def json_serializer(obj): - """A helper function for JSON serialization, where it can be used as - the default= argument. This function helps the serializer to translate - objects that otherwise would not be understood. Note that this is - one-way only - these objects are not restored on the receiving end.""" +def json_serializer(obj: Any) -> Any: + """Helper function for JSON serialization, usable as the ``default=`` argument. + + This function helps the serializer to translate objects that otherwise + would not be understood. Note that this is one-way only - these objects + are not restored on the receiving end. + + Args: + obj: The object to serialize. + + Returns: + A JSON-serializable representation of obj. + + Raises: + TypeError: If obj is not JSON serializable. + """ if isinstance(obj, decimal.Decimal): # turn all Decimals into floats diff --git a/src/workflows/transport/middleware/__init__.py b/src/workflows/transport/middleware/__init__.py index a8d8f76b..7ed665d9 100644 --- a/src/workflows/transport/middleware/__init__.py +++ b/src/workflows/transport/middleware/__init__.py @@ -4,8 +4,8 @@ import inspect import logging import time -from collections.abc import Callable -from typing import TYPE_CHECKING +from collections.abc import Callable, Mapping +from typing import TYPE_CHECKING, Any, Concatenate if TYPE_CHECKING: from workflows.transport.common_transport import ( @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -def get_callback_source(callable: Callable): +def get_callback_source(callable: Callable) -> str: if isinstance(callable, functools.partial): # functools.partial objects don't have a __qualname__ attribute # account for possibility of nested stack of functools.partials @@ -31,82 +31,148 @@ def get_callback_source(callable: Callable): class BaseTransportMiddleware: - def subscribe(self, call_next: Callable, channel, callback, **kwargs) -> int: - return call_next(channel, callback, **kwargs) + def subscribe( + self, + call_next: Callable[..., int], + channel: str, + callback: MessageCallback, + *, + disable_mangling: bool = False, + acknowledgement: bool = False, + **kwargs: Any, + ) -> int: + return call_next( + channel, + callback, + disable_mangling=disable_mangling, + acknowledgement=acknowledgement, + **kwargs, + ) def subscribe_temporary( self, - call_next: Callable, + call_next: Callable[..., TemporarySubscription], channel_hint: str | None, callback: MessageCallback, - **kwargs, + *, + disable_mangling: bool = False, + acknowledgement: bool = False, + **kwargs: Any, ) -> TemporarySubscription: - return call_next(channel_hint, callback, **kwargs) + return call_next( + channel_hint, + callback, + disable_mangling=disable_mangling, + acknowledgement=acknowledgement, + **kwargs, + ) def subscribe_broadcast( - self, call_next: Callable, channel, callback, **kwargs + self, + call_next: Callable[..., int], + channel: str, + callback: MessageCallback, + *, + disable_mangling: bool = False, + **kwargs: Any, ) -> int: - return call_next(channel, callback, **kwargs) + return call_next(channel, callback, disable_mangling=disable_mangling, **kwargs) def unsubscribe( self, - call_next: Callable, + call_next: Callable[..., None], subscription: int, - drop_callback_reference=False, - **kwargs, - ): + *, + drop_callback_reference: bool = False, + **kwargs: Any, + ) -> None: call_next( subscription, drop_callback_reference=drop_callback_reference, **kwargs ) - def send(self, call_next: Callable, destination, message, **kwargs): - call_next(destination, message, **kwargs) - - def raw_send(self, call_next: Callable, destination, message, **kwargs): + def send( + self, + call_next: Callable[..., None], + destination: str, + message: Any, + *, + headers: dict | None = None, + **kwargs: Any, + ) -> None: + call_next(destination, message, headers=headers, **kwargs) + + def raw_send( + self, + call_next: Callable[..., None], + destination: str, + message: Any, + **kwargs: Any, + ) -> None: call_next(destination, message, **kwargs) - def broadcast(self, call_next: Callable, destination, message, **kwargs): + def broadcast( + self, + call_next: Callable[..., None], + destination: str, + message: Any, + **kwargs: Any, + ) -> None: call_next(destination, message, **kwargs) - def raw_broadcast(self, call_next: Callable, destination, message, **kwargs): + def raw_broadcast( + self, + call_next: Callable[..., None], + destination: str, + message: Any, + **kwargs: Any, + ) -> None: call_next(destination, message, **kwargs) def ack( self, - call_next: Callable, - message, + call_next: Callable[..., None], + message: Any, subscription_id: int | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: call_next(message, subscription_id=subscription_id, **kwargs) def nack( self, - call_next: Callable, - message, + call_next: Callable[..., None], + message: Any, subscription_id: int | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: call_next(message, subscription_id=subscription_id, **kwargs) def transaction_begin( - self, call_next: Callable, subscription_id: int | None = None, **kwargs + self, + call_next: Callable[..., int], + subscription_id: int | None = None, + **kwargs: Any, ) -> int: return call_next(subscription_id=subscription_id, **kwargs) def transaction_abort( - self, call_next: Callable, transaction_id: int | None = None, **kwargs - ): + self, + call_next: Callable[..., None], + transaction_id: int, + **kwargs: Any, + ) -> None: call_next(transaction_id, **kwargs) def transaction_commit( - self, call_next: Callable, transaction_id: int | None = None, **kwargs - ): + self, + call_next: Callable[..., None], + transaction_id: int, + **kwargs: Any, + ) -> None: call_next(transaction_id, **kwargs) class CounterMiddleware(BaseTransportMiddleware): - def __init__(self): + def __init__(self) -> None: self.subscribe_count = 0 self.subscribe_broadcast_count = 0 self.send_count = 0 @@ -118,131 +184,209 @@ def __init__(self): self.transaction_commit_count = 0 super().__init__() - def subscribe(self, call_next: Callable, channel, callback, **kwargs) -> int: + def subscribe( + self, + call_next: Callable[..., int], + channel: str, + callback: MessageCallback, + *, + disable_mangling: bool = False, + acknowledgement: bool = False, + **kwargs: Any, + ) -> int: self.subscribe_count += 1 logger.info(f"subscribe() count: {self.subscribe_count}") - return call_next(channel, callback, **kwargs) + return call_next( + channel, + callback, + disable_mangling=disable_mangling, + acknowledgement=acknowledgement, + **kwargs, + ) - def send(self, call_next: Callable, destination, message, **kwargs): - call_next(destination, message, **kwargs) + def send( + self, + call_next: Callable[..., None], + destination: str, + message: Any, + *, + headers: dict | None = None, + **kwargs: Any, + ) -> None: + call_next(destination, message, headers=headers, **kwargs) self.send_count += 1 logger.info(f"send() count: {self.send_count}") - def broadcast(self, call_next: Callable, destination, message, **kwargs): + def broadcast( + self, + call_next: Callable[..., None], + destination: str, + message: Any, + **kwargs: Any, + ) -> None: call_next(destination, message, **kwargs) self.broadcast_count += 1 logger.info(f"broadcast() count: {self.broadcast_count}") def ack( self, - call_next: Callable, - message, + call_next: Callable[..., None], + message: Any, subscription_id: int | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: call_next(message, subscription_id=subscription_id, **kwargs) self.ack_count += 1 logger.info(f"ack() count: {self.ack_count}") def nack( self, - call_next: Callable, - message, + call_next: Callable[..., None], + message: Any, subscription_id: int | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: call_next(message, subscription_id=subscription_id, **kwargs) self.nack_count += 1 logger.info(f"nack() count: {self.nack_count}") - def transaction_begin(self, call_next: Callable, *args, **kwargs) -> int: + def transaction_begin( + self, + call_next: Callable[..., int], + subscription_id: int | None = None, + **kwargs: Any, + ) -> int: self.transaction_begin_count += 1 logger.info(f"transaction_begin() count: {self.transaction_begin_count}") - return call_next(*args, **kwargs) + return call_next(subscription_id=subscription_id, **kwargs) - def transaction_abort(self, call_next: Callable, *args, **kwargs): - call_next(*args, **kwargs) + def transaction_abort( + self, + call_next: Callable[..., None], + transaction_id: int, + **kwargs: Any, + ) -> None: + call_next(transaction_id, **kwargs) self.transaction_abort_count += 1 logger.info(f"transaction_abort() count: {self.transaction_abort_count}") - def transaction_commit(self, call_next: Callable, *args, **kwargs): - call_next(*args, **kwargs) + def transaction_commit( + self, + call_next: Callable[..., None], + transaction_id: int, + **kwargs: Any, + ) -> None: + call_next(transaction_id, **kwargs) self.transaction_commit_count += 1 logger.info(f"transaction_commit() count: {self.transaction_commit_count}") class TimerMiddleware(BaseTransportMiddleware): - def __init__(self, logger: logging.Logger | None = None, level=logging.INFO): + def __init__( + self, logger: logging.Logger | None = None, level: int = logging.INFO + ) -> None: if logger is None: logger = logging.getLogger(__name__) self.logger = logger self.level = level - def subscribe(self, call_next: Callable, channel, callback, **kwargs) -> int: + def subscribe( + self, + call_next: Callable[..., int], + channel: str, + callback: MessageCallback, + *, + disable_mangling: bool = False, + acknowledgement: bool = False, + **kwargs: Any, + ) -> int: @functools.wraps(callback) - def wrapped_callback(header, message): + def wrapped_callback(header: Mapping[str, Any], message: Any) -> None: start_time = time.perf_counter() - result = callback(header, message) + callback(header, message) end_time = time.perf_counter() source = get_callback_source(callback) self.logger.log( self.level, f"Callback for {source} took {end_time - start_time:.4f} seconds", ) - return result - return call_next(channel, wrapped_callback, **kwargs) + return call_next( + channel, + wrapped_callback, + disable_mangling=disable_mangling, + acknowledgement=acknowledgement, + **kwargs, + ) def subscribe_temporary( self, - call_next: Callable, + call_next: Callable[..., TemporarySubscription], channel_hint: str | None, callback: MessageCallback, - **kwargs, + *, + disable_mangling: bool = False, + acknowledgement: bool = False, + **kwargs: Any, ) -> TemporarySubscription: @functools.wraps(callback) - def wrapped_callback(header, message): + def wrapped_callback(header: Mapping[str, Any], message: Any) -> None: start_time = time.perf_counter() - result = callback(header, message) + callback(header, message) end_time = time.perf_counter() source = get_callback_source(callback) self.logger.log( self.level, f"Callback for {source} took {end_time - start_time:.4f} seconds", ) - return result - return call_next(channel_hint, wrapped_callback, **kwargs) + return call_next( + channel_hint, + wrapped_callback, + disable_mangling=disable_mangling, + acknowledgement=acknowledgement, + **kwargs, + ) def subscribe_broadcast( - self, call_next: Callable, channel, callback, **kwargs + self, + call_next: Callable[..., int], + channel: str, + callback: MessageCallback, + *, + disable_mangling: bool = False, + **kwargs: Any, ) -> int: @functools.wraps(callback) - def wrapped_callback(header, message): + def wrapped_callback(header: Mapping[str, Any], message: Any) -> None: start_time = time.perf_counter() - result = callback(header, message) + callback(header, message) end_time = time.perf_counter() source = get_callback_source(callback) self.logger.log( self.level, f"Callback for {source} took {end_time - start_time:.4f} seconds", ) - return result - return call_next(channel, wrapped_callback, **kwargs) + return call_next( + channel, + wrapped_callback, + disable_mangling=disable_mangling, + **kwargs, + ) -def wrap(f: Callable): +def wrap[S, **P, R]( + f: Callable[Concatenate[S, P], R], +) -> Callable[Concatenate[S, P], R]: @functools.wraps(f) - def wrapper(self, *args, **kwargs): + def wrapper(self: S, *args: P.args, **kwargs: P.kwargs) -> R: return functools.reduce( lambda call_next, m: ( - lambda *args, **kwargs: getattr(m, f.__name__)( - call_next, *args, **kwargs - ) + lambda *a, **kw: getattr(m, f.__name__)(call_next, *a, **kw) ), - reversed(self.middleware), - lambda *args, **kwargs: f(self, *args, **kwargs), + reversed(self.middleware), # type: ignore[attr-defined] + lambda *a, **kw: f(self, *a, **kw), )(*args, **kwargs) return wrapper diff --git a/src/workflows/transport/middleware/otel_tracing.py b/src/workflows/transport/middleware/otel_tracing.py index aed057ec..97cf5697 100644 --- a/src/workflows/transport/middleware/otel_tracing.py +++ b/src/workflows/transport/middleware/otel_tracing.py @@ -1,28 +1,39 @@ from __future__ import annotations import functools -from collections.abc import Callable +from collections.abc import Callable, Mapping +from typing import Any from opentelemetry import trace from opentelemetry.context import Context from opentelemetry.propagate import extract, inject +from opentelemetry.trace import Span from workflows.transport.common_transport import MessageCallback, TemporarySubscription +from workflows.transport.middleware import BaseTransportMiddleware -class OTELTracingMiddleware: - def __init__(self, tracer: trace.Tracer, service_name: str): +class OTELTracingMiddleware(BaseTransportMiddleware): + def __init__(self, tracer: trace.Tracer, service_name: str) -> None: self.tracer = tracer self.service_name = service_name - def _set_span_attributes(self, span, **attributes): + def _set_span_attributes(self, span: Span, **attributes: Any) -> None: """Helper method to set common span attributes""" span.set_attribute("service_name", self.service_name) for key, value in attributes.items(): if value is not None: span.set_attribute(key, value) - def send(self, call_next: Callable, destination: str, message, **kwargs): + def send( + self, + call_next: Callable[..., None], + destination: str, + message: Any, + *, + headers: dict | None = None, + **kwargs: Any, + ) -> None: # Get current span context (may be None if this is the root span) current_span = trace.get_current_span() parent_context = ( @@ -36,19 +47,24 @@ def send(self, call_next: Callable, destination: str, message, **kwargs): self._set_span_attributes(span, destination=destination) # Inject the current trace context into the message headers - headers = kwargs.get("headers", {}) if headers is None: headers = {} inject(headers) # This modifies headers in-place - kwargs["headers"] = headers - return call_next(destination, message, **kwargs) + call_next(destination, message, headers=headers, **kwargs) def subscribe( - self, call_next: Callable, channel: str, callback: Callable, **kwargs + self, + call_next: Callable[..., int], + channel: str, + callback: MessageCallback, + *, + disable_mangling: bool = False, + acknowledgement: bool = False, + **kwargs: Any, ) -> int: @functools.wraps(callback) - def wrapped_callback(header, message): + def wrapped_callback(header: Mapping[str, Any], message: Any) -> None: # Extract trace context from message headers ctx = extract(header) if header else Context() @@ -63,13 +79,25 @@ def wrapped_callback(header, message): # and potentially call send() which will pick up this context return callback(header, message) - return call_next(channel, wrapped_callback, **kwargs) + return call_next( + channel, + wrapped_callback, + disable_mangling=disable_mangling, + acknowledgement=acknowledgement, + **kwargs, + ) def subscribe_broadcast( - self, call_next: Callable, channel: str, callback: Callable, **kwargs + self, + call_next: Callable[..., int], + channel: str, + callback: MessageCallback, + *, + disable_mangling: bool = False, + **kwargs: Any, ) -> int: @functools.wraps(callback) - def wrapped_callback(header, message): + def wrapped_callback(header: Mapping[str, Any], message: Any) -> None: # Extract trace context from message headers ctx = extract(header) if header else Context() @@ -82,17 +110,25 @@ def wrapped_callback(header, message): return callback(header, message) - return call_next(channel, wrapped_callback, **kwargs) + return call_next( + channel, + wrapped_callback, + disable_mangling=disable_mangling, + **kwargs, + ) def subscribe_temporary( self, - call_next: Callable, + call_next: Callable[..., TemporarySubscription], channel_hint: str | None, callback: MessageCallback, - **kwargs, + *, + disable_mangling: bool = False, + acknowledgement: bool = False, + **kwargs: Any, ) -> TemporarySubscription: @functools.wraps(callback) - def wrapped_callback(header, message): + def wrapped_callback(header: Mapping[str, Any], message: Any) -> None: # Extract trace context from message headers ctx = extract(header) if header else Context() @@ -105,9 +141,21 @@ def wrapped_callback(header, message): return callback(header, message) - return call_next(channel_hint, wrapped_callback, **kwargs) + return call_next( + channel_hint, + wrapped_callback, + disable_mangling=disable_mangling, + acknowledgement=acknowledgement, + **kwargs, + ) - def raw_send(self, call_next: Callable, destination: str, message, **kwargs): + def raw_send( + self, + call_next: Callable[..., None], + destination: str, + message: Any, + **kwargs: Any, + ) -> None: # Get current span context (may be None if this is the root span) current_span = trace.get_current_span() parent_context = ( @@ -127,9 +175,15 @@ def raw_send(self, call_next: Callable, destination: str, message, **kwargs): inject(headers) # This modifies headers in-place kwargs["headers"] = headers - return call_next(destination, message, **kwargs) + call_next(destination, message, **kwargs) - def broadcast(self, call_next: Callable, destination: str, message, **kwargs): + def broadcast( + self, + call_next: Callable[..., None], + destination: str, + message: Any, + **kwargs: Any, + ) -> None: # Get current span context (may be None if this is the root span) current_span = trace.get_current_span() parent_context = ( @@ -149,9 +203,15 @@ def broadcast(self, call_next: Callable, destination: str, message, **kwargs): inject(headers) # This modifies headers in-place kwargs["headers"] = headers - return call_next(destination, message, **kwargs) + call_next(destination, message, **kwargs) - def raw_broadcast(self, call_next: Callable, destination: str, message, **kwargs): + def raw_broadcast( + self, + call_next: Callable[..., None], + destination: str, + message: Any, + **kwargs: Any, + ) -> None: # Get current span context (may be None if this is the root span) current_span = trace.get_current_span() parent_context = ( @@ -171,15 +231,16 @@ def raw_broadcast(self, call_next: Callable, destination: str, message, **kwargs inject(headers) # This modifies headers in-place kwargs["headers"] = headers - return call_next(destination, message, **kwargs) + call_next(destination, message, **kwargs) def unsubscribe( self, - call_next: Callable, + call_next: Callable[..., None], subscription: int, - drop_callback_reference=False, - **kwargs, - ): + *, + drop_callback_reference: bool = False, + **kwargs: Any, + ) -> None: # Get current span context current_span = trace.get_current_span() current_context = ( @@ -198,11 +259,11 @@ def unsubscribe( def ack( self, - call_next: Callable, - message, + call_next: Callable[..., None], + message: Any, subscription_id: int | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: # Get current span context current_span = trace.get_current_span() current_context = ( @@ -219,11 +280,11 @@ def ack( def nack( self, - call_next: Callable, - message, + call_next: Callable[..., None], + message: Any, subscription_id: int | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: # Get current span context current_span = trace.get_current_span() current_context = ( @@ -239,7 +300,10 @@ def nack( call_next(message, subscription_id=subscription_id, **kwargs) def transaction_begin( - self, call_next: Callable, subscription_id: int | None = None, **kwargs + self, + call_next: Callable[..., int], + subscription_id: int | None = None, + **kwargs: Any, ) -> int: """Start a new transaction span""" # Get current span context (may be None if this is the root span) @@ -257,8 +321,11 @@ def transaction_begin( return call_next(subscription_id=subscription_id, **kwargs) def transaction_abort( - self, call_next: Callable, transaction_id: int | None = None, **kwargs - ): + self, + call_next: Callable[..., None], + transaction_id: int, + **kwargs: Any, + ) -> None: """Abort a transaction span""" # Get current span context current_span = trace.get_current_span() @@ -272,11 +339,14 @@ def transaction_abort( ) as span: self._set_span_attributes(span, transaction_id=transaction_id) - call_next(transaction_id=transaction_id, **kwargs) + call_next(transaction_id, **kwargs) def transaction_commit( - self, call_next: Callable, transaction_id: int | None = None, **kwargs - ): + self, + call_next: Callable[..., None], + transaction_id: int, + **kwargs: Any, + ) -> None: """Commit a transaction span""" # Get current span context current_span = trace.get_current_span() @@ -290,4 +360,4 @@ def transaction_commit( ) as span: self._set_span_attributes(span, transaction_id=transaction_id) - call_next(transaction_id=transaction_id, **kwargs) + call_next(transaction_id, **kwargs) diff --git a/src/workflows/transport/middleware/prometheus.py b/src/workflows/transport/middleware/prometheus.py index 6dd57f42..eca2f75b 100644 --- a/src/workflows/transport/middleware/prometheus.py +++ b/src/workflows/transport/middleware/prometheus.py @@ -2,7 +2,8 @@ import functools import time -from collections.abc import Callable +from collections.abc import Callable, Mapping +from typing import Any from prometheus_client import Counter, Gauge, Histogram @@ -79,109 +80,165 @@ class PrometheusMiddleware(BaseTransportMiddleware): - def __init__(self, source: str): + def __init__(self, source: str) -> None: self.source = source - def subscribe(self, call_next: Callable, channel, callback, **kwargs) -> int: + def subscribe( + self, + call_next: Callable[..., int], + channel: str, + callback: MessageCallback, + *, + disable_mangling: bool = False, + acknowledgement: bool = False, + **kwargs: Any, + ) -> int: @functools.wraps(callback) - def wrapped_callback(header, message): + def wrapped_callback(header: Mapping[str, Any], message: Any) -> None: start_time = time.perf_counter() - result = callback(header, message) + callback(header, message) end_time = time.perf_counter() CALLBACK_PROCESSING_TIME.labels( source=get_callback_source(callback) ).observe(end_time - start_time) - return result SUBSCRIPTIONS.labels(source=self.source).inc() ACTIVE_SUBSCRIPTIONS.labels(source=self.source).inc() - return call_next(channel, wrapped_callback, **kwargs) + return call_next( + channel, + wrapped_callback, + disable_mangling=disable_mangling, + acknowledgement=acknowledgement, + **kwargs, + ) def subscribe_temporary( self, - call_next: Callable, + call_next: Callable[..., TemporarySubscription], channel_hint: str | None, callback: MessageCallback, - **kwargs, + *, + disable_mangling: bool = False, + acknowledgement: bool = False, + **kwargs: Any, ) -> TemporarySubscription: @functools.wraps(callback) - def wrapped_callback(header, message): + def wrapped_callback(header: Mapping[str, Any], message: Any) -> None: start_time = time.perf_counter() - result = callback(header, message) + callback(header, message) end_time = time.perf_counter() CALLBACK_PROCESSING_TIME.labels( source=get_callback_source(callback) ).observe(end_time - start_time) - return result TEMPORARY_SUBSCRIPTIONS.labels(source=self.source).inc() ACTIVE_SUBSCRIPTIONS.labels(source=self.source).inc() - return call_next(channel_hint, wrapped_callback, **kwargs) + return call_next( + channel_hint, + wrapped_callback, + disable_mangling=disable_mangling, + acknowledgement=acknowledgement, + **kwargs, + ) def subscribe_broadcast( - self, call_next: Callable, channel, callback, **kwargs + self, + call_next: Callable[..., int], + channel: str, + callback: MessageCallback, + *, + disable_mangling: bool = False, + **kwargs: Any, ) -> int: @functools.wraps(callback) - def wrapped_callback(header, message): + def wrapped_callback(header: Mapping[str, Any], message: Any) -> None: start_time = time.perf_counter() - result = callback(header, message) + callback(header, message) end_time = time.perf_counter() CALLBACK_PROCESSING_TIME.labels( source=get_callback_source(callback) ).observe(end_time - start_time) - return result BROADCAST_SUBSCRIPTIONS.labels(source=self.source).inc() ACTIVE_SUBSCRIPTIONS.labels(source=self.source).inc() - return call_next(channel, wrapped_callback, **kwargs) + return call_next( + channel, + wrapped_callback, + disable_mangling=disable_mangling, + **kwargs, + ) def unsubscribe( self, - call_next: Callable, + call_next: Callable[..., None], subscription: int, - drop_callback_reference=False, - **kwargs, - ): + *, + drop_callback_reference: bool = False, + **kwargs: Any, + ) -> None: ACTIVE_SUBSCRIPTIONS.labels(source=self.source).dec() call_next( subscription, drop_callback_reference=drop_callback_reference, **kwargs ) - def send(self, call_next: Callable, destination, message, **kwargs): + def send( + self, + call_next: Callable[..., None], + destination: str, + message: Any, + *, + headers: dict | None = None, + **kwargs: Any, + ) -> None: SENDS.labels(source=self.source).inc() - call_next(destination, message, **kwargs) + call_next(destination, message, headers=headers, **kwargs) def ack( self, - call_next: Callable, - message, + call_next: Callable[..., None], + message: Any, subscription_id: int | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: ACKS.labels(source=self.source).inc() call_next(message, subscription_id=subscription_id, **kwargs) def nack( self, - call_next: Callable, - message, + call_next: Callable[..., None], + message: Any, subscription_id: int | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: NACKS.labels(source=self.source).inc() call_next(message, subscription_id=subscription_id, **kwargs) - def transaction_begin(self, call_next: Callable, *args, **kwargs) -> int: + def transaction_begin( + self, + call_next: Callable[..., int], + subscription_id: int | None = None, + **kwargs: Any, + ) -> int: TRANSACTION_BEGIN.labels(source=self.source).inc() TRANSACTIONS_IN_PROGRESS.labels(source=self.source).inc() - return call_next(*args, **kwargs) + return call_next(subscription_id=subscription_id, **kwargs) - def transaction_abort(self, call_next: Callable, *args, **kwargs): + def transaction_abort( + self, + call_next: Callable[..., None], + transaction_id: int, + **kwargs: Any, + ) -> None: TRANSACTION_ABORT.labels(source=self.source).inc() TRANSACTIONS_IN_PROGRESS.labels(source=self.source).dec() - call_next(*args, **kwargs) + call_next(transaction_id, **kwargs) - def transaction_commit(self, call_next: Callable, *args, **kwargs): + def transaction_commit( + self, + call_next: Callable[..., None], + transaction_id: int, + **kwargs: Any, + ) -> None: TRANSACTION_COMMIT.labels(source=self.source).inc() TRANSACTIONS_IN_PROGRESS.labels(source=self.source).dec() - call_next(*args, **kwargs) + call_next(transaction_id, **kwargs) diff --git a/src/workflows/transport/offline_transport.py b/src/workflows/transport/offline_transport.py index b83b749a..c877d2db 100644 --- a/src/workflows/transport/offline_transport.py +++ b/src/workflows/transport/offline_transport.py @@ -6,6 +6,7 @@ import logging import pprint import uuid +from collections.abc import Mapping from typing import Any import workflows.util @@ -28,30 +29,36 @@ class OfflineTransport(CommonTransport): config: dict[Any, Any] = {} def __init__( - self, middleware: list[type[middleware.BaseTransportMiddleware]] | None = None + self, middleware: list[middleware.BaseTransportMiddleware] | None = None ): self._connected = False super().__init__(middleware=middleware) - def connect(self): + def connect(self) -> bool: self._connected = True return True - def is_connected(self): + def is_connected(self) -> bool: return self._connected - def disconnect(self): + def disconnect(self) -> None: self._connected = False - def _output(self, message, details=None): + def _output(self, message: str, details: Any = None) -> None: _offlog.info(f"Offline Transport: {message}") if details: _offlog.debug(details) - def broadcast_status(self, status): + def broadcast_status(self, status: Mapping) -> None: self._output("Writing status message", pprint.pformat(status)) - def _subscribe(self, sub_id, channel, callback, **kwargs): + def _subscribe( + self, + sub_id: int, + channel: str, + callback: MessageCallback, + **kwargs: Any, + ) -> None: self._output( f"Subscribing to messages on {channel}", f"subscription ID {sub_id}, callback function {callback}, further keywords: {kwargs}", @@ -62,7 +69,7 @@ def _subscribe_temporary( sub_id: int, channel_hint: str | None, callback: MessageCallback, - **kwargs, + **kwargs: Any, ) -> str: channel = channel_hint or workflows.util.generate_unique_host_id() channel = channel + "." + str(uuid.uuid4()) @@ -72,46 +79,62 @@ def _subscribe_temporary( self._subscribe(sub_id, channel, callback, **kwargs) return channel - def _subscribe_broadcast(self, sub_id, channel, callback, **kwargs): + def _subscribe_broadcast( + self, + sub_id: int, + channel: str, + callback: MessageCallback, + **kwargs: Any, + ) -> None: self._output( f"Subscribing to broadcasts on {channel}", f"subscription ID {sub_id}, callback function {callback}, further keywords: {kwargs}", ) - def _unsubscribe(self, subscription, **kwargs): - self._output( - f"Ending subscription #{subscription}", f"further keywords: {kwargs}" - ) + def _unsubscribe(self, sub_id: int, **kwargs: Any) -> None: + self._output(f"Ending subscription #{sub_id}", f"further keywords: {kwargs}") def _send( - self, destination, message, headers=None, delay=None, expiration=None, **kwargs - ): + self, + destination: str, + message: Any, + headers: dict | None = None, + delay: float | None = None, + expiration: int | None = None, + **kwargs: Any, + ) -> None: self._output(f"Sending {len(message)} bytes to {destination}", message) def _broadcast( - self, destination, message, headers=None, delay=None, expiration=None, **kwargs - ): + self, + destination: str, + message: Any, + headers: dict | None = None, + delay: float | None = None, + expiration: int | None = None, + **kwargs: Any, + ) -> None: self._output(f"Broadcasting {len(message)} bytes to {destination}", message) - def _transaction_begin(self, transaction_id, **kwargs): + def _transaction_begin(self, transaction_id: int, **kwargs: Any) -> None: self._output(f"Starting transaction {transaction_id}") - def _transaction_abort(self, transaction_id, **kwargs): + def _transaction_abort(self, transaction_id: int, **kwargs: Any) -> None: self._output(f"Rolling back transaction {transaction_id}") - def _transaction_commit(self, transaction_id, **kwargs): + def _transaction_commit(self, transaction_id: int, **kwargs: Any) -> None: self._output(f"Committing transaction {transaction_id}") - def _ack(self, message_id, subscription_id, **kwargs): + def _ack(self, message_id: Any, subscription_id: int, **kwargs: Any) -> None: self._output( f"Acknowledging message {message_id} in subscription {subscription_id}" ) - def _nack(self, message_id, subscription_id, **kwargs): + def _nack(self, message_id: Any, subscription_id: int, **kwargs: Any) -> None: self._output( f"Rejecting message {message_id} in subscription {subscription_id}" ) @staticmethod - def _mangle_for_sending(message): + def _mangle_for_sending(message: Any) -> Any: return json.dumps(message, default=json_serializer) diff --git a/src/workflows/transport/pika_transport.py b/src/workflows/transport/pika_transport.py index 84a30d48..4cbd7063 100644 --- a/src/workflows/transport/pika_transport.py +++ b/src/workflows/transport/pika_transport.py @@ -5,17 +5,23 @@ import functools import json import logging +import optparse +import os import random import sys import threading import time import uuid +from argparse import ArgumentParser, Namespace from collections.abc import Callable, Iterable from concurrent.futures import Future from enum import Enum, auto -from typing import Any +from optparse import OptionParser +from typing import Any, Mapping +import pika.channel import pika.exceptions +import pika.spec from bidict import bidict from pika.adapters.blocking_connection import BlockingChannel @@ -63,7 +69,7 @@ class PikaTransport(CommonTransport): config: dict[Any, Any] = {} def __init__( - self, middleware: list[type[middleware.BaseTransportMiddleware]] | None = None + self, middleware: list[middleware.BaseTransportMiddleware] | None = None ): self._channel = None self._conn = None @@ -80,7 +86,7 @@ def get_namespace(self) -> str: return self._vhost @classmethod - def load_configuration_file(cls, filename): + def load_configuration_file(cls, filename: str | os.PathLike[str]) -> None: cfgparser = configparser.ConfigParser(allow_no_value=True) if not cfgparser.read(filename): raise workflows.Error( @@ -99,15 +105,15 @@ def load_configuration_file(cls, filename): pass @classmethod - def add_command_line_options(cls, parser): + def add_command_line_options(cls, parser: ArgumentParser | OptionParser) -> None: """Function to inject command line parameters""" - if "add_argument" in dir(parser): + if isinstance(parser, ArgumentParser): return cls.add_command_line_options_argparse(parser) else: return cls.add_command_line_options_optparse(parser) @classmethod - def add_command_line_options_argparse(cls, argparser): + def add_command_line_options_argparse(cls, argparser: ArgumentParser) -> None: """Function to inject command line parameters into a Python ArgumentParser.""" import argparse @@ -115,7 +121,13 @@ def add_command_line_options_argparse(cls, argparser): class SetParameter(argparse.Action): """callback object for ArgumentParser""" - def __call__(self, parser, namespace, value, option_string=None): + def __call__( + self, + parser: ArgumentParser, + namespace: Namespace, + value: Any, + option_string: str | None = None, + ) -> None: cls.config[option_string] = value if option_string == "--rabbit-conf": cls.load_configuration_file(value) @@ -170,11 +182,13 @@ def __call__(self, parser, namespace, value, option_string=None): ) @classmethod - def add_command_line_options_optparse(cls, optparser): + def add_command_line_options_optparse(cls, optparser: OptionParser) -> None: """function to inject command line parameters into a Python OptionParser.""" - def set_parameter(option, opt, value, parser): + def set_parameter( + option: optparse.Option, opt: str, value: Any, parser: optparse.OptionParser + ) -> None: """callback function for OptionParser""" cls.config[opt] = value if opt == "--rabbit-conf": @@ -242,17 +256,15 @@ def set_parameter(option, opt, value, parser): ) def _generate_connection_parameters(self) -> list[pika.ConnectionParameters]: - username = self.config.get("--rabbit-user", self.defaults.get("--rabbit-user")) - password = self.config.get("--rabbit-pass", self.defaults.get("--rabbit-pass")) + username = self.config.get("--rabbit-user", self.defaults["--rabbit-user"]) + password = self.config.get("--rabbit-pass", self.defaults["--rabbit-pass"]) credentials = pika.PlainCredentials(username, password) - host_string = self.config.get( - "--rabbit-host", self.defaults.get("--rabbit-host") - ) + host_string = self.config.get("--rabbit-host", self.defaults["--rabbit-host"]) port_string = str( - self.config.get("--rabbit-port", self.defaults.get("--rabbit-port")) + self.config.get("--rabbit-port", self.defaults["--rabbit-port"]) ) - vhost = self.config.get("--rabbit-vhost", self.defaults.get("--rabbit-vhost")) + vhost = self.config.get("--rabbit-vhost", self.defaults["--rabbit-vhost"]) if "," in host_string: host = host_string.split(",") else: @@ -305,11 +317,11 @@ def is_connected(self) -> bool: # Surely .connection_alive is (slightly) better? return hasattr(self, "_pika_thread") and self._pika_thread.connection_alive - def disconnect(self): + def disconnect(self) -> None: """Gracefully close connection to pika server""" self._pika_thread.join(stop=True) - def broadcast_status(self, status): + def broadcast_status(self, status: Mapping) -> None: """Broadcast transient status information to all listeners""" # Basic status checks - this is based on behaviour of status_monitor @@ -324,11 +336,11 @@ def broadcast_status(self, status): def _call_message_callback( self, subscription_id: int, - _channel: pika.channel.Channel, + _channel: BlockingChannel, method: pika.spec.Basic.Deliver, properties: pika.spec.BasicProperties, body: bytes, - ): + ) -> None: """Rewrite and redirect a pika callback to the subscription function""" merged_headers = dict(properties.headers or {}) merged_headers.update( @@ -357,8 +369,8 @@ def _subscribe( acknowledgement: bool = False, prefetch_count: int = 1, reconnectable: bool = False, - **_kwargs, - ): + **_kwargs: Any, + ) -> None: """ Listen to a queue, notify via callback function. @@ -395,7 +407,9 @@ def _subscribe( try: return self._pika_thread.subscribe_queue( queue=channel, - callback=functools.partial(self._call_message_callback, sub_id), + callback=lambda ch, m, p, b: self._call_message_callback( + sub_id, ch, m, p, b + ), auto_ack=not acknowledgement, subscription_id=sub_id, reconnectable=reconnectable, @@ -414,8 +428,8 @@ def _subscribe_broadcast( callback: MessageCallback, *, reconnectable: bool = False, - **_kwargs, - ): + **_kwargs: Any, + ) -> None: """ Listen to a FANOUT exchange, notify via callback function. @@ -445,7 +459,7 @@ def _subscribe_temporary( callback: MessageCallback, *, acknowledgement: bool = False, - **kwargs, + **kwargs: Any, ) -> str: """ Create and then listen to a temporary queue, notify via callback function. @@ -480,9 +494,11 @@ def _subscribe_temporary( ) as e: raise workflows.Disconnected(e) - def _unsubscribe(self, sub_id: int, **kwargs): - """Stop listening to a queue - :param sub_id: Consumer Tag to cancel + def _unsubscribe(self, sub_id: int, **kwargs: Any) -> None: + """Stop listening to a queue. + + Args: + sub_id: Consumer Tag to cancel """ self._pika_thread.unsubscribe(sub_id) # self._channel.basic_cancel(consumer_tag=consumer_tag, callback=None) @@ -490,15 +506,15 @@ def _unsubscribe(self, sub_id: int, **kwargs): def _send( self, - destination, - message, - headers=None, - delay=None, - expiration=None, + destination: str, + message: Any, + headers: dict[str, Any] | None = None, + delay: float | None = None, + expiration: float | None = None, transaction: int | None = None, exchange: str | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: """ Send a message to a queue. @@ -537,14 +553,14 @@ def _send( def _broadcast( self, - destination, - message, - headers=None, - delay=None, + destination: str, + message: Any, + headers: dict[str, Any] | None = None, + delay: float | None = None, expiration: int | None = None, transaction: int | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: """Send a message to a fanout exchange. Args: @@ -580,29 +596,40 @@ def _broadcast( ).result() def _transaction_begin( - self, transaction_id: int, *, subscription_id: int | None = None, **kwargs + self, transaction_id: int, *, subscription_id: int | None = None, **kwargs: Any ) -> None: """Start a new transaction. - :param transaction_id: ID for this transaction in the transport layer. - :param subscription_id: Tie the transaction to a specific channel containing this subscription. + + Args: + transaction_id: ID for this transaction in the transport layer. + subscription_id: Tie the transaction to a specific channel containing this subscription. """ self._pika_thread.tx_select(transaction_id, subscription_id).result() - def _transaction_abort(self, transaction_id: int, **kwargs) -> None: + def _transaction_abort(self, transaction_id: int, **kwargs: Any) -> None: """Abort a transaction and roll back all operations. - :param transaction_id: ID of transaction to be aborted. + + Args: + transaction_id: ID of transaction to be aborted. """ self._pika_thread.tx_rollback(transaction_id).result() - def _transaction_commit(self, transaction_id: int, **kwargs) -> None: + def _transaction_commit(self, transaction_id: int, **kwargs: Any) -> None: """Commit a transaction. - :param transaction_id: ID of transaction to be committed. + + Args: + transaction_id: ID of transaction to be committed. """ self._pika_thread.tx_commit(transaction_id).result() def _ack( - self, message_id, subscription_id: int, *, multiple: bool = False, **_kwargs - ): + self, + message_id: int, + subscription_id: int, + *, + multiple: bool = False, + **_kwargs: Any, + ) -> None: """ Acknowledge receipt of a message. @@ -614,8 +641,7 @@ def _ack( subscription_id: Internal id for the subscription this message came from multiple: Should multiple messages be acknowledged? - - :param **kwargs: Further parameters for the transport layer. + **kwargs: Further parameters for the transport layer. """ self._pika_thread.ack( message_id, @@ -626,13 +652,13 @@ def _ack( def _nack( self, - message_id, + message_id: int, subscription_id: int, *, multiple: bool = False, requeue: bool = True, - **_kwargs, - ): + **_kwargs: Any, + ) -> None: """ Reject receipt of a message. @@ -655,7 +681,7 @@ def _nack( ) @staticmethod - def _mangle_for_sending(message): + def _mangle_for_sending(message: Any) -> str: """Function that any message will pass through before it being forwarded to the actual _send* functions. Pika only deals with serialized strings, so serialize message as json. @@ -663,7 +689,7 @@ def _mangle_for_sending(message): return json.dumps(message, default=json_serializer) @staticmethod - def _mangle_for_receiving(message): + def _mangle_for_receiving(message: str | bytes | bytearray) -> Any: """Function that any message will pass through before it being forwarded to the receiving subscribed callback functions. This transport class only deals with serialized strings, so decode @@ -684,11 +710,11 @@ class _PikaThreadStatus(Enum): STOPPED = auto() @property - def is_new(self): + def is_new(self) -> bool: return self is self.NEW @property - def is_end_of_life(self): + def is_end_of_life(self) -> bool: return self in {self.STOPPING, self.STOPPED} @@ -757,7 +783,7 @@ class _PikaThread(threading.Thread): def __init__( self, connection_parameters: Iterable[pika.ConnectionParameters], - reconnection_attempts=5, + reconnection_attempts: int = 5, ): super().__init__(name="workflows pika_transport", daemon=True, target=self._run) self._state: _PikaThreadStatus = _PikaThreadStatus.NEW @@ -803,7 +829,7 @@ def state(self) -> _PikaThreadStatus: """Read the current connection state""" return self._state - def stop(self): + def stop(self) -> None: """ Request termination, including disconnection and cleanup if necessary. @@ -830,7 +856,13 @@ def stop(self): except pika.exceptions.ConnectionWrongStateError: pass - def join(self, timeout: float | None = None, *, re_raise: bool = False, stop=False): + def join( + self, + timeout: float | None = None, + *, + re_raise: bool = False, + stop: bool = False, + ) -> None: """Wait until the thread terminates. Args: @@ -852,7 +884,7 @@ def join(self, timeout: float | None = None, *, re_raise: bool = False, stop=Fal if re_raise: self.raise_if_exception() - def wait_for_connection(self, timeout=None): + def wait_for_connection(self, timeout: float | None = None) -> None: """ Safely wait until the thread has connected and is communicating with the server. @@ -865,7 +897,7 @@ def wait_for_connection(self, timeout=None): self._connected.wait(timeout) self.raise_if_exception() - def raise_if_exception(self): + def raise_if_exception(self) -> None: """If the thread has failed with an exception, raise it in the callers thread.""" exception = self._exc_info if exception: @@ -996,7 +1028,7 @@ def subscribe_temporary( result: Future[str] = Future() - def _declare_subscribe_queue_in_thread(): + def _declare_subscribe_queue_in_thread() -> None: try: if result.set_running_or_notify_cancel(): assert subscription_id not in self._subscriptions, ( @@ -1038,7 +1070,7 @@ def unsubscribe(self, subscription_id: int) -> Future[None]: result: Future[None] = Future() - def _unsubscribe(): + def _unsubscribe() -> None: try: if result.set_running_or_notify_cancel(): logger.debug("Unsubscribing from subscription %d", subscription_id) @@ -1067,7 +1099,7 @@ def send( exchange: str, routing_key: str, body: str | bytes, - properties: pika.spec.BasicProperties = None, + properties: pika.spec.BasicProperties | None = None, mandatory: bool = True, transaction_id: int | None = None, ) -> Future[None]: @@ -1078,7 +1110,7 @@ def send( future: Future[None] = Future() - def _send(): + def _send() -> None: if future.set_running_or_notify_cancel(): try: if transaction_id: @@ -1106,9 +1138,9 @@ def ack( delivery_tag: int, subscription_id: int, *, - multiple=False, + multiple: bool = False, transaction_id: int | None, - ): + ) -> None: if subscription_id not in self._subscriptions: raise KeyError(f"Could not find subscription {subscription_id} to ACK") @@ -1125,7 +1157,7 @@ def ack( ) if transaction_id is None and not self._channel_has_active_tx.get(channel): - def _ack_callback(): + def _ack_callback() -> None: channel.basic_ack(delivery_tag, multiple=multiple) if self._channel_is_transactional.get(channel): channel.tx_commit() @@ -1144,10 +1176,10 @@ def nack( delivery_tag: int, subscription_id: int, *, - multiple=False, - requeue=True, + multiple: bool = False, + requeue: bool = True, transaction_id: int | None, - ): + ) -> None: if subscription_id not in self._subscriptions: raise KeyError(f"Could not find subscription {subscription_id} to NACK") @@ -1164,7 +1196,7 @@ def nack( ) if transaction_id is None and not self._channel_has_active_tx.get(channel): - def _nack_callback(): + def _nack_callback() -> None: channel.basic_nack(delivery_tag, multiple=multiple, requeue=requeue) if self._channel_is_transactional.get(channel): channel.tx_commit() @@ -1184,8 +1216,10 @@ def tx_select( self, transaction_id: int, subscription_id: int | None ) -> Future[None]: """Set a channel to transaction mode. Thread-safe. - :param transaction_id: ID for this transaction in the transport layer. - :param subscription_id: Tie the transaction to a specific channel containing this subscription. + + Args: + transaction_id: ID for this transaction in the transport layer. + subscription_id: Tie the transaction to a specific channel containing this subscription. """ if not self._connection: @@ -1193,7 +1227,7 @@ def tx_select( future: Future[None] = Future() - def _tx_select(): + def _tx_select() -> None: if future.set_running_or_notify_cancel(): try: if subscription_id: @@ -1228,14 +1262,16 @@ def _tx_select(): def tx_rollback(self, transaction_id: int) -> Future[None]: """Abort a transaction and roll back all operations. Thread-safe. - :param transaction_id: ID of transaction to be aborted. + + Args: + transaction_id: ID of transaction to be aborted. """ if not self._connection: raise RuntimeError("Cannot transact on unstarted connection") future: Future[None] = Future() - def _tx_rollback(): + def _tx_rollback() -> None: if future.set_running_or_notify_cancel(): try: channel = self._transaction_on_channel.inverse.pop( @@ -1256,15 +1292,17 @@ def _tx_rollback(): return future def tx_commit(self, transaction_id: int) -> Future[None]: - """Commit a transaction. - :param transaction_id: ID of transaction to be committed. Thread-safe.. + """Commit a transaction. Thread-safe. + + Args: + transaction_id: ID of transaction to be committed. """ if not self._connection: raise RuntimeError("Cannot transact on unstarted connection") future: Future[None] = Future() - def _tx_commit(): + def _tx_commit() -> None: if future.set_running_or_notify_cancel(): try: channel = self._transaction_on_channel.inverse.pop( @@ -1317,8 +1355,10 @@ def connection_alive(self) -> bool: #################################################################### # PikaThread Internal methods - def _debug_close_connection(self): - self._connection.add_callback_threadsafe(lambda: self._connection.close()) + def _debug_close_connection(self) -> None: + connection = self._connection + assert connection is not None + connection.add_callback_threadsafe(lambda: connection.close()) def _get_shared_channel(self) -> BlockingChannel: """Get the shared (no prefetch) channel. Create if necessary.""" @@ -1330,7 +1370,7 @@ def _get_shared_channel(self) -> BlockingChannel: ##### self._pika_shared_channel.confirm_delivery() return self._pika_shared_channel - def _recreate_subscriptions(self): + def _recreate_subscriptions(self) -> None: """Resubscribe all existing subscriptions""" old_subscriptions = self._subscriptions self._subscriptions = {} @@ -1349,7 +1389,9 @@ def _recreate_subscriptions(self): f"Subscriptions recreated. Reconnections allowed? - {'Yes' if self._reconnection_allowed else 'No.'}" ) - def _add_subscription(self, subscription_id: int, subscription: _PikaSubscription): + def _add_subscription( + self, subscription_id: int, subscription: _PikaSubscription + ) -> None: assert self._connection is not None assert subscription_id not in self._subscriptions @@ -1386,10 +1428,10 @@ def _add_subscription(self, subscription_id: int, subscription: _PikaSubscriptio self._subscriptions[subscription_id] = subscription logger.debug("Consuming (%d) on %s", subscription_id, subscription.queue) - def _run(self): + def _run(self) -> None: if self._please_stop.is_set(): # stop() was called before start()... so quit - self._state == _PikaThreadStatus.STOPPED + self._state == _PikaThreadStatus.STOPPED # type: ignore return assert self._state == _PikaThreadStatus.NEW assert self._reconnection_allowed, "Should be true until first synchronize" @@ -1437,7 +1479,7 @@ def _run(self): connection_counter += 1 # Clear the channels because this might be a reconnect - self._pika_channels = {} + self._pika_channels = bidict() self._pika_shared_channel = None self._transaction_on_channel = bidict() self._channel_has_active_tx = {} @@ -1461,7 +1503,7 @@ def _run(self): # Run until we are asked to stop, or fail while not self._please_stop.is_set(): - self._connection.process_data_events(None) + self._connection.process_data_events(None) # type: ignore except pika.exceptions.ConnectionClosed: self._exc_info = sys.exc_info() if self._please_stop.is_set(): @@ -1488,6 +1530,7 @@ def _run(self): self._exc_info = sys.exc_info() break # Make sure our connection is closed before reconnecting + assert self._connection is not None if not self._connection.is_closed: logger.info("Connection not closed. Closing.") self._connection.close() @@ -1535,7 +1578,7 @@ def _add_subscription_in_thread( subscription_id: int, subscription: _PikaSubscription, result: Future, - ): + ) -> None: """ Add a subscription to the pika connection. diff --git a/src/workflows/transport/stomp_transport.py b/src/workflows/transport/stomp_transport.py index 072cacad..5618c456 100644 --- a/src/workflows/transport/stomp_transport.py +++ b/src/workflows/transport/stomp_transport.py @@ -1,13 +1,18 @@ from __future__ import annotations +import argparse import configparser import json +import optparse import threading import time import uuid +from collections.abc import Mapping from typing import Any import stomp +import stomp.exception +import stomp.utils import workflows.util from workflows.transport import middleware @@ -34,7 +39,7 @@ class StompTransport(CommonTransport): config: dict[Any, Any] = {} def __init__( - self, middleware: list[type[middleware.BaseTransportMiddleware]] | None = None + self, middleware: list[middleware.BaseTransportMiddleware] | None = None ): self._connected = False self._namespace = "" @@ -44,9 +49,9 @@ def __init__( # self._stomp_listener = stomp.PrintingListener() self._stomp_listener.on_message = self._on_message self._stomp_listener.on_before_message = lambda frame: frame - super().__init__() + super().__init__(middleware) - def get_namespace(self): + def get_namespace(self) -> str: """Return the stomp namespace. This is a prefix used for all topic and queue names.""" if self._namespace.endswith("."): @@ -54,7 +59,7 @@ def get_namespace(self): return self._namespace @classmethod - def load_configuration_file(cls, filename): + def load_configuration_file(cls, filename: str) -> None: cfgparser = configparser.ConfigParser(allow_no_value=True) if not cfgparser.read(filename): raise workflows.Error( @@ -73,23 +78,32 @@ def load_configuration_file(cls, filename): pass @classmethod - def add_command_line_options(cls, parser): + def add_command_line_options( + cls, parser: argparse.ArgumentParser | optparse.OptionParser + ) -> None: """function to inject command line parameters""" - if "add_argument" in dir(parser): + if isinstance(parser, argparse.ArgumentParser): return cls.add_command_line_options_argparse(parser) else: return cls.add_command_line_options_optparse(parser) @classmethod - def add_command_line_options_argparse(cls, argparser): + def add_command_line_options_argparse( + cls, argparser: argparse.ArgumentParser + ) -> None: """function to inject command line parameters into a Python ArgumentParser.""" - import argparse class SetParameter(argparse.Action): """callback object for ArgumentParser""" - def __call__(self, parser, namespace, value, option_string=None): + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + value: Any, + option_string: str | None = None, + ) -> None: cls.config[option_string] = value if option_string == "--stomp-conf": cls.load_configuration_file(value) @@ -144,11 +158,18 @@ def __call__(self, parser, namespace, value, option_string=None): ) @classmethod - def add_command_line_options_optparse(cls, optparser): + def add_command_line_options_optparse( + cls, optparser: optparse.OptionParser + ) -> None: """function to inject command line parameters into a Python OptionParser.""" - def set_parameter(option, opt, value, parser): + def set_parameter( + option: optparse.Option, + opt: str, + value: Any, + parser: optparse.OptionParser, + ) -> None: """callback function for OptionParser""" cls.config[opt] = value if opt == "--stomp-conf": @@ -215,7 +236,7 @@ def set_parameter(option, opt, value, parser): callback=set_parameter, ) - def connect(self): + def connect(self) -> bool: with self._lock: if self._connected: return True @@ -227,7 +248,7 @@ def connect(self): ), int( self.config.get( - "--stomp-port", self.defaults.get("--stomp-port") + "--stomp-port", self.defaults["--stomp-port"] ) ), ) @@ -265,25 +286,25 @@ def connect(self): "Could not initiate connection to stomp host" ) self._namespace = self.config.get( - "--stomp-prfx", self.defaults.get("--stomp-prfx") + "--stomp-prfx", self.defaults["--stomp-prfx"] ) if self._namespace and not self._namespace.endswith("."): self._namespace = self._namespace + "." self._connected = True return True - def is_connected(self): + def is_connected(self) -> bool: """Return connection status""" self._connected = self._connected and self._conn.is_connected() return self._connected - def disconnect(self): + def disconnect(self) -> None: """Gracefully close connection to stomp server.""" if self._connected: self._connected = False self._conn.disconnect() - def broadcast_status(self, status): + def broadcast_status(self, status: Mapping) -> None: """Broadcast transient status information to all listeners""" self._broadcast( "transient.status", @@ -291,21 +312,29 @@ def broadcast_status(self, status): headers={"expires": str(int((15 + time.time()) * 1000))}, ) - def _subscribe(self, sub_id, channel, callback, **kwargs): + def _subscribe( + self, + sub_id: int, + channel: str, + callback: MessageCallback, + **kwargs: Any, + ) -> None: """Listen to a queue, notify via callback function. - :param sub_id: ID for this subscription in the transport layer - :param channel: Queue name to subscribe to - :param callback: Function to be called when messages are received - :param **kwargs: Further parameters for the transport layer. For example - acknowledgement: If true receipt of each message needs to be - acknowledged. - exclusive: Attempt to become exclusive subscriber to the queue. - ignore_namespace: Do not apply namespace to the destination name - priority: Consumer priority, messages are sent to higher - priority consumers whenever possible. - selector: Only receive messages filtered by a selector. See - https://activemq.apache.org/activemq-message-properties.html - for potential filter criteria. Uses SQL 92 syntax. + + Args: + sub_id: ID for this subscription in the transport layer + channel: Queue name to subscribe to + callback: Function to be called when messages are received + **kwargs: + Further parameters for the transport layer. For example: + acknowledgement: If true receipt of each message needs to be acknowledged. + exclusive: Attempt to become exclusive subscriber to the queue. + ignore_namespace: Do not apply namespace to the destination name. + priority: Consumer priority, messages are sent to higher priority consumers + whenever possible. + selector: Only receive messages filtered by a selector. See + https://activemq.apache.org/activemq-message-properties.html + for potential filter criteria. Uses SQL 92 syntax. """ headers = {} if kwargs.get("exclusive"): @@ -327,14 +356,23 @@ def _subscribe(self, sub_id, channel, callback, **kwargs): self._conn.subscribe(destination, sub_id, headers=headers, ack=ack) - def _subscribe_broadcast(self, sub_id, channel, callback, **kwargs): + def _subscribe_broadcast( + self, + sub_id: int, + channel: str, + callback: MessageCallback, + **kwargs: Any, + ) -> None: """Listen to a broadcast topic, notify via callback function. - :param sub_id: ID for this subscription in the transport layer - :param channel: Topic name to subscribe to - :param callback: Function to be called when messages are received - :param **kwargs: Further parameters for the transport layer. For example - ignore_namespace: Do not apply namespace to the destination name - retroactive: Ask broker to send old messages if possible + + Args: + sub_id: ID for this subscription in the transport layer + channel: Topic name to subscribe to + callback: Function to be called when messages are received + **kwargs: + Further parameters for the transport layer. For example: + ignore_namespace: Do not apply namespace to the destination name. + retroactive: Ask broker to send old messages if possible. """ headers = {} if kwargs.get("ignore_namespace"): @@ -350,15 +388,18 @@ def _subscribe_temporary( sub_id: int, channel_hint: str | None, callback: MessageCallback, - **kwargs, + **kwargs: Any, ) -> str: """Create and then listen to a temporary queue, notify via callback function. - :param sub_id: ID for this subscription in the transport layer - :param channel_hint: Name suggestion for the temporary queue - :param callback: Function to be called when messages are received - :param **kwargs: Further parameters for the transport layer. - See _subscribe() above. - :returns: The name of the temporary queue + + Args: + sub_id: ID for this subscription in the transport layer + channel_hint: Name suggestion for the temporary queue + callback: Function to be called when messages are received + **kwargs: Further parameters for the transport layer. See _subscribe(). + + Returns: + The name of the temporary queue """ channel = channel_hint or workflows.util.generate_unique_host_id() @@ -370,28 +411,38 @@ def _subscribe_temporary( return channel - def _unsubscribe(self, subscription, **kwargs): - """Stop listening to a queue or a broadcast - :param subscription: Subscription ID to cancel + def _unsubscribe(self, sub_id: int, **kwargs: Any) -> None: + """Stop listening to a queue or a broadcast. + + Args: + sub_id: Subscription ID to cancel """ - self._conn.unsubscribe(id=subscription) + self._conn.unsubscribe(id=sub_id) # Callback reference is kept as further messages may already have been received def _send( - self, destination, message, headers=None, delay=None, expiration=None, **kwargs - ): + self, + destination: str, + message: Any, + headers: dict | None = None, + delay: float | None = None, + expiration: int | None = None, + **kwargs: Any, + ) -> None: """Send a message to a queue. - :param destination: Queue name to send to - :param message: A string to be sent - :param **kwargs: Further parameters for the transport layer. For example - delay: Delay transport of message by this many seconds - expiration: Optional expiration time, relative to sending time - headers: Optional dictionary of header entries - ignore_namespace: Do not apply namespace to the destination name - persistent: Whether to mark messages as persistent, to be kept - between broker restarts. Default is 'true'. - transaction: Transaction ID if message should be part of a - transaction + + Args: + destination: Queue name to send to + message: A string to be sent + **kwargs: + Further parameters for the transport layer. For example: + delay: Delay transport of message by this many seconds. + expiration: Optional expiration time, relative to sending time. + headers: Optional dictionary of header entries. + ignore_namespace: Do not apply namespace to the destination name. + persistent: Whether to mark messages as persistent, to be kept + between broker restarts. Default is 'true'. + transaction: Transaction ID if message should be part of a transaction. """ if not headers: headers = {} @@ -414,18 +465,26 @@ def _send( raise workflows.Disconnected("No connection to stomp host") def _broadcast( - self, destination, message, headers=None, delay=None, expiration=None, **kwargs - ): + self, + destination: str, + message: Any, + headers: dict | None = None, + delay: float | None = None, + expiration: int | None = None, + **kwargs: Any, + ) -> None: """Broadcast a message. - :param destination: Topic name to send to - :param message: A string to be broadcast - :param **kwargs: Further parameters for the transport layer. For example - delay: Delay transport of message by this many seconds - expiration: Optional expiration time, relative to sending time - headers: Optional dictionary of header entries - ignore_namespace: Do not apply namespace to the destination name - transaction: Transaction ID if message should be part of a - transaction + + Args: + destination: Topic name to send to + message: A string to be broadcast + **kwargs: + Further parameters for the transport layer. For example: + delay: Delay transport of message by this many seconds. + expiration: Optional expiration time, relative to sending time. + headers: Optional dictionary of header entries. + ignore_namespace: Do not apply namespace to the destination name. + transaction: Transaction ID if message should be part of a transaction. """ if not headers: headers = {} @@ -443,48 +502,59 @@ def _broadcast( self._connected = False raise workflows.Disconnected("No connection to stomp host") - def _transaction_begin(self, transaction_id, **kwargs): + def _transaction_begin(self, transaction_id: int, **kwargs: Any) -> None: """Start a new transaction. - :param transaction_id: ID for this transaction in the transport layer. + + Args: + transaction_id: ID for this transaction in the transport layer. """ self._conn.begin(transaction=transaction_id) - def _transaction_abort(self, transaction_id, **kwargs): + def _transaction_abort(self, transaction_id: int, **kwargs: Any) -> None: """Abort a transaction and roll back all operations. - :param transaction_id: ID of transaction to be aborted. + + Args: + transaction_id: ID of transaction to be aborted. """ self._conn.abort(transaction_id) - def _transaction_commit(self, transaction_id, **kwargs): + def _transaction_commit(self, transaction_id: int, **kwargs: Any) -> None: """Commit a transaction. - :param transaction_id: ID of transaction to be committed. + + Args: + transaction_id: ID of transaction to be committed. """ self._conn.commit(transaction_id) - def _ack(self, message_id, subscription_id, **kwargs): + def _ack(self, message_id: Any, subscription_id: int, **kwargs: Any) -> None: """Acknowledge receipt of a message. This only makes sense when the 'acknowledgement' flag was set for the relevant subscription. - :param message_id: ID of the message to be acknowledged - :param subscription: ID of the relevant subscriptiong - :param **kwargs: Further parameters for the transport layer. For example - transaction: Transaction ID if acknowledgement should be part of - a transaction + + Args: + message_id: ID of the message to be acknowledged + subscription_id: ID of the relevant subscription + **kwargs: + Further parameters for the transport layer. For example: + transaction: Transaction ID if acknowledgement should be part of + a transaction. """ self._conn.ack(message_id, subscription_id, **kwargs) - def _nack(self, message_id, subscription_id, **kwargs): + def _nack(self, message_id: Any, subscription_id: int, **kwargs: Any) -> None: """Reject receipt of a message. This only makes sense when the 'acknowledgement' flag was set for the relevant subscription. - :param message_id: ID of the message to be rejected - :param subscription: ID of the relevant subscriptiong - :param **kwargs: Further parameters for the transport layer. For example - transaction: Transaction ID if rejection should be part of a - transaction + + Args: + message_id: ID of the message to be rejected + subscription_id: ID of the relevant subscription + **kwargs: + Further parameters for the transport layer. For example: + transaction: Transaction ID if rejection should be part of a transaction. """ self._conn.nack(message_id, subscription_id, **kwargs) @staticmethod - def _mangle_for_sending(message): + def _mangle_for_sending(message: Any) -> str: """Function that any message will pass through before it being forwarded to the actual _send* functions. Stomp only deals with serialized strings, so serialize message as json. @@ -492,7 +562,7 @@ def _mangle_for_sending(message): return json.dumps(message, default=json_serializer) @staticmethod - def _mangle_for_receiving(message): + def _mangle_for_receiving(message: Any) -> Any: """Function that any message will pass through before it being forwarded to the receiving subscribed callback functions. This transport class only deals with serialized strings, so decode @@ -506,12 +576,9 @@ def _mangle_for_receiving(message): ## Stomp listener methods ##################################################### - def _on_message(self, frame): + def _on_message(self, frame: stomp.utils.Frame) -> None: headers = frame.headers body = frame.body - subscription_id = int(headers.get("subscription")) + subscription_id = int(headers["subscription"]) target_function = self.subscription_callback(subscription_id) - if target_function: - target_function(headers, body) - else: - raise workflows.Error(f"Unhandled message {headers!r} {body!r}") + target_function(headers, body) diff --git a/src/workflows/util/zocalo/configuration.py b/src/workflows/util/zocalo/configuration.py index 33585c18..01ae7e49 100644 --- a/src/workflows/util/zocalo/configuration.py +++ b/src/workflows/util/zocalo/configuration.py @@ -1,9 +1,15 @@ +""" +Zocalo configuration for workflows objects + +Only imported if Zocalo is present in the environment. +""" + from __future__ import annotations -from typing import ClassVar, TypedDict +from typing import Any, ClassVar, TypedDict -from marshmallow import fields -from zocalo.configuration import PluginSchema +from marshmallow import fields # type: ignore +from zocalo.configuration import PluginSchema # type: ignore import workflows.transport from workflows.transport.pika_transport import PikaTransport @@ -26,7 +32,7 @@ class Schema(PluginSchema): timeout = fields.Int(required=False, load_default=10) @staticmethod - def activate(configuration): + def activate(configuration: dict[str, Any]) -> Any: # Build the full endpoint URL endpoint = f"https://{configuration['host']}:{configuration['port']}/v1/traces" OTEL.config["endpoint"] = endpoint @@ -45,7 +51,7 @@ class Schema(PluginSchema): prefix = fields.Str(required=True) @staticmethod - def activate(configuration): + def activate(configuration: dict[str, Any]) -> Any: for cfgoption, target in [ ("host", "--stomp-host"), ("port", "--stomp-port"), @@ -68,7 +74,7 @@ class Schema(PluginSchema): vhost = fields.Str(required=True) @staticmethod - def activate(configuration): + def activate(configuration: dict[str, Any]) -> Any: for cfgoption, target in [ ("host", "--rabbit-host"), ("port", "--rabbit-port"), @@ -87,5 +93,5 @@ class Schema(PluginSchema): default = fields.Str(required=True) @staticmethod - def activate(configuration): + def activate(configuration: dict[str, Any]) -> None: workflows.transport.default_transport = configuration["default"] diff --git a/tests/frontend/test_frontend.py b/tests/frontend/test_frontend.py index 15cf6d29..ec8ce531 100644 --- a/tests/frontend/test_frontend.py +++ b/tests/frontend/test_frontend.py @@ -14,7 +14,7 @@ class ServiceCrashingOnInit(CommonService): """A service that raises an unhandled exception.""" @staticmethod - def initializing(): + def initializing(): # type: ignore """Raise AssertionError. This should set the error state, kill the service and cause the frontend to leave its main loop.""" diff --git a/tests/transport/test_common.py b/tests/transport/test_common.py index 2ac337c3..1bbea536 100644 --- a/tests/transport/test_common.py +++ b/tests/transport/test_common.py @@ -212,7 +212,9 @@ def test_simple_send_message(): ct.send(mock.sentinel.destination, mock.sentinel.message) - ct._send.assert_called_with(mock.sentinel.destination, mock.sentinel.message) + ct._send.assert_called_with( + mock.sentinel.destination, mock.sentinel.message, headers=None + ) def test_simple_broadcast_message(): diff --git a/tests/transport/test_pika.py b/tests/transport/test_pika.py index 4745cc00..516c21ea 100644 --- a/tests/transport/test_pika.py +++ b/tests/transport/test_pika.py @@ -70,11 +70,10 @@ def test_lookup_and_initialize_pika_transport_layer(): def test_add_command_line_help_optparse(): """Check that command line parameters are registered in the parser.""" - parser = mock.MagicMock() + parser = mock.MagicMock(spec=optparse.OptionParser) PikaTransport().add_command_line_options(parser) - parser.add_argument.assert_not_called() parser.add_option.assert_called() assert parser.add_option.call_count > 4 for call in parser.add_option.call_args_list: @@ -83,13 +82,11 @@ def test_add_command_line_help_optparse(): def test_add_command_line_help_argparse(): """Check that command line parameters are registered in the parser.""" - parser = mock.MagicMock() - parser.add_argument = mock.Mock() + parser = mock.MagicMock(spec=argparse.ArgumentParser) PikaTransport().add_command_line_options(parser) parser.add_argument.assert_called() - parser.add_option.assert_not_called() assert parser.add_argument.call_count > 4 for call in parser.add_argument.call_args_list: assert inspect.isclass(call[1]["action"]) diff --git a/tests/transport/test_stomp.py b/tests/transport/test_stomp.py index 7363b59e..dd2291ba 100644 --- a/tests/transport/test_stomp.py +++ b/tests/transport/test_stomp.py @@ -32,11 +32,10 @@ def test_lookup_and_initialize_stomp_transport_layer(): def test_add_command_line_help_optparse(): """Check that command line parameters are registered in the parser.""" - parser = mock.MagicMock() + parser = mock.MagicMock(spec=optparse.OptionParser) StompTransport().add_command_line_options(parser) - parser.add_argument.assert_not_called() parser.add_option.assert_called() assert parser.add_option.call_count > 4 for call in parser.add_option.call_args_list: @@ -45,13 +44,11 @@ def test_add_command_line_help_optparse(): def test_add_command_line_help_argparse(): """Check that command line parameters are registered in the parser.""" - parser = mock.MagicMock() - parser.add_argument = mock.Mock() + parser = mock.MagicMock(spec=argparse.ArgumentParser) StompTransport().add_command_line_options(parser) parser.add_argument.assert_called() - parser.add_option.assert_not_called() assert parser.add_argument.call_count > 4 for call in parser.add_argument.call_args_list: assert inspect.isclass(call[1]["action"])