diff --git a/src/labthings_fastapi/actions.py b/src/labthings_fastapi/actions.py index 888acfef..0ae46b72 100644 --- a/src/labthings_fastapi/actions.py +++ b/src/labthings_fastapi/actions.py @@ -22,7 +22,7 @@ from collections import deque from functools import partial, wraps import inspect -from threading import Thread, Lock +from threading import Thread, Lock, RLock import uuid from typing import ( TYPE_CHECKING, @@ -39,10 +39,9 @@ ) from weakref import WeakSet import weakref -from fastapi import APIRouter, FastAPI, HTTPException, Request, Body, BackgroundTasks +from fastapi import APIRouter, FastAPI, HTTPException, Response, Body, BackgroundTasks from pydantic import BaseModel, create_model - from .middleware.url_for import URLFor from .base_descriptor import ( BaseDescriptor, @@ -50,13 +49,20 @@ DescriptorInfoCollection, ) from .logs import add_thing_log_destination -from .utilities import model_to_dict, wrap_plain_types_in_rootmodel -from .invocations import InvocationModel, InvocationStatus +from .utilities import ( + RootModelWrapper, + model_to_dict, + serialize_from_user_code, + validate_from_user_code, +) +from .invocations import InvocationSummary, InvocationModel, InvocationStatus from .exceptions import ( GlobalLockBusyError, + InvalidReturnValueError, InvocationCancelledError, InvocationError, NotConnectedToServerError, + UnserializableTypeError, ) from . import invocation_contexts from .utilities.introspection import ( @@ -140,9 +146,9 @@ def __init__( self.expiry_time: Optional[datetime.datetime] = None # Private state properties - self._status_lock = Lock() # This Lock protects properties below + self._status_lock = RLock() # This Lock protects properties below self._status: InvocationStatus = InvocationStatus.PENDING # Task status - self._return_value: Optional[Any] = None # Return value + self._output_model_instance: Optional[BaseModel] = None # Return value self._request_time: datetime.datetime = datetime.datetime.now() self._start_time: Optional[datetime.datetime] = None # Task start time self._end_time: Optional[datetime.datetime] = None # Task end time @@ -158,7 +164,13 @@ def id(self) -> uuid.UUID: def output(self) -> Any: """Return value of the Action. If the Action is still running, returns None.""" with self._status_lock: - return self._return_value + return RootModelWrapper.unwrap(self._output_model_instance) + + @property + def output_model_instance(self) -> BaseModel | None: + """Return value of the Action, as a model, or None.""" + with self._status_lock: + return self._output_model_instance @property def log(self) -> list[logging.LogRecord]: @@ -216,6 +228,29 @@ def cancel(self) -> None: """ self.cancel_hook.set() + def summary_model(self) -> InvocationSummary: + """Generate a summary of the invocation suitable for HTTP. + + :return: a `InvocationSummary` representing this `Invocation`. + """ + links = [ + LinkElement(rel="self", href=URLFor("action_invocation", id=self.id)), + LinkElement( + rel="output", href=URLFor("action_invocation_output", id=self.id) + ), + ] + with self._status_lock: + return InvocationSummary( + status=self.status, + id=self.id, + action=self.thing.path + self.action.name, # type: ignore[attr-defined] + href=URLFor("action_invocation", id=self.id), + timeStarted=self._start_time, + timeCompleted=self._end_time, + timeRequested=self._request_time, + links=links, + ) + def response(self) -> InvocationModel: """Generate a representation of the invocation suitable for HTTP. @@ -225,27 +260,15 @@ def response(self) -> InvocationModel: :return: an `.InvocationModel` representing this `.Invocation`. """ - links = [ - LinkElement(rel="self", href=URLFor("action_invocation", id=self.id)), - LinkElement( - rel="output", href=URLFor("action_invocation_output", id=self.id) - ), - ] # The line below confuses MyPy because self.action **evaluates to** a Descriptor # object (i.e. we don't call __get__ on the descriptor). - return self.action.invocation_model( # type: ignore[attr-defined] - status=self.status, - id=self.id, - action=self.thing.path + self.action.name, # type: ignore[attr-defined] - href=URLFor("action_invocation", id=self.id), - timeStarted=self._start_time, - timeCompleted=self._end_time, - timeRequested=self._request_time, - input=self.input, - output=self.output, - links=links, - log=self.log, - ) + with self._status_lock: + return self.action.invocation_model( # type: ignore[attr-defined] + **dict(self.summary_model()), + input=self.input, + output=self.output_model_instance, + log=self.log, + ) def run(self) -> None: """Run the action and track progress. @@ -303,11 +326,17 @@ def run(self) -> None: # Actually run the action ret = action.func(thing, **kwargs, **self.dependencies) - with self._status_lock: - self._return_value = ret - self._status = InvocationStatus.COMPLETED - action.emit_changed_event(self.thing, self._status.value) + output_model_instance = validate_from_user_code( + model=action.output_model, + value=ret, + description=f"the output of '{self.thing.name}.{action.name}'", + code=action.func, + ) + with self._status_lock: + self._output_model_instance = output_model_instance + self._status = InvocationStatus.COMPLETED + action.emit_changed_event(self.thing, self._status.value) except InvocationCancelledError: logger.info(f"Invocation {self.id} was cancelled.") with self._status_lock: @@ -315,7 +344,7 @@ def run(self) -> None: action.emit_changed_event(self.thing, self._status.value) except Exception as e: # skipcq: PYL-W0703 # First log - if isinstance(e, InvocationError): + if isinstance(e, (InvocationError, InvalidReturnValueError)): # Log without traceback for anticipated errors logger.error(e) elif ( @@ -412,11 +441,10 @@ def list_invocations( self, action: Optional[ActionDescriptor] = None, thing: Optional[Thing] = None, - request: Optional[Request] = None, - ) -> list[InvocationModel]: + ) -> list[InvocationSummary]: """All of the invocations currently managed. - Returns a list of `.InvocationModel` instances representing all the + Returns a list of `InvocationSummary` instances representing all the invocations that are currently running, or have recently completed and not yet expired. @@ -427,16 +455,11 @@ def list_invocations( :param thing: returns only invocations of actions on a particular `~lt.Thing`. This will often be combined with filtering by ``action`` to give the list of invocations returned by a GET request on an action endpoint. - :param request: is used to pass a `fastapi.Request` object to the - `.Invocation.response` method. Doing so ensures the URL returned as - ``href`` in the response matches the address used to communicate with - the server (i.e. it uses `fastapi.Request.url_for` instead of a path - generated from a string). :return: A list of invocations, optionally filtered by Thing and/or Action. """ return [ - i.response() + i.summary_model() for i in self.invocations if thing is None or i.thing == thing if action is None or i.action == action @@ -461,20 +484,19 @@ def router(self) -> APIRouter: """ router = APIRouter() - @router.get(ACTION_INVOCATIONS_PATH, response_model=list[InvocationModel]) - def list_all_invocations(request: Request) -> list[InvocationModel]: - return self.list_invocations(request=request) + @router.get(ACTION_INVOCATIONS_PATH) + def list_all_invocations() -> list[InvocationSummary]: + return self.list_invocations() @router.get( ACTION_INVOCATIONS_PATH + "/{id}", + response_model=InvocationModel, responses={404: {"description": "Invocation ID not found"}}, ) - def action_invocation(id: uuid.UUID, request: Request) -> InvocationModel: + def action_invocation(id: uuid.UUID) -> Response: """Return a description of a specific action. :param id: The action's ID (from the path). - :param request: FastAPI dependency for the request object, used to - find URLs via ``url_for``. :return: Details of the invocation. @@ -482,13 +504,23 @@ def action_invocation(id: uuid.UUID, request: Request) -> InvocationModel: found. """ try: - with self._invocations_lock: - return self._invocations[id].response() + invocation = self.get_invocation(id) + return serialize_from_user_code( + model_instance=invocation.response(), + description=f"invocation '{id}' of ", + code=invocation.action.func, # type: ignore[attr-defined] + ) except KeyError as e: raise HTTPException( status_code=404, detail="No action invocation found with ID {id}", ) from e + except InvalidReturnValueError as e: + invocation.thing.logger.error(e) + raise HTTPException( + status_code=500, + detail=str(e), + ) from e @router.get( ACTION_INVOCATIONS_PATH + "/{id}/output", @@ -504,7 +536,7 @@ def action_invocation(id: uuid.UUID, request: Request) -> InvocationModel: 503: {"description": "No result is available for this invocation"}, }, ) - def action_invocation_output(id: uuid.UUID) -> Any: + def action_invocation_output(id: uuid.UUID) -> Response: """Get the output of an action invocation. This returns just the "output" component of the action invocation. If the @@ -536,7 +568,18 @@ def action_invocation_output(id: uuid.UUID) -> Any: ): # TODO: honour "accept" header return invocation.output.response() - return invocation.output + try: + return serialize_from_user_code( + model_instance=invocation.output_model_instance, + description=f"the output of {invocation}", + code=invocation.action.func, + ) + except InvalidReturnValueError as e: + invocation.thing.logger.error(e) + raise HTTPException( + status_code=500, + detail=str(e), + ) from e @router.delete( ACTION_INVOCATIONS_PATH + "/{id}", @@ -710,6 +753,8 @@ def __init__( if more nuanced locking behaviour is required meaning the lock is acquired directly in the action code, for example using `~lt.ThingServerInterface.hold_global_lock`\ . + :raises UnserializableTypeError: if the return type of the action cannot + be serialised to JSON by `pydantic`\ . """ super().__init__() self.func = func @@ -726,7 +771,13 @@ def __init__( remove_first_positional_arg=True, ignore=[p.name for p in self.dependency_params], ) - self.output_model = wrap_plain_types_in_rootmodel(return_type(func)) + try: + self.output_model = RootModelWrapper.wrap_type( + return_type(func), name=f"{name.title()}Output" + ) + except UnserializableTypeError as e: + e.set_source_function(func) + raise self.invocation_model = create_model( f"{name}_invocation", __base__=InvocationModel, @@ -810,7 +861,7 @@ def instance_get(self, obj: OwnerT) -> Callable[ActionParams, ActionReturn]: """ @wraps(self.func) - def wrapped(*args: Any, **kwargs: Any) -> Any: # noqa: DOC + def wrapped(*args: Any, **kwargs: Any) -> Any: # noqa: DOC101, DOC103, DOC201 """Acquire the lock then run `func` with supplied arguments.""" with self.context_for_func(obj): return self.func(*args, **kwargs) @@ -900,16 +951,25 @@ def start_action( body: Any, # This annotation will be overwritten below. background_tasks: BackgroundTasks, **dependencies: Any, - ) -> InvocationModel: + ) -> Response: action_manager = thing._thing_server_interface._action_manager - action = action_manager.invoke_action( + invocation = action_manager.invoke_action( action=self, thing=thing, input=body, dependencies=dependencies, ) background_tasks.add_task(action_manager.expire_invocations) - return action.response() + try: + return serialize_from_user_code( + model_instance=invocation.response(), + description=f"{invocation}", + status_code=201, + code=self.func, + ) + except InvalidReturnValueError as e: + thing.logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) from e if issubclass(self.input_model, EmptyInput): annotation = Body(default_factory=StrictEmptyInput) @@ -952,9 +1012,9 @@ def start_action( ) app.post( thing.path + self.name, - response_model=self.invocation_model, status_code=201, response_description="Action has been invoked (and may still be running).", + response_model=self.invocation_model, description=f"## {self.title}\n\n {self.description} {ACTION_POST_NOTICE}", summary=self.title, responses=responses, @@ -962,7 +1022,7 @@ def start_action( @app.get( thing.path + self.name, - response_model=list[self.invocation_model], # type: ignore + response_model=list[InvocationSummary], # MyPy doesn't like the line above - but it works for FastAPI # to generate a response model. response_description=f"A list of every invocation of {self.name}.", @@ -971,7 +1031,7 @@ def start_action( ), summary=f"All invocations of {self.name}.", ) - def list_invocations() -> list[InvocationModel]: + def list_invocations() -> list[InvocationSummary]: action_manager = thing._thing_server_interface._action_manager return action_manager.list_invocations(self, thing) diff --git a/src/labthings_fastapi/client/__init__.py b/src/labthings_fastapi/client/__init__.py index a2ddc247..566852c5 100644 --- a/src/labthings_fastapi/client/__init__.py +++ b/src/labthings_fastapi/client/__init__.py @@ -101,15 +101,24 @@ def poll_invocation( :param first_interval: sets how long we wait before the first polling request. Often, it makes sense for this to be a short interval, in case the action fails (or returns) immediately. - + :raises ServerActionError: if an HTTP error is found during polling. :return: the completed invocation as a dictionary. """ first_time = True while invocation["status"] in ACTION_RUNNING_KEYWORDS: time.sleep(first_interval if first_time else interval) - r = client.get(invocation_href(invocation)) - r.raise_for_status() - invocation = r.json() + response = client.get(invocation_href(invocation)) + if response.is_error: + try: + message = response.json()["detail"] + except KeyError: + message = response.text + raise ServerActionError( + f"The server returned error {response.status_code} while polling " + f"action '{invocation['action']}' with id '{invocation['id']}'. " + f"The error message was:\n{message}." + ) + invocation = response.json() first_time = False return invocation diff --git a/src/labthings_fastapi/exceptions.py b/src/labthings_fastapi/exceptions.py index 472d5269..6dffe269 100644 --- a/src/labthings_fastapi/exceptions.py +++ b/src/labthings_fastapi/exceptions.py @@ -4,6 +4,8 @@ # An __all__ for this module is less than helpful, unless we have an # automated check that everything's included. +from collections.abc import Callable + class NotConnectedToServerError(RuntimeError): """The Thing is not connected to a server. @@ -254,6 +256,74 @@ class NoInvocationContextError(RuntimeError): """ +class CausedByUserCodeError(Exception): + """A mixin to allow exceptions to refer to downstream code.""" + + def _append_to_args(self, message: str) -> None: + """Add a message to the exception's arguments. + + :param message: the message to append. + """ + if len(self.args) == 1: + # If there's only one string, assume it's a message and append + self.args = (self.args[0] + "\n" + message,) + else: + # If there are multiple arguments, add this as a further one + self.args += (message,) + + def set_source_function(self, func: Callable) -> None: + """Add the location of a user-supplied function to the error message. + + :param func: the function that caused this error. + """ + code = func.__code__ + self._append_to_args( + f"This was likely caused by function '{code.co_name}' " + f"at {code.co_filename}:{code.co_firstlineno}" + ) + + def set_source_class(self, cls: type, attr: str | None = None) -> None: + """Add a reference to a class (and optionally attribute). + + :param cls: the class that caused this error. + :param attr: the attribute name that caused this error. + """ + self._append_to_args( + f"This was likely caused by '{cls.__module__}.{cls.__qualname__}.{attr}" + if attr + else "'." + ) + + +class InvalidReturnValueError(CausedByUserCodeError, RuntimeError): + r"""The return value from a method cannot be serialised by LabThings. + + This error is raised when an action returns a value that can't be serialised. + This usually means that either it doesn't match the declared return type of + the function, or the declared return type permits un-serialisable values. + + If an action's return type is missing or `Any`\ , it's possible to return a + value that can't be serialised, which will cause this error. + + The solution is usually to ensure that the return type of your action is + either a simple type that can be serialised to JSON, or a Pydantic model. + You should also check that the function's return value matches the declared + type, ideally by regularly running a type checker like `mypy` on your code. + """ + + +class UnserializableTypeError(CausedByUserCodeError, TypeError): + r"""A type has been specified that can't be serialized to JSON. + + This error generally means a property or action has a type that cannot be + serialized to JSON. This might be an instance of a custom class, or another + datatype that doesn't have a ready representation using JSON-compatible types. + + This error can often be fixed using `pydantic` annotations, or by using simple + Python types instead of custom ones. + """ + + class LogConfigurationError(RuntimeError): """There is a problem with logging configuration. diff --git a/src/labthings_fastapi/invocations.py b/src/labthings_fastapi/invocations.py index 403d1dd5..c9512a1f 100644 --- a/src/labthings_fastapi/invocations.py +++ b/src/labthings_fastapi/invocations.py @@ -11,7 +11,11 @@ from typing import Optional, Any, Sequence, TypeVar, Generic import uuid -from pydantic import BaseModel, ConfigDict, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + model_validator, +) from labthings_fastapi.middleware.url_for import URLFor @@ -80,11 +84,31 @@ def generate_message(cls, data: Any) -> Any: return data +class InvocationSummary(BaseModel): + """A model to represent `.Invocation` objects over HTTP. + + This version of the model does not include logs our action outputs, and is intended + for use in endpoints that might list several invocations. + + See `GenericInvocationModel` for the full representation, to be used in + endpoints referring to one specific invocation. + """ + + status: InvocationStatus + id: uuid.UUID + action: str + href: URLFor + timeStarted: Optional[datetime] + timeRequested: Optional[datetime] + timeCompleted: Optional[datetime] + links: Links = None + + InputT = TypeVar("InputT") OutputT = TypeVar("OutputT") -class GenericInvocationModel(BaseModel, Generic[InputT, OutputT]): +class GenericInvocationModel(InvocationSummary, Generic[InputT, OutputT]): """A model to serialise `.Invocation` objects when they are polled over HTTP. The input and output models are generic parameters, to allow this model to @@ -93,17 +117,9 @@ class GenericInvocationModel(BaseModel, Generic[InputT, OutputT]): are not known in advance. """ - status: InvocationStatus - id: uuid.UUID - action: str - href: URLFor - timeStarted: Optional[datetime] - timeRequested: Optional[datetime] - timeCompleted: Optional[datetime] input: InputT output: OutputT log: Sequence[LogRecordModel] - links: Links = None InvocationModel = GenericInvocationModel[Any, Any] diff --git a/src/labthings_fastapi/properties.py b/src/labthings_fastapi/properties.py index e3de1b6b..a51295ba 100644 --- a/src/labthings_fastapi/properties.py +++ b/src/labthings_fastapi/properties.py @@ -62,7 +62,7 @@ class attribute. Documentation is in strings immediately following the from typing_extensions import Self, TypedDict from weakref import WeakSet -from fastapi import Body, FastAPI +from fastapi import Body, FastAPI, Response, HTTPException from pydantic import ( BaseModel, ConfigDict, @@ -81,9 +81,10 @@ class attribute. Documentation is in strings immediately following the PropertyOp, ) from .utilities import ( - LabThingsRootModelWrapper, + RootModelWrapper, labthings_data, - wrap_plain_types_in_rootmodel, + serialize_from_user_code, + validate_from_user_code, ) from .utilities.introspection import return_type from .base_descriptor import ( @@ -93,10 +94,12 @@ class attribute. Documentation is in strings immediately following the ) from .exceptions import ( FeatureNotAvailableError, + InvalidReturnValueError, NotConnectedToServerError, PropertyRedefinitionError, ReadOnlyPropertyError, MissingTypeError, + UnserializableTypeError, UnsupportedConstraintError, ) from .thing_class_settings import get_validate_properties_on_set @@ -464,12 +467,19 @@ def model(self) -> type[BaseModel]: subclass, this returns it unchanged. :return: a Pydantic model for the property's type. + :raises UnserializableTypeError: if the property can't be serialized + by `pydantic` to JSON. """ if self._model is None: - self._model = wrap_plain_types_in_rootmodel( - self.value_type, - constraints=self.constraints, - ) + try: + self._model = RootModelWrapper.wrap_type( + self.value_type, + constraints=self.constraints, + name=f"{self.name.title()}Value", + ) + except UnserializableTypeError as e: + e.set_source_class(self.owning_class, self.name) + raise return self._model def get_default(self, obj: Owner | None) -> Value: @@ -559,8 +569,25 @@ def set_property(body: Any) -> None: summary=self.title, description=f"## {self.title}\n\n{self.description or ''}", ) - def get_property() -> Any: - return self.__get__(thing) + def get_property() -> Response: + try: + instance = validate_from_user_code( + model=self.model, + value=self.__get__(thing), + description=f"{thing.name}.{self.name}", + code=(self.owning_class, self.name), + ) + return serialize_from_user_code( + model_instance=instance, + description=f"{thing.name}.{self.name}", + code=(self.owning_class, self.name), + ) + except InvalidReturnValueError as e: + thing.logger.error(e) + raise HTTPException( + status_code=500, + detail=str(e), + ) from e if self.is_resettable(thing): @@ -1270,7 +1297,7 @@ def validate(self, value: Any) -> Value: with its value type. This should never happen. """ try: - if issubclass(self.model, LabThingsRootModelWrapper): + if issubclass(self.model, RootModelWrapper): # If a plain type has been wrapped in a RootModel, use that to validate # and then set the property to the root value. model = self.model.model_validate(value) @@ -1283,7 +1310,7 @@ def validate(self, value: Any) -> Value: return self.value_type.model_validate(value) # This should be unreachable, because `model` is a - # `LabThingsRootModelWrapper` wrapping the value type, or the value type + # `RootModelWrapper` wrapping the value type, or the value type # should be a BaseModel. msg = f"Property {self.name} has an inconsistent model. This is " msg += f"most likely a LabThings bug. {self.model=}, {self.value_type=}" diff --git a/src/labthings_fastapi/server/__init__.py b/src/labthings_fastapi/server/__init__.py index 4b46e4d8..4e7d502b 100644 --- a/src/labthings_fastapi/server/__init__.py +++ b/src/labthings_fastapi/server/__init__.py @@ -9,6 +9,7 @@ import warnings from fastapi.testclient import TestClient from pydantic import ValidationError +from pydantic_core import PydanticSerializationError from typing import Any, AsyncGenerator, Optional, TypeVar, overload from fastapi.responses import JSONResponse from typing_extensions import Self @@ -50,6 +51,9 @@ ThingSubclass = TypeVar("ThingSubclass", bound=Thing) +LOGGER = logging.getLogger(__name__) + + class ThingServer: """Use FastAPI to serve `~lt.Thing` instances. @@ -141,7 +145,7 @@ def __init__( self._config = ThingServerConfig(**kwargs) if self._config.settings_folder is None: self._config.settings_folder = "./settings" - self.app = FastAPI(lifespan=self.lifespan) + self.app = FastAPI(lifespan=self.lifespan, separate_input_output_schemas=False) self._set_cors_middleware() self._set_url_for_middleware() self._add_exception_handlers() @@ -248,6 +252,16 @@ async def global_lock_exception_handler( content={"detail": repr(exc)}, ) + @self.app.exception_handler(PydanticSerializationError) + async def serialization_error_handler( + request: Request, exc: PydanticSerializationError + ) -> JSONResponse: + LOGGER.error( + f"Couldn't serialize response to {request.url} because of error: \n" + f"{exc}" + ) + return JSONResponse(status_code=500, content={"detail": str(exc)}) + @property def debug(self) -> bool: """Whether the server is in debug mode.""" diff --git a/src/labthings_fastapi/utilities/__init__.py b/src/labthings_fastapi/utilities/__init__.py index 51515601..91a9937e 100644 --- a/src/labthings_fastapi/utilities/__init__.py +++ b/src/labthings_fastapi/utilities/__init__.py @@ -1,13 +1,28 @@ """Utility functions used by LabThings-FastAPI.""" from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Callable, Mapping +from types import FunctionType from typing import Any, Dict, Generic, Iterable, TYPE_CHECKING, Optional, TypeVar from weakref import WeakSet -from pydantic import BaseModel, ConfigDict, Field, RootModel, create_model +from pydantic import ( + BaseModel, + ConfigDict, + Field, + RootModel, + create_model, + PydanticSchemaGenerationError, + ValidationError, +) from pydantic.dataclasses import dataclass - -from labthings_fastapi.exceptions import UnsupportedConstraintError +from pydantic_core import PydanticSerializationError +from fastapi import Response + +from labthings_fastapi.exceptions import ( + InvalidReturnValueError, + UnsupportedConstraintError, + UnserializableTypeError, +) from .introspection import EmptyObject if TYPE_CHECKING: @@ -17,9 +32,9 @@ __all__ = [ "class_attributes", "attributes", + "RootModelWrapper", "LabThingsObjectData", "labthings_data", - "wrap_plain_types_in_rootmodel", "model_to_dict", ] @@ -97,7 +112,7 @@ def labthings_data(obj: Thing) -> LabThingsObjectData: WrappedT = TypeVar("WrappedT") -class LabThingsRootModelWrapper(RootModel[WrappedT], Generic[WrappedT]): +class RootModelWrapper(RootModel[WrappedT], Generic[WrappedT]): """A RootModel subclass for automatically-wrapped types. There are several places where LabThings needs a model, but may only @@ -105,52 +120,198 @@ class LabThingsRootModelWrapper(RootModel[WrappedT], Generic[WrappedT]): a type has been automatically wrapped, and will need to be unwrapped in order for the value to have the correct type. - It has no additional functionality. + It also provides methods to automatically wrap types if they are not + already `pydantic.BaseModel` subclasses, and to unwrap them again. + """ + + @classmethod + def wrap_type( + cls, + model: type, + constraints: Mapping[str, Any] | None = None, + name: str | None = None, + ) -> type[BaseModel]: + r"""Ensure a type is a subclass of BaseModel. + + If a `pydantic.BaseModel` subclass is passed to this function, we will pass it + through unchanged. Otherwise, we wrap the type in a `pydantic.RootModel`. + In the future, we may explicitly check that the argument is a type + and not a model instance. + + :param model: A Python type or `pydantic` model. + :param constraints: is passed as keyword arguments to `pydantic.Field` + to add validation constraints to the property. + :param name: the name to use for the dynamically created model. + + :return: A `pydantic` model, wrapping Python types in a ``RootModel`` if needed. + + :raises UnsupportedConstraintError: if constraints are provided for an + unsuitable type, for example `allow_inf_nan` for an `int` property, or + any constraints for a `BaseModel` subclass. + :raises UnserializableTypeError: if the type being wrapped is not able + to be serialized by `pydantic`\ . + :raises RuntimeError: if other errors prevent Pydantic from creating a schema + for the generated model. + """ + try: # This needs to be a `try` as basic types are not classes + if issubclass(model, BaseModel): + if constraints: + raise UnsupportedConstraintError( + "Constraints may only be applied to plain types, not Models." + ) + return model + except TypeError: + pass # some types aren't classes and that's OK - they still get wrapped. + constraints = constraints or {} + try: + # Dynamically create a subclass of RootModelWrapper for the supplied type. + return create_model( + f"{name or cls.__name__}[{model!r}]", + root=(model, Field(**constraints)), + __base__=cls, + ) + except PydanticSchemaGenerationError as e: + raise UnserializableTypeError( + f"LabThings does not know how to serialize {model!r} to JSON." + ) from e + except RuntimeError as e: + if "Unable to apply constraint" in str(e): + raise UnsupportedConstraintError(str(e)) from e + raise e + + @classmethod + def unwrap(cls, value: BaseModel | None) -> Any: + r"""If the supplied value is a `RootModelWrapper`, unwrap it. + + :param value: a model instance. + :return: the root value, if ``value`` is a `RootModelWrapper`\ , or ``value`` + if not. + """ + if value is None: + return None + if isinstance(value, cls): + return value.root + return value + + +def refer_to_user_code( + code: Callable | tuple[type, str] | None = None, suffix: str = "\n" +) -> str: + r"""Refer to a user-supplied function or property. + + This function generates a human-readable error string that should enable someone + to find a problem in downstream code. + + If `code` is `None` the empty string will be returned. This is intended to simplify + the construction of error messages that may or may not include a code location. + + :param code: the code that generated `value`\ . This may be either a function, + a tuple consisting of a class and an attribute name, or a string (which + should describe how to find the user code that generated the value). + :param suffix: a string that terminates the message, by default a newline. This + is not used if `code` is None, and instead the empty string is returned. + :return: a string referring to the user code, for use in an error message, or + the empty string if no user code is specified. """ + if callable(code): + if isinstance(code, FunctionType): + # As a rule, we'll pass a function and this code works. + co = code.__code__ + return ( + f"This value was returned by '{co.co_name}' " + f"at {co.co_filename}:{co.co_firstlineno}.{suffix}" + ) + else: + # As a fallback (not currently used), just dump the object to string. + return f"This value was returned by {repr(code)}.{suffix}" + elif isinstance(code, tuple) and len(code) == 2: + cls, attr = code + return ( + "You may want to check the definition of " + f"{cls.__module__}.{cls.__qualname__}.{attr}.{suffix}" + ) + else: + return "" + +ModelT = TypeVar("ModelT", bound=BaseModel) -def wrap_plain_types_in_rootmodel( - model: type, constraints: Mapping[str, Any] | None = None -) -> type[BaseModel]: - """Ensure a type is a subclass of BaseModel. - If a `pydantic.BaseModel` subclass is passed to this function, we will pass it - through unchanged. Otherwise, we wrap the type in a `pydantic.RootModel`. - In the future, we may explicitly check that the argument is a type - and not a model instance. +def validate_from_user_code( + model: type[ModelT], + value: Any, + description: str, + code: Callable | tuple[type, str] | None = None, +) -> ModelT: + r"""Validate a return value from user code, with error handling. - :param model: A Python type or `pydantic` model. - :param constraints: is passed as keyword arguments to `pydantic.Field` - to add validation constraints to the property. + This wraps ``return model.model_validate(value)`` in error handling code. + It is intended to help LabThings generate better errors when models fail + to validate, in particular making clear where in the user's code the value + was generated, and why it didn't validate. - :return: A `pydantic` model, wrapping Python types in a ``RootModel`` if needed. + :param model: the `pydantic` model to use for validation. + :param value: the value passed to ``model.model_validate()``\ . + :param description: a description of the value, e.g. "the output of {action}". + :param code: the code that generated `value`\ . + This will be passed to `refer_to_user_code` - see that function for details. - :raises UnsupportedConstraintError: if constraints are provided for an - unsuitable type, for example `allow_inf_nan` for an `int` property, or - any constraints for a `BaseModel` subclass. - :raises RuntimeError: if other errors prevent Pydantic from creating a schema - for the generated model. + :return: a validated model instance. + :raises InvalidReturnValueError: if the model failed to validate. """ - try: # This needs to be a `try` as basic types are not classes - if issubclass(model, BaseModel): - if constraints: - raise UnsupportedConstraintError( - "Constraints may only be applied to plain types, not Models." - ) - return model - except TypeError: - pass # some plain types aren't classes and that's OK - they still get wrapped. - constraints = constraints or {} try: - return create_model( - f"{model!r}", - root=(model, Field(**constraints)), - __base__=LabThingsRootModelWrapper, + return model.model_validate(value) + except ValidationError as e: + msg = ( + f"Error validating {description} against its model.\n" + f"The value was '{value}' and the model was {model}.\n" + f"{refer_to_user_code(code)}" + f"The validation error was:\n{e}\n" + ) + raise InvalidReturnValueError(msg) from e + + +def serialize_from_user_code( + model_instance: BaseModel, + description: str, + status_code: int = 200, + code: Callable | tuple[type, str] | None = None, +) -> Response: + r"""Return a value from a model instance, with appropriate error handling. + + This function implements very similar logic to FastAPI's default behaviour when + an endpoint function is typed as returning a `pydantic.BaseModel` instance. + The validated model instance is serialised to JSON by calling + ``model_dump_json()`` on the model instance, and the resulting string is returned + in a `Response` object. This uses `pydantic` serialization, written in Rust, + and outperforms the native `json` library significantly. + + If the model can't be serialized, we raise an exception with information about + the place in the user code where the problem occurred. + + :param model_instance: the `pydantic` model to use for validation. + :param description: a description of the value, e.g. "the output of {action}". + :param status_code: an HTTP status code to use. + :param code: the code that generated `value`\ . + This will be passed to `refer_to_user_code` - see that function for details. + :return: a `fastapi.Response` object containing a 200 code and the serialised + value or a 500 code and the error message. + :raises InvalidReturnValueError: if the model can't be serialised. + """ + try: + return Response( + content=model_instance.model_dump_json(), + status_code=status_code, + media_type="application/json", + ) + except PydanticSerializationError as exc: + msg = ( + f"Error serializing {description} to JSON.\n" + f"The value was validated as {repr(model_instance)}.\n" + f"The serialization error was '{exc}'.\n" + f"{refer_to_user_code(code)}" ) - except RuntimeError as e: - if "Unable to apply constraint" in str(e): - raise UnsupportedConstraintError(str(e)) from e - raise e + raise InvalidReturnValueError(msg) from exc def model_to_dict(model: Optional[BaseModel]) -> Dict[str, Any]: diff --git a/tests/test_action_manager.py b/tests/test_action_manager.py index 37242d42..4b0603b3 100644 --- a/tests/test_action_manager.py +++ b/tests/test_action_manager.py @@ -74,4 +74,7 @@ def test_actions_list(client): r2 = client.get(ACTION_INVOCATIONS_PATH) r2.raise_for_status() invocations = r2.json() + # Some keys aren't present in the list for performance/safety reasons + for k in ["input", "output", "log"]: + del invocation[k] assert invocations == [invocation] diff --git a/tests/test_actions.py b/tests/test_actions.py index fbf892ba..4798f46a 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -1,5 +1,7 @@ +from typing import Any import uuid from fastapi.testclient import TestClient +from labthings_fastapi.exceptions import FailedToInvokeActionError, ServerActionError from pydantic import BaseModel import pytest import functools @@ -333,3 +335,71 @@ def long_docstring(self) -> None: assert actions["long_docstring"].description.startswith( "It has multiple paragraphs." ) + + +def test_invalid_return_values(): + """Test the errors raised when an action's return value can't be serialised.""" + + class NaughtyThing(lt.Thing): + @lt.action + def make_random_int(self) -> int: + """An action that should return an integer, but returns a float.""" + return 4.2 + + @lt.action + def make_unjsonable_any(self) -> Any: + """A vaguely-typed action that won't serialise.""" + return object() + + server = lt.ThingServer.from_things({"naughty": NaughtyThing}) + with server.test_client() as client: + tc = lt.ThingClient.from_url("/naughty/", client=client) + + # Here, the return type doesn't match the type hint, so it should fail + # to validate. + with pytest.raises( + (ServerActionError, FailedToInvokeActionError), + match="Error validating the output of 'naughty.make_random_int'", + ): + tc.make_random_int() + + # The action should still have run, so check that we can get the + # invocation. + actions = client.get("/naughty/make_random_int/").json() + assert len(actions) == 1 + invocation = client.get(actions[0]["href"]).json() + assert invocation["output"] is None + assert invocation["status"] == "error" + first_invocation_id = invocation["id"] + assert "Error validating the output" in invocation["log"][-1]["message"] + response = client.get(invocation["links"][1]["href"]) + assert response.status_code == 503 # There's no output as it failed. + + # Here, the type hint is vague so it validates OK, but it can't + # serialize. + with pytest.raises( + (ServerActionError, FailedToInvokeActionError), + match="Error serializing invocation ", + ) as excinfo: + tc.make_unjsonable_any() + assert "make_unjsonable_any" in str(excinfo) + + # Get the last invocation + actions = client.get("/naughty/make_unjsonable_any/").json() + + # The action should still have run, so check that we can get the + # invocation. + actions = client.get("/naughty/make_unjsonable_any/").json() + assert len(actions) == 1 + second_invocation_id = actions[0]["id"] + response = client.get(actions[0]["href"]) + assert response.status_code == 500 + assert "Error serializing" in response.json()["detail"] + # Try the direct link to the action's output + response = client.get(actions[0]["links"][1]["href"]) + assert response.status_code == 500 # The output won't serialize + assert "Error serializing" in response.json()["detail"] + + # Check the overall invocations endpoint isn't broken + actions = client.get("/action_invocations/").json() + assert {a["id"] for a in actions} == {first_invocation_id, second_invocation_id} diff --git a/tests/test_properties.py b/tests/test_properties.py index a1806b12..792d99ca 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -9,8 +9,10 @@ import labthings_fastapi as lt from labthings_fastapi.exceptions import ( + ClientPropertyError, NotBoundToInstanceError, ServerNotRunningError, + UnserializableTypeError, UnsupportedConstraintError, ) from labthings_fastapi.properties import BaseProperty, PropertyInfo @@ -18,6 +20,10 @@ from .temp_client import poll_task +class Unjsonable: + """A class that pydantic can't serialize.""" + + class PropertyTestThing(lt.Thing): """A Thing with various properties for testing.""" @@ -612,3 +618,43 @@ def test_title_and_description(name, title, description): description = title # If a description is present, ignore any trailing whitespace. assert (prop.description.rstrip() if prop.description else None) == description + + +def test_bad_type(): + """Test an obviously un-serializable type raises an error.""" + + class BrokenThing(lt.Thing): + broken: Unjsonable | None = lt.property(default=None) + + with pytest.raises(UnserializableTypeError, match="BrokenThing.broken"): + _ = BrokenThing.properties["broken"].model + + +def test_bad_values(): + """Ensure bad values in properties generate sensible HTTP errors.""" + + class BrokenThing(lt.Thing): + def __init__(self, **kwargs): + super().__init__(**kwargs) + # Set bad values here because we shouldn't have invalid defaults. + self.__dict__["intprop"] = 4.2 + self.__dict__["anyprop"] = Unjsonable() + + intprop: int = lt.property(default=0) + anyprop: Any = lt.property(default=None) + + server = lt.ThingServer.from_things({"broken": BrokenThing}) + with server.test_client() as client: + tc = lt.ThingClient.from_url("/broken/", client=client) + + # The first property won't validate + with pytest.raises( + ClientPropertyError, match="Error validating broken.intprop" + ): + _ = tc.intprop + + # The second property won't serialize + with pytest.raises( + ClientPropertyError, match="Error serializing broken.anyprop" + ): + _ = tc.anyprop