diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b3d169b80..9529f09e4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: strategy: matrix: runs-on: ["ubuntu-latest"] # can add windows-latest, macos-latest - python-version: ["3.11"] + python-version: ["3.11", "3.12"] include: # Include one that runs in the dev environment - runs-on: "ubuntu-latest" diff --git a/pyproject.toml b/pyproject.toml index 3b70143cf..d36507ce6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] description = "Control system agnostic framework for building Device support in Python that will work for both EPICS and Tango" dependencies = [ diff --git a/src/fastcs/backend.py b/src/fastcs/backend.py index a7e8980fd..2330dd26c 100644 --- a/src/fastcs/backend.py +++ b/src/fastcs/backend.py @@ -1,10 +1,12 @@ import asyncio from collections import defaultdict from collections.abc import Callable -from types import MethodType + +from fastcs.cs_methods import Command, Put, Scan from .attributes import AttrR, AttrW, Sender, Updater -from .controller import Controller, SingleMapping +from .controller import BaseController, Controller +from .controller_api import ControllerAPI from .exceptions import FastCSException @@ -14,19 +16,21 @@ def __init__( controller: Controller, loop: asyncio.AbstractEventLoop, ): - self._loop = loop self._controller = controller + self._loop = loop self._initial_coros = [controller.connect] self._scan_tasks: set[asyncio.Task] = set() - loop.run_until_complete(self._controller.initialise()) + # Initialise controller and then build its APIs + loop.run_until_complete(controller.initialise()) + self.controller_api = build_controller_api(controller) self._link_process_tasks() def _link_process_tasks(self): - for single_mapping in self._controller.get_controller_mappings(): - _link_single_controller_put_tasks(single_mapping) - _link_attribute_sender_class(single_mapping) + for controller_api in self.controller_api.walk_api(): + _link_put_tasks(controller_api) + _link_attribute_sender_class(controller_api, self._controller) def __del__(self): self._stop_scan_tasks() @@ -41,7 +45,8 @@ async def _run_initial_coros(self): async def _start_scan_tasks(self): self._scan_tasks = { - self._loop.create_task(coro()) for coro in _get_scan_coros(self._controller) + self._loop.create_task(coro()) + for coro in _get_scan_coros(self.controller_api, self._controller) } def _stop_scan_tasks(self): @@ -53,16 +58,14 @@ def _stop_scan_tasks(self): pass -def _link_single_controller_put_tasks(single_mapping: SingleMapping) -> None: - for name, method in single_mapping.put_methods.items(): +def _link_put_tasks(controller_api: ControllerAPI) -> None: + for name, method in controller_api.put_methods.items(): name = name.removeprefix("put_") - attribute = single_mapping.attributes[name] + attribute = controller_api.attributes[name] match attribute: case AttrW(): - attribute.set_process_callback( - MethodType(method.fn, single_mapping.controller) - ) + attribute.set_process_callback(method.fn) case _: raise FastCSException( f"Mode {attribute.access_mode} does not " @@ -70,15 +73,17 @@ def _link_single_controller_put_tasks(single_mapping: SingleMapping) -> None: ) -def _link_attribute_sender_class(single_mapping: SingleMapping) -> None: - for attr_name, attribute in single_mapping.attributes.items(): +def _link_attribute_sender_class( + controller_api: ControllerAPI, controller: Controller +) -> None: + for attr_name, attribute in controller_api.attributes.items(): match attribute: case AttrW(sender=Sender()): assert not attribute.has_process_callback(), ( f"Cannot assign both put method and Sender object to {attr_name}" ) - callback = _create_sender_callback(attribute, single_mapping.controller) + callback = _create_sender_callback(attribute, controller) attribute.set_process_callback(callback) @@ -89,35 +94,35 @@ async def callback(value): return callback -def _get_scan_coros(controller: Controller) -> list[Callable]: +def _get_scan_coros( + root_controller_api: ControllerAPI, controller: Controller +) -> list[Callable]: scan_dict: dict[float, list[Callable]] = defaultdict(list) - for single_mapping in controller.get_controller_mappings(): - _add_scan_method_tasks(scan_dict, single_mapping) - _add_attribute_updater_tasks(scan_dict, single_mapping) + for controller_api in root_controller_api.walk_api(): + _add_scan_method_tasks(scan_dict, controller_api) + _add_attribute_updater_tasks(scan_dict, controller_api, controller) scan_coros = _get_periodic_scan_coros(scan_dict) return scan_coros def _add_scan_method_tasks( - scan_dict: dict[float, list[Callable]], single_mapping: SingleMapping + scan_dict: dict[float, list[Callable]], controller_api: ControllerAPI ): - for method in single_mapping.scan_methods.values(): - scan_dict[method.period].append( - MethodType(method.fn, single_mapping.controller) - ) + for method in controller_api.scan_methods.values(): + scan_dict[method.period].append(method.fn) def _add_attribute_updater_tasks( - scan_dict: dict[float, list[Callable]], single_mapping: SingleMapping + scan_dict: dict[float, list[Callable]], + controller_api: ControllerAPI, + controller: Controller, ): - for attribute in single_mapping.attributes.values(): + for attribute in controller_api.attributes.values(): match attribute: case AttrR(updater=Updater(update_period=update_period)) as attribute: - callback = _create_updater_callback( - attribute, single_mapping.controller - ) + callback = _create_updater_callback(attribute, controller) if update_period is not None: scan_dict[update_period].append(callback) @@ -155,3 +160,38 @@ async def scan_coro() -> None: await asyncio.gather(*[method() for method in methods]) return scan_coro + + +def build_controller_api(controller: Controller) -> ControllerAPI: + return _build_controller_api(controller, []) + + +def _build_controller_api(controller: BaseController, path: list[str]) -> ControllerAPI: + """Build a `ControllerAPI` for a `BaseController` and its sub controllers""" + scan_methods: dict[str, Scan] = {} + put_methods: dict[str, Put] = {} + command_methods: dict[str, Command] = {} + for attr_name in dir(controller): + attr = getattr(controller, attr_name) + match attr: + case Put(enabled=True): + put_methods[attr_name] = attr + case Scan(enabled=True): + scan_methods[attr_name] = attr + case Command(enabled=True): + command_methods[attr_name] = attr + case _: + pass + + return ControllerAPI( + path=path, + attributes=controller.attributes, + command_methods=command_methods, + put_methods=put_methods, + scan_methods=scan_methods, + sub_apis={ + name: _build_controller_api(sub_controller, path + [name]) + for name, sub_controller in controller.get_sub_controllers().items() + }, + description=controller.description, + ) diff --git a/src/fastcs/controller.py b/src/fastcs/controller.py index 0648b4326..c7fbc45fa 100755 --- a/src/fastcs/controller.py +++ b/src/fastcs/controller.py @@ -1,22 +1,9 @@ from __future__ import annotations -from collections.abc import Iterator from copy import copy -from dataclasses import dataclass from typing import get_type_hints -from .attributes import Attribute -from .cs_methods import Command, Put, Scan -from .wrappers import WrappedMethod - - -@dataclass -class SingleMapping: - controller: BaseController - scan_methods: dict[str, Scan] - put_methods: dict[str, Put] - command_methods: dict[str, Command] - attributes: dict[str, Attribute] +from fastcs.attributes import Attribute class BaseController: @@ -52,9 +39,26 @@ def set_path(self, path: list[str]): self._path = path def _bind_attrs(self) -> None: + """Search for `Attributes` and `Methods` to bind them to this instance. + + This method will search the attributes of this controller class to bind them to + this specific instance. For `Attribute`s, this is just a case of copying and + re-assigning to `self` to make it unique across multiple instances of this + controller class. For `Method`s, this requires creating a bound method from a + class method and a controller instance, so that it can be called from any + context with the controller instance passed as the `self` argument. + + """ + # Lazy import to avoid circular references + from fastcs.cs_methods import UnboundCommand, UnboundPut, UnboundScan + # Using a dictionary instead of a set to maintain order. - class_dir = {key: None for key in dir(type(self))} - class_type_hints = get_type_hints(type(self)) + class_dir = {key: None for key in dir(type(self)) if not key.startswith("_")} + class_type_hints = { + key: value + for key, value in get_type_hints(type(self)).items() + if not key.startswith("_") + } for attr_name in {**class_dir, **class_type_hints}: if attr_name == "root_attribute": @@ -73,6 +77,8 @@ def _bind_attrs(self) -> None: new_attribute = copy(attr) setattr(self, attr_name, new_attribute) self.attributes[attr_name] = new_attribute + elif isinstance(attr, UnboundPut | UnboundScan | UnboundCommand): + setattr(self, attr_name, attr.bind(self)) def register_sub_controller(self, name: str, sub_controller: SubController): if name in self.__sub_controller_tree.keys(): @@ -95,40 +101,6 @@ def register_sub_controller(self, name: str, sub_controller: SubController): def get_sub_controllers(self) -> dict[str, SubController]: return self.__sub_controller_tree - def get_controller_mappings(self) -> list[SingleMapping]: - return list(_walk_mappings(self)) - - -def _walk_mappings(controller: BaseController) -> Iterator[SingleMapping]: - yield _get_single_mapping(controller) - for sub_controller in controller.get_sub_controllers().values(): - yield from _walk_mappings(sub_controller) - - -def _get_single_mapping(controller: BaseController) -> SingleMapping: - scan_methods: dict[str, Scan] = {} - put_methods: dict[str, Put] = {} - command_methods: dict[str, Command] = {} - for attr_name in dir(controller): - attr = getattr(controller, attr_name) - match attr: - case WrappedMethod(fastcs_method=Put(enabled=True) as put_method): - put_methods[attr_name] = put_method - case WrappedMethod(fastcs_method=Scan(enabled=True) as scan_method): - scan_methods[attr_name] = scan_method - case WrappedMethod(fastcs_method=Command(enabled=True) as command_method): - command_methods[attr_name] = command_method - - enabled_attributes = { - name: attribute - for name, attribute in controller.attributes.items() - if attribute.enabled - } - - return SingleMapping( - controller, scan_methods, put_methods, command_methods, enabled_attributes - ) - class Controller(BaseController): """Top-level controller for a device. diff --git a/src/fastcs/controller_api.py b/src/fastcs/controller_api.py new file mode 100644 index 000000000..a9afa47f8 --- /dev/null +++ b/src/fastcs/controller_api.py @@ -0,0 +1,31 @@ +from collections.abc import Iterator +from dataclasses import dataclass, field + +from fastcs.attributes import Attribute +from fastcs.cs_methods import Command, Put, Scan + + +@dataclass +class ControllerAPI: + """Attributes, bound methods and sub APIs of a `Controller` / `SubController`""" + + path: list[str] = field(default_factory=list) + """Path within controller tree (empty if this is the root)""" + attributes: dict[str, Attribute] = field(default_factory=dict) + command_methods: dict[str, Command] = field(default_factory=dict) + put_methods: dict[str, Put] = field(default_factory=dict) + scan_methods: dict[str, Scan] = field(default_factory=dict) + sub_apis: dict[str, "ControllerAPI"] = field(default_factory=dict) + """APIs of the sub controllers of the `Controller` this API was built from""" + description: str | None = None + + def walk_api(self) -> Iterator["ControllerAPI"]: + """Walk through all the nested `ControllerAPIs` of this `ControllerAPI` + + yields: `ControllerAPI`s from a depth-first traversal of the tree, including + self. + + """ + yield self + for api in self.sub_apis.values(): + yield from api.walk_api() diff --git a/src/fastcs/cs_methods.py b/src/fastcs/cs_methods.py index 0e861f708..fbab7e1a5 100644 --- a/src/fastcs/cs_methods.py +++ b/src/fastcs/cs_methods.py @@ -1,14 +1,40 @@ from asyncio import iscoroutinefunction -from collections.abc import Awaitable, Callable +from collections.abc import Callable, Coroutine from inspect import Signature, getdoc, signature +from types import MethodType +from typing import Any, Generic, TypeVar -from .exceptions import FastCSException - -ScanCallback = Callable[..., Awaitable[None]] +from fastcs.controller import BaseController +from .exceptions import FastCSException -class Method: - def __init__(self, fn: Callable, *, group: str | None = None) -> None: +MethodCallback = Callable[..., Coroutine[None, None, None]] +"""Generic base class for all `Controller` methods""" +Controller_T = TypeVar("Controller_T", bound=BaseController) +"""Generic `Controller` class that an unbound method must be called with as `self`""" +UnboundCommandCallback = Callable[[Controller_T], Coroutine[None, None, None]] +"""A Command callback that is unbound and must be called with a `Controller` instance""" +UnboundScanCallback = Callable[[Controller_T], Coroutine[None, None, None]] +"""A Scan callback that is unbound and must be called with a `Controller` instance""" +UnboundPutCallback = Callable[[Controller_T, Any], Coroutine[None, None, None]] +"""A Put callback that is unbound and must be called with a `Controller` instance""" +CommandCallback = Callable[[], Coroutine[None, None, None]] +"""A Command callback that is bound and can be called without `self`""" +ScanCallback = Callable[[], Coroutine[None, None, None]] +"""A Scan callback that is bound and can be called withous `self`""" +PutCallback = Callable[[Any], Coroutine[None, None, None]] +"""A Put callback that is bound and can be called without `self`""" + + +method_not_bound_error = NotImplementedError( + "Method must be bound to a controller instance to be callable" +) + + +class Method(Generic[Controller_T]): + """Generic base class for all FastCS Controller methods.""" + + def __init__(self, fn: MethodCallback, *, group: str | None = None) -> None: self._docstring = getdoc(fn) sig = signature(fn, eval_str=True) @@ -20,7 +46,7 @@ def __init__(self, fn: Callable, *, group: str | None = None) -> None: self._group = group self.enabled = True - def _validate(self, fn: Callable) -> None: + def _validate(self, fn: MethodCallback) -> None: if self.return_type not in (None, Signature.empty): raise FastCSException("Method return type must be None or empty") @@ -48,40 +74,142 @@ def group(self): return self._group -class Scan(Method): - def __init__(self, fn: Callable, period) -> None: - super().__init__(fn) +class Command(Method[BaseController]): + """A `Controller` `Method` that performs a single action when called. - self._period = period + This class contains a function that is bound to a specific `Controller` instance and + is callable outside of the class context, without an explicit `self` parameter. + Calling an instance of this class will call the bound `Controller` method. + """ + + def __init__(self, fn: CommandCallback, *, group: str | None = None): + super().__init__(fn, group=group) - def _validate(self, fn: Callable) -> None: + def _validate(self, fn: CommandCallback) -> None: super()._validate(fn) - if not len(self.parameters) == 1: - raise FastCSException("Scan method cannot have arguments") + if not len(self.parameters) == 0: + raise FastCSException(f"Command method cannot have arguments: {fn}") + + async def __call__(self): + return await self._fn() + + +class Scan(Method[BaseController]): + """A `Controller` `Method` that will be called periodically in the background. + + This class contains a function that is bound to a specific `Controller` instance and + is callable outside of the class context, without an explicit `self` parameter. + Calling an instance of this class will call the bound `Controller` method. + """ + + def __init__(self, fn: ScanCallback, period: float): + super().__init__(fn) + + self._period = period @property def period(self): return self._period + def _validate(self, fn: ScanCallback) -> None: + super()._validate(fn) + + if not len(self.parameters) == 0: + raise FastCSException("Scan method cannot have arguments") + + async def __call__(self): + return await self._fn() + -class Put(Method): - def __init__(self, fn: Callable) -> None: +class Put(Method[BaseController]): + def __init__(self, fn: PutCallback): super().__init__(fn) - def _validate(self, fn: Callable) -> None: + def _validate(self, fn: PutCallback) -> None: super()._validate(fn) - if not len(self.parameters) == 2: + if not len(self.parameters) == 1: raise FastCSException("Put method can only take one argument") + async def __call__(self, value: Any): + return await self._fn(value) + + +class UnboundCommand(Method[Controller_T]): + """A wrapper of an unbound `Controller` method to be bound into a `Command`. -class Command(Method): - def __init__(self, fn: Callable, *, group: str | None = None) -> None: + This generic class stores an unbound `Controller` method - effectively a function + that takes an instance of a specific `Controller` type (`Controller_T`). Instances + of this class can be added at `Controller` definition, either manually or with use + of the `@command` wrapper, to register the method to be included in the API of the + `Controller`. When the `Controller` is instantiated, these instances will be bound + to the instance, creating a `Command` instance. + """ + + def __init__( + self, fn: UnboundCommandCallback[Controller_T], *, group: str | None = None + ) -> None: super().__init__(fn, group=group) - def _validate(self, fn: Callable) -> None: + def _validate(self, fn: UnboundCommandCallback[Controller_T]) -> None: super()._validate(fn) if not len(self.parameters) == 1: raise FastCSException("Command method cannot have arguments") + + def bind(self, controller: Controller_T) -> Command: + return Command(MethodType(self.fn, controller)) + + def __call__(self): + raise method_not_bound_error + + +class UnboundScan(Method[Controller_T]): + """A wrapper of an unbound `Controller` method to be bound into a `Scan`. + + This generic class stores an unbound `Controller` method - effectively a function + that takes an instance of a specific `Controller` type (`Controller_T`). Instances + of this class can be added at `Controller` definition, either manually or with use + of the `@scan` wrapper, to register the method to be included in the API of the + `Controller`. When the `Controller` is instantiated, these instances will be bound + to the instance, creating a `Scan` instance. + """ + + def __init__(self, fn: UnboundScanCallback[Controller_T], period: float) -> None: + super().__init__(fn) + + self._period = period + + @property + def period(self): + return self._period + + def _validate(self, fn: UnboundScanCallback[Controller_T]) -> None: + super()._validate(fn) + + if not len(self.parameters) == 1: + raise FastCSException("Scan method cannot have arguments") + + def bind(self, controller: Controller_T) -> Scan: + return Scan(MethodType(self.fn, controller), self._period) + + def __call__(self): + raise method_not_bound_error + + +class UnboundPut(Method[Controller_T]): + def __init__(self, fn: UnboundPutCallback[Controller_T]) -> None: + super().__init__(fn) + + def _validate(self, fn: UnboundPutCallback[Controller_T]) -> None: + super()._validate(fn) + + if not len(self.parameters) == 2: + raise FastCSException("Put method can only take one argument") + + def bind(self, controller: Controller_T) -> Put: + return Put(MethodType(self.fn, controller)) + + def __call__(self): + raise method_not_bound_error diff --git a/src/fastcs/launch.py b/src/fastcs/launch.py index 25b387e31..3f4451115 100644 --- a/src/fastcs/launch.py +++ b/src/fastcs/launch.py @@ -43,14 +43,14 @@ def __init__( from .transport.epics.pva.adapter import EpicsPVATransport transport = EpicsPVATransport( - controller, + self._backend.controller_api, option, ) case EpicsCAOptions(): from .transport.epics.ca.adapter import EpicsCATransport transport = EpicsCATransport( - controller, + self._backend.controller_api, self._loop, option, ) @@ -58,7 +58,7 @@ def __init__( from .transport.tango.adapter import TangoTransport transport = TangoTransport( - controller, + self._backend.controller_api, self._loop, option, ) @@ -66,14 +66,14 @@ def __init__( from .transport.rest.adapter import RestTransport transport = RestTransport( - controller, + self._backend.controller_api, option, ) case GraphQLOptions(): from .transport.graphQL.adapter import GraphQLTransport transport = GraphQLTransport( - controller, + self._backend.controller_api, option, ) diff --git a/src/fastcs/transport/epics/ca/adapter.py b/src/fastcs/transport/epics/ca/adapter.py index ec490bc7a..00b8c97a5 100644 --- a/src/fastcs/transport/epics/ca/adapter.py +++ b/src/fastcs/transport/epics/ca/adapter.py @@ -1,6 +1,6 @@ import asyncio -from fastcs.controller import Controller +from fastcs.controller_api import ControllerAPI from fastcs.transport.adapter import TransportAdapter from fastcs.transport.epics.ca.ioc import EpicsCAIOC from fastcs.transport.epics.ca.options import EpicsCAOptions @@ -11,17 +11,17 @@ class EpicsCATransport(TransportAdapter): def __init__( self, - controller: Controller, + controller_api: ControllerAPI, loop: asyncio.AbstractEventLoop, options: EpicsCAOptions | None = None, ) -> None: - self._controller = controller + self._controller_api = controller_api self._loop = loop self._options = options or EpicsCAOptions() self._pv_prefix = self.options.ioc.pv_prefix self._ioc = EpicsCAIOC( self.options.ioc.pv_prefix, - controller, + controller_api, self._options.ioc, ) @@ -30,10 +30,10 @@ def options(self) -> EpicsCAOptions: return self._options def create_docs(self) -> None: - EpicsDocs(self._controller).create_docs(self.options.docs) + EpicsDocs(self._controller_api).create_docs(self.options.docs) def create_gui(self) -> None: - EpicsGUI(self._controller, self._pv_prefix).create_gui(self.options.gui) + EpicsGUI(self._controller_api, self._pv_prefix).create_gui(self.options.gui) async def serve(self) -> None: print(f"Running FastCS IOC: {self._pv_prefix}") diff --git a/src/fastcs/transport/epics/ca/ioc.py b/src/fastcs/transport/epics/ca/ioc.py index 8dc1d891a..86247549f 100644 --- a/src/fastcs/transport/epics/ca/ioc.py +++ b/src/fastcs/transport/epics/ca/ioc.py @@ -1,6 +1,5 @@ import asyncio from collections.abc import Callable -from types import MethodType from typing import Any, Literal from softioc import builder, softioc @@ -8,7 +7,7 @@ from softioc.pythonSoftIoc import RecordWrapper from fastcs.attributes import AttrR, AttrRW, AttrW -from fastcs.controller import BaseController, Controller +from fastcs.controller_api import ControllerAPI from fastcs.datatypes import DataType, T from fastcs.transport.epics.ca.util import ( builder_callable_from_attribute, @@ -26,16 +25,16 @@ class EpicsCAIOC: def __init__( self, pv_prefix: str, - controller: Controller, + controller_api: ControllerAPI, options: EpicsIOCOptions | None = None, ): self._options = options or EpicsIOCOptions() - self._controller = controller + self._controller_api = controller_api _add_pvi_info(f"{pv_prefix}:PVI") - _add_sub_controller_pvi_info(pv_prefix, controller) + _add_sub_controller_pvi_info(pv_prefix, controller_api) - _create_and_link_attribute_pvs(pv_prefix, controller) - _create_and_link_command_pvs(pv_prefix, controller) + _create_and_link_attribute_pvs(pv_prefix, controller_api) + _create_and_link_command_pvs(pv_prefix, controller_api) def run( self, @@ -91,7 +90,7 @@ def _add_pvi_info( record.add_info("Q:group", q_group) -def _add_sub_controller_pvi_info(pv_prefix: str, parent: BaseController): +def _add_sub_controller_pvi_info(pv_prefix: str, parent: ControllerAPI): """Add PVI references from controller to its sub controllers, recursively. Args: @@ -101,7 +100,7 @@ def _add_sub_controller_pvi_info(pv_prefix: str, parent: BaseController): """ parent_pvi = ":".join([pv_prefix] + parent.path + ["PVI"]) - for child in parent.get_sub_controllers().values(): + for child in parent.sub_apis.values(): child_pvi = ":".join([pv_prefix] + child.path + ["PVI"]) child_name = child.path[-1].lower() @@ -110,10 +109,12 @@ def _add_sub_controller_pvi_info(pv_prefix: str, parent: BaseController): _add_sub_controller_pvi_info(pv_prefix, child) -def _create_and_link_attribute_pvs(pv_prefix: str, controller: Controller) -> None: - for single_mapping in controller.get_controller_mappings(): - path = single_mapping.controller.path - for attr_name, attribute in single_mapping.attributes.items(): +def _create_and_link_attribute_pvs( + pv_prefix: str, root_controller_api: ControllerAPI +) -> None: + for controller_api in root_controller_api.walk_api(): + path = controller_api.path + for attr_name, attribute in controller_api.attributes.items(): pv_name = attr_name.title().replace("_", "") _pv_prefix = ":".join([pv_prefix] + path) full_pv_name_length = len(f"{_pv_prefix}:{pv_name}") @@ -122,7 +123,7 @@ def _create_and_link_attribute_pvs(pv_prefix: str, controller: Controller) -> No attribute.enabled = False print( f"Not creating PV for {attr_name} for controller" - f" {single_mapping.controller.path} as full name would exceed" + f" {controller_api.path} as full name would exceed" f" {EPICS_MAX_NAME_LENGTH} characters" ) continue @@ -202,10 +203,12 @@ async def async_write_display(value: T): attribute.set_write_display_callback(async_write_display) -def _create_and_link_command_pvs(pv_prefix: str, controller: Controller) -> None: - for single_mapping in controller.get_controller_mappings(): - path = single_mapping.controller.path - for attr_name, method in single_mapping.command_methods.items(): +def _create_and_link_command_pvs( + pv_prefix: str, root_controller_api: ControllerAPI +) -> None: + for controller_api in root_controller_api.walk_api(): + path = controller_api.path + for attr_name, method in controller_api.command_methods.items(): pv_name = attr_name.title().replace("_", "") _pv_prefix = ":".join([pv_prefix] + path) if len(f"{_pv_prefix}:{pv_name}") > EPICS_MAX_NAME_LENGTH: @@ -219,7 +222,7 @@ def _create_and_link_command_pvs(pv_prefix: str, controller: Controller) -> None _pv_prefix, pv_name, attr_name, - MethodType(method.fn, single_mapping.controller), + method.fn, ) diff --git a/src/fastcs/transport/epics/docs.py b/src/fastcs/transport/epics/docs.py index bec5469d2..1a087d159 100644 --- a/src/fastcs/transport/epics/docs.py +++ b/src/fastcs/transport/epics/docs.py @@ -1,11 +1,11 @@ -from fastcs.controller import Controller +from fastcs.controller_api import ControllerAPI from .options import EpicsDocsOptions class EpicsDocs: - def __init__(self, controller: Controller) -> None: - self._controller = controller + def __init__(self, controller_apis: ControllerAPI) -> None: + self._controller_apis = controller_apis def create_docs(self, options: EpicsDocsOptions | None = None) -> None: if options is None: diff --git a/src/fastcs/transport/epics/gui.py b/src/fastcs/transport/epics/gui.py index 4fb2a4099..5c44e5c10 100644 --- a/src/fastcs/transport/epics/gui.py +++ b/src/fastcs/transport/epics/gui.py @@ -23,7 +23,7 @@ from pydantic import ValidationError from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW -from fastcs.controller import Controller, SingleMapping, _get_single_mapping +from fastcs.controller_api import ControllerAPI from fastcs.cs_methods import Command from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.exceptions import FastCSException @@ -33,8 +33,8 @@ class EpicsGUI: - def __init__(self, controller: Controller, pv_prefix: str) -> None: - self._controller = controller + def __init__(self, controller_api: ControllerAPI, pv_prefix: str) -> None: + self._controller_api = controller_api self._pv_prefix = pv_prefix def _get_pv(self, attr_path: list[str], name: str): @@ -128,33 +128,29 @@ def create_gui(self, options: EpicsGUIOptions | None = None) -> None: assert options.output_path.suffix == options.file_format.value options.output_path.parent.mkdir(parents=True, exist_ok=True) - controller_mapping = self._controller.get_controller_mappings()[0] - components = self.extract_mapping_components(controller_mapping) + components = self.extract_api_components(self._controller_api) device = Device(label=options.title, children=components) formatter = DLSFormatter() formatter.format(device, options.output_path.resolve()) - def extract_mapping_components(self, mapping: SingleMapping) -> Tree: + def extract_api_components(self, controller_api: ControllerAPI) -> Tree: components: Tree = [] - attr_path = mapping.controller.path - for name, sub_controller in mapping.controller.get_sub_controllers().items(): + for name, api in controller_api.sub_apis.items(): components.append( Group( name=snake_to_pascal(name), layout=SubScreen(), - children=self.extract_mapping_components( - _get_single_mapping(sub_controller) - ), + children=self.extract_api_components(api), ) ) groups: dict[str, list[ComponentUnion]] = {} - for attr_name, attribute in mapping.attributes.items(): + for attr_name, attribute in controller_api.attributes.items(): try: signal = self._get_attribute_component( - attr_path, + controller_api.path, attr_name, attribute, ) @@ -177,8 +173,8 @@ def extract_mapping_components(self, mapping: SingleMapping) -> Tree: case _: components.append(signal) - for name, command in mapping.command_methods.items(): - signal = self._get_command_component(attr_path, name) + for name, command in controller_api.command_methods.items(): + signal = self._get_command_component(controller_api.path, name) match command: case Command(group=group) if group is not None: diff --git a/src/fastcs/transport/epics/pva/_pv_handlers.py b/src/fastcs/transport/epics/pva/_pv_handlers.py index 141942e63..0e8195542 100644 --- a/src/fastcs/transport/epics/pva/_pv_handlers.py +++ b/src/fastcs/transport/epics/pva/_pv_handlers.py @@ -1,5 +1,3 @@ -from collections.abc import Callable - import numpy as np from p4p import Value from p4p.nt import NTEnum, NTNDArray, NTScalar, NTTable @@ -9,6 +7,7 @@ from p4p.server.asyncio import SharedPV from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW +from fastcs.cs_methods import CommandCallback from fastcs.datatypes import Table from .types import ( @@ -52,7 +51,7 @@ async def put(self, pv: SharedPV, op: ServerOperation): class CommandPvHandler: - def __init__(self, command: Callable): + def __init__(self, command: CommandCallback): self._command = command self._task_in_progress = False @@ -125,7 +124,7 @@ async def on_update(value): return shared_pv -def make_command_pv(command: Callable) -> SharedPV: +def make_command_pv(command: CommandCallback) -> SharedPV: type_ = NTScalar.buildType("?", display=True, control=True) initial = Value(type_, {"value": False, **p4p_alarm_states()}) diff --git a/src/fastcs/transport/epics/pva/adapter.py b/src/fastcs/transport/epics/pva/adapter.py index c50ab0175..b3d34f3f1 100644 --- a/src/fastcs/transport/epics/pva/adapter.py +++ b/src/fastcs/transport/epics/pva/adapter.py @@ -1,4 +1,4 @@ -from fastcs.controller import Controller +from fastcs.controller_api import ControllerAPI from fastcs.transport.adapter import TransportAdapter from fastcs.transport.epics.docs import EpicsDocs from fastcs.transport.epics.gui import EpicsGUI @@ -10,13 +10,13 @@ class EpicsPVATransport(TransportAdapter): def __init__( self, - controller: Controller, + controller_api: ControllerAPI, options: EpicsPVAOptions | None = None, ) -> None: - self._controller = controller + self._controller_api = controller_api self._options = options or EpicsPVAOptions() self._pv_prefix = self.options.ioc.pv_prefix - self._ioc = P4PIOC(self.options.ioc.pv_prefix, controller) + self._ioc = P4PIOC(self.options.ioc.pv_prefix, controller_api) @property def options(self) -> EpicsPVAOptions: @@ -27,7 +27,7 @@ async def serve(self) -> None: await self._ioc.run() def create_docs(self) -> None: - EpicsDocs(self._controller).create_docs(self.options.docs) + EpicsDocs(self._controller_api).create_docs(self.options.docs) def create_gui(self) -> None: - EpicsGUI(self._controller, self._pv_prefix).create_gui(self.options.gui) + EpicsGUI(self._controller_api, self._pv_prefix).create_gui(self.options.gui) diff --git a/src/fastcs/transport/epics/pva/ioc.py b/src/fastcs/transport/epics/pva/ioc.py index 5c26fbcf4..18655211a 100644 --- a/src/fastcs/transport/epics/pva/ioc.py +++ b/src/fastcs/transport/epics/pva/ioc.py @@ -1,11 +1,10 @@ import asyncio import re -from types import MethodType from p4p.server import Server, StaticProvider from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW -from fastcs.controller import Controller +from fastcs.controller_api import ControllerAPI from ._pv_handlers import make_command_pv, make_shared_pv from .pvi_tree import AccessModeType, PviTree @@ -36,31 +35,25 @@ def get_pv_name(pv_prefix: str, *attribute_names: str) -> str: async def parse_attributes( - root_pv_prefix: str, controller: Controller + root_pv_prefix: str, root_controller_api: ControllerAPI ) -> list[StaticProvider]: pvi_tree = PviTree(root_pv_prefix) provider = StaticProvider(root_pv_prefix) - for single_mapping in controller.get_controller_mappings(): - path = single_mapping.controller.path - pv_prefix = get_pv_name(root_pv_prefix, *path) + for controller_api in root_controller_api.walk_api(): + pv_prefix = get_pv_name(root_pv_prefix, *controller_api.path) - pvi_tree.add_sub_device( - pv_prefix, - single_mapping.controller.description, - ) + pvi_tree.add_sub_device(pv_prefix, controller_api.description) - for attr_name, attribute in single_mapping.attributes.items(): + for attr_name, attribute in controller_api.attributes.items(): pv_name = get_pv_name(pv_prefix, attr_name) attribute_pv = make_shared_pv(attribute) provider.add(pv_name, attribute_pv) pvi_tree.add_signal(pv_name, _attribute_to_access(attribute)) - for attr_name, method in single_mapping.command_methods.items(): + for attr_name, method in controller_api.command_methods.items(): pv_name = get_pv_name(pv_prefix, attr_name) - command_pv = make_command_pv( - MethodType(method.fn, single_mapping.controller) - ) + command_pv = make_command_pv(method.fn) provider.add(pv_name, command_pv) pvi_tree.add_signal(pv_name, "x") @@ -68,16 +61,12 @@ async def parse_attributes( class P4PIOC: - def __init__( - self, - pv_prefix: str, - controller: Controller, - ): + def __init__(self, pv_prefix: str, controller_api: ControllerAPI): self.pv_prefix = pv_prefix - self.controller = controller + self.controller_api = controller_api async def run(self): - providers = await parse_attributes(self.pv_prefix, self.controller) + providers = await parse_attributes(self.pv_prefix, self.controller_api) endless_event = asyncio.Event() with Server(providers): diff --git a/src/fastcs/transport/graphQL/adapter.py b/src/fastcs/transport/graphQL/adapter.py index 2c7eaec7c..79853f2f7 100644 --- a/src/fastcs/transport/graphQL/adapter.py +++ b/src/fastcs/transport/graphQL/adapter.py @@ -1,4 +1,4 @@ -from fastcs.controller import Controller +from fastcs.controller_api import ControllerAPI from fastcs.transport.adapter import TransportAdapter from .graphQL import GraphQLServer @@ -8,11 +8,11 @@ class GraphQLTransport(TransportAdapter): def __init__( self, - controller: Controller, + controller_api: ControllerAPI, options: GraphQLOptions | None = None, ): self._options = options or GraphQLOptions() - self._server = GraphQLServer(controller) + self._server = GraphQLServer(controller_api) @property def options(self) -> GraphQLOptions: diff --git a/src/fastcs/transport/graphQL/graphQL.py b/src/fastcs/transport/graphQL/graphQL.py index 106edc47d..1be19d3b3 100644 --- a/src/fastcs/transport/graphQL/graphQL.py +++ b/src/fastcs/transport/graphQL/graphQL.py @@ -8,24 +8,19 @@ from strawberry.types.field import StrawberryField from fastcs.attributes import AttrR, AttrRW, AttrW, T -from fastcs.controller import ( - BaseController, - Controller, - SingleMapping, - _get_single_mapping, -) +from fastcs.controller_api import ControllerAPI from fastcs.exceptions import FastCSException from .options import GraphQLServerOptions class GraphQLServer: - def __init__(self, controller: Controller): - self._controller = controller + def __init__(self, controller_api: ControllerAPI): + self._controller_api = controller_api self._app = self._create_app() def _create_app(self) -> GraphQL: - api = GraphQLAPI(self._controller) + api = GraphQLAPI(self._controller_api) schema = api.create_schema() app = GraphQL(schema) @@ -45,19 +40,17 @@ async def serve(self, options: GraphQLServerOptions | None = None) -> None: class GraphQLAPI: - """A Strawberry API built dynamically from a Controller""" + """A Strawberry API built dynamically from a `ControllerAPI`""" - def __init__(self, controller: BaseController): + def __init__(self, controller_api: ControllerAPI): self.queries: list[StrawberryField] = [] self.mutations: list[StrawberryField] = [] - api = _get_single_mapping(controller) + self._process_attributes(controller_api) + self._process_commands(controller_api) + self._process_sub_apis(controller_api) - self._process_attributes(api) - self._process_commands(api) - self._process_sub_controllers(api) - - def _process_attributes(self, api: SingleMapping): + def _process_attributes(self, api: ControllerAPI): """Create queries and mutations from api attributes.""" for attr_name, attribute in api.attributes.items(): match attribute: @@ -78,18 +71,16 @@ def _process_attributes(self, api: SingleMapping): strawberry.mutation(_wrap_attr_set(attr_name, attribute)) ) - def _process_commands(self, api: SingleMapping): + def _process_commands(self, controller_api: ControllerAPI): """Create mutations from api commands""" - for cmd_name, method in api.command_methods.items(): - self.mutations.append( - strawberry.mutation(_wrap_command(cmd_name, method.fn, api.controller)) - ) - - def _process_sub_controllers(self, api: SingleMapping): - """Recursively add fields from the queries and mutations of sub controllers""" - for sub_controller in api.controller.get_sub_controllers().values(): - name = "".join(sub_controller.path) - child_tree = GraphQLAPI(sub_controller) + for name, method in controller_api.command_methods.items(): + self.mutations.append(strawberry.mutation(_wrap_command(name, method.fn))) + + def _process_sub_apis(self, root_controller_api: ControllerAPI): + """Recursively add fields from the queries and mutations of sub apis""" + for controller_api in root_controller_api.sub_apis.values(): + name = "".join(controller_api.path) + child_tree = GraphQLAPI(controller_api) if child_tree.queries: self.queries.append( _wrap_as_field( @@ -107,7 +98,8 @@ def create_schema(self) -> strawberry.Schema: """Create a Strawberry Schema to load into a GraphQL application.""" if not self.queries: raise FastCSException( - "Can't create GraphQL transport from Controller with no read attributes" + "Can't create GraphQL transport from ControllerAPI with no read " + "attributes" ) query = create_type("Query", self.queries) @@ -159,13 +151,11 @@ def _dynamic_field(): return strawberry.field(_dynamic_field) -def _wrap_command( - method_name: str, method: Callable, controller: BaseController -) -> Callable[..., Awaitable[bool]]: +def _wrap_command(method_name: str, method: Callable) -> Callable[..., Awaitable[bool]]: """Wrap a command in a function with annotations for strawberry""" async def _dynamic_f() -> bool: - await getattr(controller, method.__name__)() + await method() return True _dynamic_f.__name__ = method_name diff --git a/src/fastcs/transport/rest/adapter.py b/src/fastcs/transport/rest/adapter.py index 7e98c5da9..cac64ae49 100644 --- a/src/fastcs/transport/rest/adapter.py +++ b/src/fastcs/transport/rest/adapter.py @@ -1,4 +1,4 @@ -from fastcs.controller import Controller +from fastcs.controller_api import ControllerAPI from fastcs.transport.adapter import TransportAdapter from .options import RestOptions @@ -8,11 +8,11 @@ class RestTransport(TransportAdapter): def __init__( self, - controller: Controller, + controller_api: ControllerAPI, options: RestOptions | None = None, ): self._options = options or RestOptions() - self._server = RestServer(controller) + self._server = RestServer(controller_api) @property def options(self) -> RestOptions: diff --git a/src/fastcs/transport/rest/rest.py b/src/fastcs/transport/rest/rest.py index 10b36a392..5fcadc238 100644 --- a/src/fastcs/transport/rest/rest.py +++ b/src/fastcs/transport/rest/rest.py @@ -1,4 +1,4 @@ -from collections.abc import Awaitable, Callable, Coroutine +from collections.abc import Callable, Coroutine from typing import Any import uvicorn @@ -6,7 +6,8 @@ from pydantic import create_model from fastcs.attributes import AttrR, AttrRW, AttrW, T -from fastcs.controller import BaseController, Controller +from fastcs.controller_api import ControllerAPI +from fastcs.cs_methods import CommandCallback from .options import RestServerOptions from .util import ( @@ -17,14 +18,14 @@ class RestServer: - def __init__(self, controller: Controller): - self._controller = controller + def __init__(self, controller_api: ControllerAPI): + self._controller_api = controller_api self._app = self._create_app() def _create_app(self): app = FastAPI() - _add_attribute_api_routes(app, self._controller) - _add_command_api_routes(app, self._controller) + _add_attribute_api_routes(app, self._controller_api) + _add_command_api_routes(app, self._controller_api) return app @@ -91,11 +92,11 @@ async def attr_get() -> Any: # Must be any as response_model is set return attr_get -def _add_attribute_api_routes(app: FastAPI, controller: Controller) -> None: - for single_mapping in controller.get_controller_mappings(): - path = single_mapping.controller.path +def _add_attribute_api_routes(app: FastAPI, root_controller_api: ControllerAPI) -> None: + for controller_api in root_controller_api.walk_api(): + path = controller_api.path - for attr_name, attribute in single_mapping.attributes.items(): + for attr_name, attribute in controller_api.attributes.items(): attr_name = attr_name.replace("_", "-") route = f"{'/'.join(path)}/{attr_name}" if path else attr_name @@ -133,27 +134,24 @@ def _add_attribute_api_routes(app: FastAPI, controller: Controller) -> None: def _wrap_command( - method: Callable, controller: BaseController -) -> Callable[..., Awaitable[None]]: + method: CommandCallback, +) -> Callable[..., Coroutine[None, None, None]]: async def command() -> None: - await getattr(controller, method.__name__)() + await method() return command -def _add_command_api_routes(app: FastAPI, controller: Controller) -> None: - for single_mapping in controller.get_controller_mappings(): - path = single_mapping.controller.path +def _add_command_api_routes(app: FastAPI, root_controller_api: ControllerAPI) -> None: + for controller_api in root_controller_api.walk_api(): + path = controller_api.path - for name, method in single_mapping.command_methods.items(): + for name, method in root_controller_api.command_methods.items(): cmd_name = name.replace("_", "-") route = f"/{'/'.join(path)}/{cmd_name}" if path else cmd_name app.add_api_route( f"/{route}", - _wrap_command( - method.fn, - single_mapping.controller, - ), + _wrap_command(method.fn), methods=["PUT"], status_code=204, ) diff --git a/src/fastcs/transport/tango/adapter.py b/src/fastcs/transport/tango/adapter.py index 40018694b..96938f7b9 100644 --- a/src/fastcs/transport/tango/adapter.py +++ b/src/fastcs/transport/tango/adapter.py @@ -1,6 +1,6 @@ import asyncio -from fastcs.controller import Controller +from fastcs.controller_api import ControllerAPI from fastcs.transport.adapter import TransportAdapter from .dsr import TangoDSR @@ -10,12 +10,12 @@ class TangoTransport(TransportAdapter): def __init__( self, - controller: Controller, + controller_api: ControllerAPI, loop: asyncio.AbstractEventLoop, options: TangoOptions | None = None, ): self._options = options or TangoOptions() - self._dsr = TangoDSR(controller, loop) + self._dsr = TangoDSR(controller_api, loop) @property def options(self) -> TangoOptions: diff --git a/src/fastcs/transport/tango/dsr.py b/src/fastcs/transport/tango/dsr.py index 64dae60cb..cfefc9cbe 100644 --- a/src/fastcs/transport/tango/dsr.py +++ b/src/fastcs/transport/tango/dsr.py @@ -7,7 +7,8 @@ from tango.server import Device from fastcs.attributes import AttrR, AttrRW, AttrW -from fastcs.controller import BaseController +from fastcs.controller_api import ControllerAPI +from fastcs.cs_methods import CommandCallback from .options import TangoDSROptions from .util import ( @@ -21,7 +22,7 @@ def _wrap_updater_fget( attr_name: str, attribute: AttrR, - controller: BaseController, + controller_api: ControllerAPI, ) -> Callable[[Any], Any]: async def fget(tango_device: Device): tango_device.info_stream(f"called fget method: {attr_name}") @@ -44,7 +45,7 @@ async def _run_threadsafe_blocking( def _wrap_updater_fset( attr_name: str, attribute: AttrW, - controller: BaseController, + controller_api: ControllerAPI, loop: asyncio.AbstractEventLoop, ) -> Callable[[Any, Any], Any]: async def fset(tango_device: Device, value): @@ -56,13 +57,13 @@ async def fset(tango_device: Device, value): def _collect_dev_attributes( - controller: BaseController, loop: asyncio.AbstractEventLoop + root_controller_api: ControllerAPI, loop: asyncio.AbstractEventLoop ) -> dict[str, Any]: collection: dict[str, Any] = {} - for single_mapping in controller.get_controller_mappings(): - path = single_mapping.controller.path + for controller_api in root_controller_api.walk_api(): + path = controller_api.path - for attr_name, attribute in single_mapping.attributes.items(): + for attr_name, attribute in controller_api.attributes.items(): attr_name = attr_name.title().replace("_", "") d_attr_name = f"{'_'.join(path)}_{attr_name}" if path else attr_name @@ -70,11 +71,9 @@ def _collect_dev_attributes( case AttrRW(): collection[d_attr_name] = server.attribute( label=d_attr_name, - fget=_wrap_updater_fget( - attr_name, attribute, single_mapping.controller - ), + fget=_wrap_updater_fget(attr_name, attribute, controller_api), fset=_wrap_updater_fset( - attr_name, attribute, single_mapping.controller, loop + attr_name, attribute, controller_api, loop ), access=AttrWriteType.READ_WRITE, **get_server_metadata_from_attribute(attribute), @@ -84,9 +83,7 @@ def _collect_dev_attributes( collection[d_attr_name] = server.attribute( label=d_attr_name, access=AttrWriteType.READ, - fget=_wrap_updater_fget( - attr_name, attribute, single_mapping.controller - ), + fget=_wrap_updater_fget(attr_name, attribute, controller_api), **get_server_metadata_from_attribute(attribute), **get_server_metadata_from_datatype(attribute.datatype), ) @@ -95,7 +92,7 @@ def _collect_dev_attributes( label=d_attr_name, access=AttrWriteType.WRITE, fset=_wrap_updater_fset( - attr_name, attribute, single_mapping.controller, loop + attr_name, attribute, controller_api, loop ), **get_server_metadata_from_attribute(attribute), **get_server_metadata_from_datatype(attribute.datatype), @@ -106,15 +103,16 @@ def _collect_dev_attributes( def _wrap_command_f( method_name: str, - method: Callable, - controller: BaseController, + method: CommandCallback, + controller_api: ControllerAPI, loop: asyncio.AbstractEventLoop, ) -> Callable[..., Awaitable[None]]: async def _dynamic_f(tango_device: Device) -> None: tango_device.info_stream( - f"called {'_'.join(controller.path)} f method: {method_name}" + f"called {'_'.join(controller_api.path)} f method: {method_name}" ) - coro = getattr(controller, method.__name__)() + + coro = method() await _run_threadsafe_blocking(coro, loop) _dynamic_f.__name__ = method_name @@ -122,31 +120,29 @@ async def _dynamic_f(tango_device: Device) -> None: def _collect_dev_commands( - controller: BaseController, + root_controller_api: ControllerAPI, loop: asyncio.AbstractEventLoop, ) -> dict[str, Any]: collection: dict[str, Any] = {} - for single_mapping in controller.get_controller_mappings(): - path = single_mapping.controller.path + for controller_api in root_controller_api.walk_api(): + path = controller_api.path - for name, method in single_mapping.command_methods.items(): + for name, method in controller_api.command_methods.items(): cmd_name = name.title().replace("_", "") d_cmd_name = f"{'_'.join(path)}_{cmd_name}" if path else cmd_name collection[d_cmd_name] = server.command( - f=_wrap_command_f( - d_cmd_name, method.fn, single_mapping.controller, loop - ) + f=_wrap_command_f(d_cmd_name, method.fn, controller_api, loop) ) return collection -def _collect_dev_properties(controller: BaseController) -> dict[str, Any]: +def _collect_dev_properties(controller_api: ControllerAPI) -> dict[str, Any]: collection: dict[str, Any] = {} return collection -def _collect_dev_init(controller: BaseController) -> dict[str, Callable]: +def _collect_dev_init(controller_api: ControllerAPI) -> dict[str, Callable]: async def init_device(tango_device: Device): await server.Device.init_device(tango_device) # type: ignore tango_device.set_state(DevState.ON) @@ -154,7 +150,7 @@ async def init_device(tango_device: Device): return {"init_device": init_device} -def _collect_dev_flags(controller: BaseController) -> dict[str, Any]: +def _collect_dev_flags(controller_api: ControllerAPI) -> dict[str, Any]: collection: dict[str, Any] = {} collection["green_mode"] = tango.GreenMode.Asyncio @@ -174,21 +170,21 @@ def _collect_dsr_args(options: TangoDSROptions) -> list[str]: class TangoDSR: def __init__( self, - controller: BaseController, + controller_api: ControllerAPI, loop: asyncio.AbstractEventLoop, ): - self._controller = controller + self._controller_api = controller_api self._loop = loop - self.dev_class = self._controller.__class__.__name__ + self.dev_class = self._controller_api.__class__.__name__ self._device = self._create_device() def _create_device(self): class_dict: dict = { - **_collect_dev_attributes(self._controller, self._loop), - **_collect_dev_commands(self._controller, self._loop), - **_collect_dev_properties(self._controller), - **_collect_dev_init(self._controller), - **_collect_dev_flags(self._controller), + **_collect_dev_attributes(self._controller_api, self._loop), + **_collect_dev_commands(self._controller_api, self._loop), + **_collect_dev_properties(self._controller_api), + **_collect_dev_init(self._controller_api), + **_collect_dev_flags(self._controller_api), } class_bases = (server.Device,) diff --git a/src/fastcs/wrappers.py b/src/fastcs/wrappers.py index 5451672d9..41b3a5c02 100644 --- a/src/fastcs/wrappers.py +++ b/src/fastcs/wrappers.py @@ -1,41 +1,46 @@ -from typing import Any, Protocol, runtime_checkable - -from .cs_methods import Command, Method, Put, Scan +from collections.abc import Callable + +from .cs_methods import ( + Controller_T, + UnboundCommand, + UnboundCommandCallback, + UnboundPut, + UnboundPutCallback, + UnboundScan, + UnboundScanCallback, +) from .exceptions import FastCSException -@runtime_checkable -class WrappedMethod(Protocol): - fastcs_method: Method - - -# TODO: Consider type hints with the use of typing.Protocol -def scan(period: float) -> Any: +def scan( + period: float, +) -> Callable[[UnboundScanCallback[Controller_T]], UnboundScan[Controller_T]]: if period <= 0: raise FastCSException("Scan method must have a positive scan period") - def wrapper(fn): - fn.fastcs_method = Scan(fn, period) - return fn + def wrapper(fn: UnboundScanCallback[Controller_T]) -> UnboundScan[Controller_T]: + return UnboundScan(fn, period) return wrapper -def put(fn) -> Any: - fn.fastcs_method = Put(fn) - return fn +def put(fn: UnboundPutCallback[Controller_T]) -> UnboundPut[Controller_T]: + return UnboundPut(fn) -def command(*, group: str | None = None) -> Any: - """Decorator to map a `Controller` method into a `Command`. +def command( + *, group: str | None = None +) -> Callable[[UnboundCommandCallback[Controller_T]], UnboundCommand[Controller_T]]: + """Decorator to tag a `Controller` method to be turned into a `Command`. Args: - group: Group to display the widget for this command in on the UI + group: Group to display this command under in the transport layer """ - def wrapper(fn): - fn.fastcs_method = Command(fn, group=group) - return fn + def wrapper( + fn: UnboundCommandCallback[Controller_T], + ) -> UnboundCommand[Controller_T]: + return UnboundCommand(fn, group=group) return wrapper diff --git a/tests/assertable_controller.py b/tests/assertable_controller.py index 9000e70f5..36f683cde 100644 --- a/tests/assertable_controller.py +++ b/tests/assertable_controller.py @@ -2,10 +2,12 @@ from contextlib import contextmanager from typing import Literal -from pytest_mock import MockerFixture +from pytest_mock import MockerFixture, MockType from fastcs.attributes import AttrR, Handler, Sender, Updater +from fastcs.backend import build_controller_api from fastcs.controller import Controller, SubController +from fastcs.controller_api import ControllerAPI from fastcs.datatypes import Int from fastcs.wrappers import command, scan @@ -30,7 +32,7 @@ class TestSubController(SubController): read_int: AttrR = AttrR(Int(), handler=TestUpdater()) -class TestController(Controller): +class MyTestController(Controller): def __init__(self) -> None: super().__init__() @@ -59,11 +61,28 @@ async def counter(self): self.count += 1 -class AssertableController(TestController): - def __init__(self, mocker: MockerFixture) -> None: - self.mocker = mocker +class AssertableControllerAPI(ControllerAPI): + def __init__(self, controller: Controller, mocker: MockerFixture) -> None: super().__init__() + self.mocker = mocker + self.command_method_spys: dict[str, MockType] = {} + + # Build a ControllerAPI from the given Controller + controller_api = build_controller_api(controller) + # Copy its fields + self.attributes = controller_api.attributes + self.command_methods = controller_api.command_methods + self.put_methods = controller_api.put_methods + self.scan_methods = controller_api.scan_methods + self.sub_apis = controller_api.sub_apis + + # Create spys for command methods before they are passed to the transport + for command_name in self.command_methods.keys(): + self.command_method_spys[command_name] = mocker.spy( + self.command_methods[command_name], "_fn" + ) + @contextmanager def assert_read_here(self, path: list[str]): yield from self._assert_method(path, "get") @@ -85,20 +104,20 @@ def _assert_method(self, path: list[str], method: Literal["get", "process", ""]) queue = copy.deepcopy(path) # Navigate to subcontroller - controller = self + controller_api = self item_name = queue.pop(-1) for item in queue: - controllers = controller.get_sub_controllers() - controller = controllers[item] + controller_api = controller_api.sub_apis[item] - # create probe + # Get spy if method: - attr = getattr(controller, item_name) + attr = controller_api.attributes[item_name] spy = self.mocker.spy(attr, method) else: - spy = self.mocker.spy(controller, item_name) - initial = spy.call_count + # Lookup pre-defined spy for method + spy = self.command_method_spys[item_name] + initial = spy.call_count try: yield # Enter context except Exception as e: diff --git a/tests/benchmarking/controller.py b/tests/benchmarking/controller.py index cb1c895bb..72606895a 100644 --- a/tests/benchmarking/controller.py +++ b/tests/benchmarking/controller.py @@ -8,7 +8,7 @@ from fastcs.transport.tango.options import TangoDSROptions, TangoOptions -class TestController(Controller): +class MyTestController(Controller): read_int: AttrR = AttrR(Int(), initial_value=0) write_bool: AttrW = AttrW(Bool()) @@ -22,7 +22,7 @@ def run(): TangoOptions(dsr=TangoDSROptions(dev_name="MY/BENCHMARK/DEVICE")), ] instance = FastCS( - TestController(), + MyTestController(), transport_options, ) instance.run() diff --git a/tests/conftest.py b/tests/conftest.py index a28c8af89..74feaa3e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,10 +17,11 @@ from aioca import purge_channel_caches from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.backend import build_controller_api from fastcs.datatypes import Bool, Float, Int, String from fastcs.transport.tango.dsr import register_dev from tests.assertable_controller import ( - TestController, + MyTestController, TestHandler, TestSender, TestUpdater, @@ -31,7 +32,7 @@ DATA_PATH = Path(__file__).parent / "data" -class BackendTestController(TestController): +class BackendTestController(MyTestController): read_int: AttrR = AttrR(Int(), handler=TestUpdater()) read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler()) read_write_float: AttrRW = AttrRW(Float()) @@ -45,6 +46,11 @@ def controller(): return BackendTestController() +@pytest.fixture +def controller_api(controller): + return build_controller_api(controller) + + @pytest.fixture def data() -> Path: return DATA_PATH diff --git a/tests/test_backend.py b/tests/test_backend.py index 0a446a522..2d578c547 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,6 +1,11 @@ import asyncio -from fastcs.backend import Backend +from fastcs.attributes import AttrRW +from fastcs.backend import Backend, build_controller_api +from fastcs.controller import Controller +from fastcs.cs_methods import Command +from fastcs.datatypes import Int +from fastcs.wrappers import command, scan def test_backend(controller): @@ -29,3 +34,58 @@ async def test_wrapper(): backend._stop_scan_tasks() loop.run_until_complete(test_wrapper()) + + +def test_controller_api(): + class MyTestController(Controller): + attr1: AttrRW[int] = AttrRW(Int()) + + def __init__(self): + super().__init__(description="Controller for testing") + + self.attributes["attr2"] = AttrRW(Int()) + + @command() + async def do_nothing(self): + pass + + @scan(1.0) + async def scan_nothing(self): + pass + + controller = MyTestController() + api = build_controller_api(controller) + + assert api.description == controller.description + assert list(api.attributes) == ["attr1", "attr2"] + assert list(api.command_methods) == ["do_nothing"] + assert list(api.scan_methods) == ["scan_nothing"] + + +def test_controller_api_methods(): + class MyTestController(Controller): + def __init__(self): + super().__init__() + + async def initialise(self): + async def do_nothing_dynamic() -> None: + pass + + self.do_nothing_dynamic = Command(do_nothing_dynamic) + + @command() + async def do_nothing_static(self): + pass + + controller = MyTestController() + loop = asyncio.get_event_loop() + backend = Backend(controller, loop) + + async def test_wrapper(): + await controller.do_nothing_static() + await controller.do_nothing_dynamic() + + await backend.controller_api.command_methods["do_nothing_static"]() + await backend.controller_api.command_methods["do_nothing_dynamic"]() + + loop.run_until_complete(test_wrapper()) diff --git a/tests/test_controller.py b/tests/test_controller.py index b404f6416..b0e87fd07 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -1,12 +1,7 @@ import pytest from fastcs.attributes import AttrR -from fastcs.controller import ( - Controller, - SubController, - _get_single_mapping, - _walk_mappings, -) +from fastcs.controller import Controller, SubController from fastcs.datatypes import Int @@ -20,11 +15,8 @@ def test_controller_nesting(): assert sub_controller.path == ["a"] assert sub_sub_controller.path == ["a", "b"] - assert list(_walk_mappings(controller)) == [ - _get_single_mapping(controller), - _get_single_mapping(sub_controller), - _get_single_mapping(sub_sub_controller), - ] + assert controller.get_sub_controllers() == {"a": sub_controller} + assert sub_controller.get_sub_controllers() == {"b": sub_sub_controller} with pytest.raises( ValueError, match=r"Controller .* already has a SubController registered as .*" @@ -69,10 +61,7 @@ def test_attribute_parsing(): sub_controller = SomeSubController() controller = SomeController(sub_controller) - mapping_walk = _walk_mappings(controller) - - controller_mapping = next(mapping_walk) - assert set(controller_mapping.attributes.keys()) == { + assert set(controller.attributes.keys()) == { "_attributes_attr", "annotated_attr", "_attributes_attr_equal", @@ -87,8 +76,7 @@ def test_attribute_parsing(): is not controller.annotated_and_equal_attr ) - sub_controller_mapping = next(mapping_walk) - assert sub_controller_mapping.attributes == { + assert sub_controller.attributes == { "sub_attribute": sub_controller.sub_attribute, } @@ -123,4 +111,4 @@ class FailingController(SomeController): "has an attribute of that name." ), ): - next(_walk_mappings(FailingController(SomeSubController()))) + FailingController(SomeSubController()) diff --git a/tests/test_cs_methods.py b/tests/test_cs_methods.py new file mode 100644 index 000000000..7b05bb924 --- /dev/null +++ b/tests/test_cs_methods.py @@ -0,0 +1,115 @@ +import pytest + +from fastcs.controller import Controller +from fastcs.cs_methods import ( + Command, + Method, + Put, + Scan, + UnboundCommand, + UnboundPut, + UnboundScan, +) +from fastcs.exceptions import FastCSException + + +def test_method(): + def sync_do_nothing(): + pass + + with pytest.raises(FastCSException): + Method(sync_do_nothing) # type: ignore + + async def do_nothing_with_return() -> int: + return 1 + + with pytest.raises(FastCSException): + Method(do_nothing_with_return) # type: ignore + + async def do_nothing(): + """Do nothing.""" + pass + + method = Method(do_nothing, group="Nothing") + + assert method.docstring == "Do nothing." + assert method.group == "Nothing" + + +@pytest.mark.asyncio +async def test_unbound_command(): + class TestController(Controller): + async def do_nothing(self): + pass + + async def do_nothing_with_arg(self, arg): + pass + + unbound_command = UnboundCommand(TestController.do_nothing) + + with pytest.raises(NotImplementedError): + await unbound_command() + + with pytest.raises(FastCSException): + UnboundCommand(TestController.do_nothing_with_arg) # type: ignore + + with pytest.raises(FastCSException): + Command(TestController().do_nothing_with_arg) # type: ignore + + command = unbound_command.bind(TestController()) + + await command() + + +@pytest.mark.asyncio +async def test_unbound_scan(): + class TestController(Controller): + async def update_nothing(self): + pass + + async def update_nothing_with_arg(self, arg): + pass + + unbound_scan = UnboundScan(TestController.update_nothing, 1.0) + + assert unbound_scan.period == 1.0 + + with pytest.raises(NotImplementedError): + await unbound_scan() + + with pytest.raises(FastCSException): + UnboundScan(TestController.update_nothing_with_arg, 1.0) # type: ignore + + with pytest.raises(FastCSException): + Scan(TestController().update_nothing_with_arg, 1.0) # type: ignore + + scan = unbound_scan.bind(TestController()) + + assert scan.period == 1.0 + + await scan() + + +@pytest.mark.asyncio +async def test_unbound_put(): + class TestController(Controller): + async def put_value(self, value): + pass + + async def put_no_value(self): + pass + + unbound_put = UnboundPut(TestController.put_value) + + with pytest.raises(NotImplementedError): + await unbound_put() + + with pytest.raises(FastCSException): + UnboundPut(TestController.put_no_value) # type: ignore + + with pytest.raises(FastCSException): + Put(TestController().put_no_value) # type: ignore + + put = unbound_put.bind(TestController()) + + await put(1) diff --git a/tests/transport/epics/ca/test_gui.py b/tests/transport/epics/ca/test_gui.py index de7fb9399..cf1d47b8c 100644 --- a/tests/transport/epics/ca/test_gui.py +++ b/tests/transport/epics/ca/test_gui.py @@ -18,13 +18,13 @@ from tests.util import ColourEnum from fastcs.attributes import AttrR, AttrRW, AttrW -from fastcs.controller import Controller +from fastcs.controller_api import ControllerAPI from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.transport.epics.gui import EpicsGUI -def test_get_pv(controller): - gui = EpicsGUI(controller, "DEVICE") +def test_get_pv(controller_api): + gui = EpicsGUI(controller_api, "DEVICE") assert gui._get_pv([], "A") == "DEVICE:A" assert gui._get_pv(["B"], "C") == "DEVICE:B:C" @@ -42,8 +42,8 @@ def test_get_pv(controller): # (Waveform(array_dtype=np.int32), None), ], ) -def test_get_attribute_component_r(datatype, widget, controller): - gui = EpicsGUI(controller, "DEVICE") +def test_get_attribute_component_r(datatype, widget, controller_api): + gui = EpicsGUI(controller_api, "DEVICE") assert gui._get_attribute_component([], "Attr", AttrR(datatype)) == SignalR( name="Attr", read_pv="Attr", read_widget=widget @@ -60,16 +60,16 @@ def test_get_attribute_component_r(datatype, widget, controller): (Enum(ColourEnum), ComboBox(choices=["RED", "GREEN", "BLUE"])), ], ) -def test_get_attribute_component_w(datatype, widget, controller): - gui = EpicsGUI(controller, "DEVICE") +def test_get_attribute_component_w(datatype, widget, controller_api): + gui = EpicsGUI(controller_api, "DEVICE") assert gui._get_attribute_component([], "Attr", AttrW(datatype)) == SignalW( name="Attr", write_pv="Attr", write_widget=widget ) -def test_get_attribute_component_none(mocker, controller): - gui = EpicsGUI(controller, "DEVICE") +def test_get_attribute_component_none(mocker, controller_api): + gui = EpicsGUI(controller_api, "DEVICE") mocker.patch.object(gui, "_get_read_widget", return_value=None) mocker.patch.object(gui, "_get_write_widget", return_value=None) @@ -86,10 +86,10 @@ def test_get_write_widget_none(): assert EpicsGUI._get_write_widget(AttrW(Waveform(np.int32))) is None -def test_get_components(controller): - gui = EpicsGUI(controller, "DEVICE") +def test_get_components(controller_api): + gui = EpicsGUI(controller_api, "DEVICE") - components = gui.extract_mapping_components(controller.get_controller_mappings()[0]) + components = gui.extract_api_components(controller_api) assert components == [ Group( name="SubController01", @@ -155,13 +155,10 @@ def test_get_components(controller): def test_get_components_none(mocker): """Test that if _get_attribute_component returns none it is skipped""" - class TestController(Controller): - attr = AttrR(Int()) - - controller = TestController() - gui = EpicsGUI(controller, "DEVICE") + controller_api = ControllerAPI() + gui = EpicsGUI(controller_api, "DEVICE") mocker.patch.object(gui, "_get_attribute_component", return_value=None) - components = gui.extract_mapping_components(controller.get_controller_mappings()[0]) + components = gui.extract_api_components(controller_api) assert components == [] diff --git a/tests/transport/epics/ca/test_softioc.py b/tests/transport/epics/ca/test_softioc.py index 59f4acca4..41a8ccc35 100644 --- a/tests/transport/epics/ca/test_softioc.py +++ b/tests/transport/epics/ca/test_softioc.py @@ -5,7 +5,8 @@ import pytest from pytest_mock import MockerFixture from tests.assertable_controller import ( - AssertableController, + AssertableControllerAPI, + MyTestController, TestHandler, TestSender, TestUpdater, @@ -14,6 +15,7 @@ from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import Controller +from fastcs.controller_api import ControllerAPI from fastcs.cs_methods import Command from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.exceptions import FastCSException @@ -180,7 +182,7 @@ def test_get_output_record_raises(mocker: MockerFixture): _make_record("PV", mocker.MagicMock(), on_update=mocker.MagicMock()) -class EpicsAssertableController(AssertableController): +class EpicsController(MyTestController): read_int = AttrR(Int(), handler=TestUpdater()) read_write_int = AttrRW(Int(), handler=TestHandler()) read_write_float = AttrRW(Float()) @@ -192,11 +194,11 @@ class EpicsAssertableController(AssertableController): @pytest.fixture() -def controller(class_mocker: MockerFixture): - return EpicsAssertableController(class_mocker) +def epics_controller_api(class_mocker: MockerFixture): + return AssertableControllerAPI(EpicsController(), class_mocker) -def test_ioc(mocker: MockerFixture, controller: Controller): +def test_ioc(mocker: MockerFixture, epics_controller_api: ControllerAPI): ioc_builder = mocker.patch("fastcs.transport.epics.ca.ioc.builder") builder = mocker.patch("fastcs.transport.epics.ca.util.builder") add_pvi_info = mocker.patch("fastcs.transport.epics.ca.ioc._add_pvi_info") @@ -204,75 +206,93 @@ def test_ioc(mocker: MockerFixture, controller: Controller): "fastcs.transport.epics.ca.ioc._add_sub_controller_pvi_info" ) - EpicsCAIOC(DEVICE, controller) + EpicsCAIOC(DEVICE, epics_controller_api) # Check records are created builder.boolIn.assert_called_once_with( f"{DEVICE}:ReadBool", - **record_metadata_from_attribute(controller.attributes["read_bool"]), - **record_metadata_from_datatype(controller.attributes["read_bool"].datatype), + **record_metadata_from_attribute(epics_controller_api.attributes["read_bool"]), + **record_metadata_from_datatype( + epics_controller_api.attributes["read_bool"].datatype + ), ) builder.longIn.assert_any_call( f"{DEVICE}:ReadInt", - **record_metadata_from_attribute(controller.attributes["read_int"]), - **record_metadata_from_datatype(controller.attributes["read_int"].datatype), + **record_metadata_from_attribute(epics_controller_api.attributes["read_int"]), + **record_metadata_from_datatype( + epics_controller_api.attributes["read_int"].datatype + ), ) builder.aIn.assert_called_once_with( f"{DEVICE}:ReadWriteFloat_RBV", - **record_metadata_from_attribute(controller.attributes["read_write_float"]), + **record_metadata_from_attribute( + epics_controller_api.attributes["read_write_float"] + ), **record_metadata_from_datatype( - controller.attributes["read_write_float"].datatype + epics_controller_api.attributes["read_write_float"].datatype ), ) builder.aOut.assert_any_call( f"{DEVICE}:ReadWriteFloat", always_update=True, on_update=mocker.ANY, - **record_metadata_from_attribute(controller.attributes["read_write_float"]), + **record_metadata_from_attribute( + epics_controller_api.attributes["read_write_float"] + ), **record_metadata_from_datatype( - controller.attributes["read_write_float"].datatype + epics_controller_api.attributes["read_write_float"].datatype ), ) builder.longIn.assert_any_call( f"{DEVICE}:ReadWriteInt_RBV", - **record_metadata_from_attribute(controller.attributes["read_write_int"]), + **record_metadata_from_attribute( + epics_controller_api.attributes["read_write_int"] + ), **record_metadata_from_datatype( - controller.attributes["read_write_int"].datatype + epics_controller_api.attributes["read_write_int"].datatype ), ) builder.longOut.assert_called_with( f"{DEVICE}:ReadWriteInt", always_update=True, on_update=mocker.ANY, - **record_metadata_from_attribute(controller.attributes["read_write_int"]), + **record_metadata_from_attribute( + epics_controller_api.attributes["read_write_int"] + ), **record_metadata_from_datatype( - controller.attributes["read_write_int"].datatype + epics_controller_api.attributes["read_write_int"].datatype ), ) builder.mbbIn.assert_called_once_with( f"{DEVICE}:Enum_RBV", - **record_metadata_from_attribute(controller.attributes["enum"]), - **record_metadata_from_datatype(controller.attributes["enum"].datatype), + **record_metadata_from_attribute(epics_controller_api.attributes["enum"]), + **record_metadata_from_datatype( + epics_controller_api.attributes["enum"].datatype + ), ) builder.mbbOut.assert_called_once_with( f"{DEVICE}:Enum", always_update=True, on_update=mocker.ANY, - **record_metadata_from_attribute(controller.attributes["enum"]), - **record_metadata_from_datatype(controller.attributes["enum"].datatype), + **record_metadata_from_attribute(epics_controller_api.attributes["enum"]), + **record_metadata_from_datatype( + epics_controller_api.attributes["enum"].datatype + ), ) builder.boolOut.assert_called_once_with( f"{DEVICE}:WriteBool", always_update=True, on_update=mocker.ANY, - **record_metadata_from_attribute(controller.attributes["write_bool"]), - **record_metadata_from_datatype(controller.attributes["write_bool"].datatype), + **record_metadata_from_attribute(epics_controller_api.attributes["write_bool"]), + **record_metadata_from_datatype( + epics_controller_api.attributes["write_bool"].datatype + ), ) ioc_builder.Action.assert_any_call(f"{DEVICE}:Go", on_update=mocker.ANY) # Check info tags are added add_pvi_info.assert_called_once_with(f"{DEVICE}:PVI") - add_sub_controller_pvi_info.assert_called_once_with(DEVICE, controller) + add_sub_controller_pvi_info.assert_called_once_with(DEVICE, epics_controller_api) def test_add_pvi_info(mocker: MockerFixture): @@ -341,13 +361,13 @@ def test_add_pvi_info_with_parent(mocker: MockerFixture): def test_add_sub_controller_pvi_info(mocker: MockerFixture): add_pvi_info = mocker.patch("fastcs.transport.epics.ca.ioc._add_pvi_info") - controller = mocker.MagicMock() - controller.path = [] - child = mocker.MagicMock() - child.path = ["Child"] - controller.get_sub_controllers.return_value = {"d": child} + parent_api = mocker.MagicMock() + parent_api.path = [] + child_api = mocker.MagicMock() + child_api.path = ["Child"] + parent_api.sub_apis = {"d": child_api} - _add_sub_controller_pvi_info(DEVICE, controller) + _add_sub_controller_pvi_info(DEVICE, parent_api) add_pvi_info.assert_called_once_with( f"{DEVICE}:Child:PVI", f"{DEVICE}:PVI", "child" @@ -373,33 +393,30 @@ def test_add_attr_pvi_info(mocker: MockerFixture): ) -async def do_nothing(arg): ... - - -class NothingCommand: - def __init__(self): # make fastcs_method instance variable - self.fastcs_method = Command(do_nothing) +async def do_nothing(): ... class ControllerLongNames(Controller): attr_r_with_reallyreallyreallyreallyreallyreallyreally_long_name = AttrR(Int()) attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV = AttrRW(Int()) attr_rw_short_name = AttrRW(Int()) - command_with_reallyreallyreallyreallyreallyreallyreally_long_name = NothingCommand() - command_short_name = NothingCommand() + command_with_reallyreallyreallyreallyreallyreallyreally_long_name = Command( + do_nothing + ) + command_short_name = Command(do_nothing) def test_long_pv_names_discarded(mocker: MockerFixture): ioc_builder = mocker.patch("fastcs.transport.epics.ca.ioc.builder") builder = mocker.patch("fastcs.transport.epics.ca.util.builder") - long_name_controller = ControllerLongNames() + long_name_controller_api = AssertableControllerAPI(ControllerLongNames(), mocker) long_attr_name = "attr_r_with_reallyreallyreallyreallyreallyreallyreally_long_name" long_rw_name = "attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV" - assert long_name_controller.attr_rw_short_name.enabled - assert getattr(long_name_controller, long_attr_name).enabled - EpicsCAIOC(DEVICE, long_name_controller) - assert long_name_controller.attr_rw_short_name.enabled - assert not getattr(long_name_controller, long_attr_name).enabled + assert long_name_controller_api.attributes["attr_rw_short_name"].enabled + assert long_name_controller_api.attributes[long_attr_name].enabled + EpicsCAIOC(DEVICE, long_name_controller_api) + assert long_name_controller_api.attributes["attr_rw_short_name"].enabled + assert not long_name_controller_api.attributes[long_attr_name].enabled short_pv_name = "attr_rw_short_name".title().replace("_", "") builder.longOut.assert_called_once_with( @@ -407,17 +424,23 @@ def test_long_pv_names_discarded(mocker: MockerFixture): always_update=True, on_update=mocker.ANY, **record_metadata_from_datatype( - long_name_controller.attr_rw_short_name.datatype + long_name_controller_api.attributes["attr_rw_short_name"].datatype + ), + **record_metadata_from_attribute( + long_name_controller_api.attributes["attr_rw_short_name"] ), - **record_metadata_from_attribute(long_name_controller.attr_rw_short_name), ) builder.longIn.assert_called_once_with( f"{DEVICE}:{short_pv_name}_RBV", **record_metadata_from_datatype( - long_name_controller.attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV.datatype + long_name_controller_api.attributes[ + "attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV" + ].datatype ), **record_metadata_from_attribute( - long_name_controller.attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV + long_name_controller_api.attributes[ + "attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV" + ] ), ) @@ -443,11 +466,11 @@ def test_long_pv_names_discarded(mocker: MockerFixture): with pytest.raises(AssertionError): builder.longIn.assert_called_once_with(f"{DEVICE}:{long_rw_pv_name}_RBV") - assert long_name_controller.command_short_name.fastcs_method.enabled + assert long_name_controller_api.command_methods["command_short_name"].enabled long_command_name = ( "command_with_reallyreallyreallyreallyreallyreallyreally_long_name" ) - assert not getattr(long_name_controller, long_command_name).fastcs_method.enabled + assert not long_name_controller_api.command_methods[long_command_name].enabled short_command_pv_name = "command_short_name".title().replace("_", "") ioc_builder.Action.assert_called_once_with( diff --git a/tests/transport/graphQL/test_graphQL.py b/tests/transport/graphQL/test_graphQL.py index 8ba57eebf..c790b6869 100644 --- a/tests/transport/graphQL/test_graphQL.py +++ b/tests/transport/graphQL/test_graphQL.py @@ -6,7 +6,8 @@ from fastapi.testclient import TestClient from pytest_mock import MockerFixture from tests.assertable_controller import ( - AssertableController, + AssertableControllerAPI, + MyTestController, TestHandler, TestSender, TestUpdater, @@ -17,7 +18,7 @@ from fastcs.transport.graphQL.adapter import GraphQLTransport -class RestAssertableController(AssertableController): +class GraphQLController(MyTestController): read_int = AttrR(Int(), handler=TestUpdater()) read_write_int = AttrRW(Int(), handler=TestHandler()) read_write_float = AttrRW(Float()) @@ -27,8 +28,8 @@ class RestAssertableController(AssertableController): @pytest.fixture(scope="class") -def assertable_controller(class_mocker: MockerFixture): - return RestAssertableController(class_mocker) +def gql_controller_api(class_mocker: MockerFixture): + return AssertableControllerAPI(GraphQLController(), class_mocker) def nest_query(path: list[str]) -> str: @@ -53,106 +54,124 @@ def nest_mutation(path: list[str], value: Any) -> str: return f"{field}(value: {json.dumps(value)})" -def nest_responce(path: list[str], value: Any) -> dict: +def nest_response(path: list[str], value: Any) -> dict: queue = copy.deepcopy(path) field = queue.pop(0) if queue: - nesting = nest_responce(queue, value) + nesting = nest_response(queue, value) return {field: nesting} else: return {field: value} +def create_test_client(gql_controller_api: AssertableControllerAPI) -> TestClient: + return TestClient(GraphQLTransport(gql_controller_api)._server._app) + + class TestGraphQLServer: @pytest.fixture(scope="class") - def client(self, assertable_controller): - app = GraphQLTransport( - assertable_controller, - )._server._app - return TestClient(app) + def test_client(self, gql_controller_api) -> TestClient: + return create_test_client(gql_controller_api) - def test_read_int(self, client, assertable_controller): + def test_read_int( + self, gql_controller_api: AssertableControllerAPI, test_client: TestClient + ): expect = 0 path = ["readInt"] query = f"query {{ {nest_query(path)} }}" - with assertable_controller.assert_read_here(["read_int"]): - response = client.post("/graphql", json={"query": query}) + with gql_controller_api.assert_read_here(["read_int"]): + response = test_client.post("/graphql", json={"query": query}) assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, expect) + assert response.json()["data"] == nest_response(path, expect) - def test_read_write_int(self, client, assertable_controller): + def test_read_write_int( + self, gql_controller_api: AssertableControllerAPI, test_client: TestClient + ): expect = 0 path = ["readWriteInt"] query = f"query {{ {nest_query(path)} }}" - with assertable_controller.assert_read_here(["read_write_int"]): - response = client.post("/graphql", json={"query": query}) + with gql_controller_api.assert_read_here(["read_write_int"]): + response = test_client.post("/graphql", json={"query": query}) assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, expect) + assert response.json()["data"] == nest_response(path, expect) new = 9 mutation = f"mutation {{ {nest_mutation(path, new)} }}" - with assertable_controller.assert_write_here(["read_write_int"]): - response = client.post("/graphql", json={"query": mutation}) + with gql_controller_api.assert_write_here(["read_write_int"]): + response = test_client.post("/graphql", json={"query": mutation}) assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, new) + assert response.json()["data"] == nest_response(path, new) - def test_read_write_float(self, client, assertable_controller): + def test_read_write_float( + self, gql_controller_api: AssertableControllerAPI, test_client: TestClient + ): expect = 0 path = ["readWriteFloat"] query = f"query {{ {nest_query(path)} }}" - with assertable_controller.assert_read_here(["read_write_float"]): - response = client.post("/graphql", json={"query": query}) + with gql_controller_api.assert_read_here(["read_write_float"]): + response = test_client.post("/graphql", json={"query": query}) assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, expect) + assert response.json()["data"] == nest_response(path, expect) new = 0.5 mutation = f"mutation {{ {nest_mutation(path, new)} }}" - with assertable_controller.assert_write_here(["read_write_float"]): - response = client.post("/graphql", json={"query": mutation}) + with gql_controller_api.assert_write_here(["read_write_float"]): + response = test_client.post("/graphql", json={"query": mutation}) assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, new) + assert response.json()["data"] == nest_response(path, new) - def test_read_bool(self, client, assertable_controller): + def test_read_bool( + self, gql_controller_api: AssertableControllerAPI, test_client: TestClient + ): expect = False path = ["readBool"] query = f"query {{ {nest_query(path)} }}" - with assertable_controller.assert_read_here(["read_bool"]): - response = client.post("/graphql", json={"query": query}) + with gql_controller_api.assert_read_here(["read_bool"]): + response = test_client.post("/graphql", json={"query": query}) assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, expect) + assert response.json()["data"] == nest_response(path, expect) - def test_write_bool(self, client, assertable_controller): + def test_write_bool( + self, gql_controller_api: AssertableControllerAPI, test_client: TestClient + ): value = True path = ["writeBool"] mutation = f"mutation {{ {nest_mutation(path, value)} }}" - with assertable_controller.assert_write_here(["write_bool"]): - response = client.post("/graphql", json={"query": mutation}) + with gql_controller_api.assert_write_here(["write_bool"]): + response = test_client.post("/graphql", json={"query": mutation}) assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, value) + assert response.json()["data"] == nest_response(path, value) + + def test_go( + self, gql_controller_api: AssertableControllerAPI, test_client: TestClient + ): + test_client = create_test_client(gql_controller_api) - def test_go(self, client, assertable_controller): path = ["go"] mutation = f"mutation {{ {nest_query(path)} }}" - with assertable_controller.assert_execute_here(path): - response = client.post("/graphql", json={"query": mutation}) + with gql_controller_api.assert_execute_here(["go"]): + response = test_client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 assert response.json()["data"] == {path[-1]: True} - def test_read_child1(self, client, assertable_controller): + def test_read_child1( + self, gql_controller_api: AssertableControllerAPI, test_client: TestClient + ): expect = 0 path = ["SubController01", "readInt"] query = f"query {{ {nest_query(path)} }}" - with assertable_controller.assert_read_here(["SubController01", "read_int"]): - response = client.post("/graphql", json={"query": query}) + with gql_controller_api.assert_read_here(["SubController01", "read_int"]): + response = test_client.post("/graphql", json={"query": query}) assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, expect) + assert response.json()["data"] == nest_response(path, expect) - def test_read_child2(self, client, assertable_controller): + def test_read_child2(self, gql_controller_api, test_client: TestClient): expect = 0 path = ["SubController02", "readInt"] query = f"query {{ {nest_query(path)} }}" - with assertable_controller.assert_read_here(["SubController02", "read_int"]): - response = client.post("/graphql", json={"query": query}) + with gql_controller_api.assert_read_here(["SubController02", "read_int"]): + response = test_client.post("/graphql", json={"query": query}) assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, expect) + assert response.json()["data"] == nest_response(path, expect) diff --git a/tests/transport/rest/test_rest.py b/tests/transport/rest/test_rest.py index 87d016f06..23b5fce99 100644 --- a/tests/transport/rest/test_rest.py +++ b/tests/transport/rest/test_rest.py @@ -5,18 +5,20 @@ from fastapi.testclient import TestClient from pytest_mock import MockerFixture from tests.assertable_controller import ( - AssertableController, + AssertableControllerAPI, + MyTestController, TestHandler, TestSender, TestUpdater, ) from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.controller_api import ControllerAPI from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.transport.rest.adapter import RestTransport -class RestAssertableController(AssertableController): +class RestController(MyTestController): read_int = AttrR(Int(), handler=TestUpdater()) read_write_int = AttrRW(Int(), handler=TestHandler()) read_write_float = AttrRW(Float()) @@ -29,126 +31,155 @@ class RestAssertableController(AssertableController): @pytest.fixture(scope="class") -def assertable_controller(class_mocker: MockerFixture): - return RestAssertableController(class_mocker) +def rest_controller_api(class_mocker: MockerFixture): + return AssertableControllerAPI(RestController(), class_mocker) + + +def create_test_client(rest_controller_api: ControllerAPI) -> TestClient: + return TestClient(RestTransport(rest_controller_api)._server._app) class TestRestServer: @pytest.fixture(scope="class") - def client(self, assertable_controller): - app = RestTransport(assertable_controller)._server._app - with TestClient(app) as client: - yield client + def test_client(self, rest_controller_api): + with create_test_client(rest_controller_api) as test_client: + yield test_client - def test_read_write_int(self, assertable_controller, client): + def test_read_write_int( + self, rest_controller_api: AssertableControllerAPI, test_client: TestClient + ): expect = 0 - with assertable_controller.assert_read_here(["read_write_int"]): - response = client.get("/read-write-int") + with rest_controller_api.assert_read_here(["read_write_int"]): + response = test_client.get("/read-write-int") assert response.status_code == 200 assert response.json()["value"] == expect new = 9 - with assertable_controller.assert_write_here(["read_write_int"]): - response = client.put("/read-write-int", json={"value": new}) - assert client.get("/read-write-int").json()["value"] == new + with rest_controller_api.assert_write_here(["read_write_int"]): + response = test_client.put("/read-write-int", json={"value": new}) + assert test_client.get("/read-write-int").json()["value"] == new - def test_read_int(self, assertable_controller, client): + def test_read_int( + self, rest_controller_api: AssertableControllerAPI, test_client: TestClient + ): expect = 0 - with assertable_controller.assert_read_here(["read_int"]): - response = client.get("/read-int") + with rest_controller_api.assert_read_here(["read_int"]): + response = test_client.get("/read-int") assert response.status_code == 200 assert response.json()["value"] == expect - def test_read_write_float(self, assertable_controller, client): + def test_read_write_float( + self, rest_controller_api: AssertableControllerAPI, test_client: TestClient + ): expect = 0 - with assertable_controller.assert_read_here(["read_write_float"]): - response = client.get("/read-write-float") + with rest_controller_api.assert_read_here(["read_write_float"]): + response = test_client.get("/read-write-float") assert response.status_code == 200 assert response.json()["value"] == expect new = 0.5 - with assertable_controller.assert_write_here(["read_write_float"]): - response = client.put("/read-write-float", json={"value": new}) - assert client.get("/read-write-float").json()["value"] == new + with rest_controller_api.assert_write_here(["read_write_float"]): + response = test_client.put("/read-write-float", json={"value": new}) + assert test_client.get("/read-write-float").json()["value"] == new - def test_read_bool(self, assertable_controller, client): + def test_read_bool( + self, rest_controller_api: AssertableControllerAPI, test_client: TestClient + ): expect = False - with assertable_controller.assert_read_here(["read_bool"]): - response = client.get("/read-bool") + with rest_controller_api.assert_read_here(["read_bool"]): + response = test_client.get("/read-bool") assert response.status_code == 200 assert response.json()["value"] == expect - def test_write_bool(self, assertable_controller, client): - with assertable_controller.assert_write_here(["write_bool"]): - client.put("/write-bool", json={"value": True}) - - def test_enum(self, assertable_controller, client): - enum_attr = assertable_controller.attributes["enum"] + def test_write_bool( + self, rest_controller_api: AssertableControllerAPI, test_client: TestClient + ): + with rest_controller_api.assert_write_here(["write_bool"]): + test_client.put("/write-bool", json={"value": True}) + + def test_enum( + self, rest_controller_api: AssertableControllerAPI, test_client: TestClient + ): + enum_attr = rest_controller_api.attributes["enum"] + assert isinstance(enum_attr, AttrRW) enum_cls = enum_attr.datatype.dtype assert isinstance(enum_attr.get(), enum_cls) assert enum_attr.get() == enum_cls(0) expect = 0 - with assertable_controller.assert_read_here(["enum"]): - response = client.get("/enum") + with rest_controller_api.assert_read_here(["enum"]): + response = test_client.get("/enum") assert response.status_code == 200 assert response.json()["value"] == expect new = 2 - with assertable_controller.assert_write_here(["enum"]): - response = client.put("/enum", json={"value": new}) - assert client.get("/enum").json()["value"] == new + with rest_controller_api.assert_write_here(["enum"]): + response = test_client.put("/enum", json={"value": new}) + assert test_client.get("/enum").json()["value"] == new assert isinstance(enum_attr.get(), enum_cls) assert enum_attr.get() == enum_cls(2) - def test_1d_waveform(self, assertable_controller, client): - attribute = assertable_controller.attributes["one_d_waveform"] + def test_1d_waveform( + self, rest_controller_api: AssertableControllerAPI, test_client: TestClient + ): + attribute = rest_controller_api.attributes["one_d_waveform"] expect = np.zeros((10,), dtype=np.int32) + assert isinstance(attribute, AttrRW) assert np.array_equal(attribute.get(), expect) assert isinstance(attribute.get(), np.ndarray) - with assertable_controller.assert_read_here(["one_d_waveform"]): - response = client.get("one-d-waveform") + with rest_controller_api.assert_read_here(["one_d_waveform"]): + response = test_client.get("one-d-waveform") assert np.array_equal(response.json()["value"], expect) new = [1, 2, 3] - with assertable_controller.assert_write_here(["one_d_waveform"]): - client.put("/one-d-waveform", json={"value": new}) - assert np.array_equal(client.get("/one-d-waveform").json()["value"], new) + with rest_controller_api.assert_write_here(["one_d_waveform"]): + test_client.put("/one-d-waveform", json={"value": new}) + assert np.array_equal(test_client.get("/one-d-waveform").json()["value"], new) - result = client.get("/one-d-waveform") + result = test_client.get("/one-d-waveform") assert np.array_equal(result.json()["value"], new) assert np.array_equal(attribute.get(), new) assert isinstance(attribute.get(), np.ndarray) - def test_2d_waveform(self, assertable_controller, client): - attribute = assertable_controller.attributes["two_d_waveform"] + def test_2d_waveform( + self, rest_controller_api: AssertableControllerAPI, test_client: TestClient + ): + attribute = rest_controller_api.attributes["two_d_waveform"] + assert isinstance(attribute, AttrRW) expect = np.zeros((10, 10), dtype=np.int32) assert np.array_equal(attribute.get(), expect) assert isinstance(attribute.get(), np.ndarray) - with assertable_controller.assert_read_here(["two_d_waveform"]): - result = client.get("/two-d-waveform") + with rest_controller_api.assert_read_here(["two_d_waveform"]): + result = test_client.get("/two-d-waveform") assert np.array_equal(result.json()["value"], expect) new = [[1, 2, 3], [4, 5, 6]] - with assertable_controller.assert_write_here(["two_d_waveform"]): - client.put("/two-d-waveform", json={"value": new}) + with rest_controller_api.assert_write_here(["two_d_waveform"]): + test_client.put("/two-d-waveform", json={"value": new}) - result = client.get("/two-d-waveform") + result = test_client.get("/two-d-waveform") assert np.array_equal(result.json()["value"], new) assert np.array_equal(attribute.get(), new) assert isinstance(attribute.get(), np.ndarray) - def test_go(self, assertable_controller, client): - with assertable_controller.assert_execute_here(["go"]): - response = client.put("/go") - assert response.status_code == 204 + def test_go( + self, rest_controller_api: AssertableControllerAPI, test_client: TestClient + ): + with rest_controller_api.assert_execute_here(["go"]): + response = test_client.put("/go") + + assert response.status_code == 204 - def test_read_child1(self, assertable_controller, client): + def test_read_child1( + self, rest_controller_api: AssertableControllerAPI, test_client: TestClient + ): expect = 0 - with assertable_controller.assert_read_here(["SubController01", "read_int"]): - response = client.get("/SubController01/read-int") + with rest_controller_api.assert_read_here(["SubController01", "read_int"]): + response = test_client.get("/SubController01/read-int") assert response.status_code == 200 assert response.json()["value"] == expect - def test_read_child2(self, assertable_controller, client): + def test_read_child2( + self, rest_controller_api: AssertableControllerAPI, test_client: TestClient + ): expect = 0 - with assertable_controller.assert_read_here(["SubController02", "read_int"]): - response = client.get("/SubController02/read-int") + with rest_controller_api.assert_read_here(["SubController02", "read_int"]): + response = test_client.get("/SubController02/read-int") assert response.status_code == 200 assert response.json()["value"] == expect diff --git a/tests/transport/tango/test_dsr.py b/tests/transport/tango/test_dsr.py index 481fac227..c38de8361 100644 --- a/tests/transport/tango/test_dsr.py +++ b/tests/transport/tango/test_dsr.py @@ -1,14 +1,14 @@ import asyncio import enum -from unittest import mock import numpy as np import pytest from pytest_mock import MockerFixture -from tango import DevState +from tango import DeviceProxy, DevState from tango.test_context import DeviceTestContext from tests.assertable_controller import ( - AssertableController, + AssertableControllerAPI, + MyTestController, TestHandler, TestSender, TestUpdater, @@ -23,7 +23,16 @@ async def patch_run_threadsafe_blocking(coro, loop): await coro -class TangoAssertableController(AssertableController): +@pytest.fixture(scope="module") +def mock_run_threadsafe_blocking(module_mocker: MockerFixture): + m = module_mocker.patch( + "fastcs.transport.tango.dsr._run_threadsafe_blocking", + patch_run_threadsafe_blocking, + ) + yield m + + +class TangoController(MyTestController): read_int = AttrR(Int(), handler=TestUpdater()) read_write_int = AttrRW(Int(), handler=TestHandler()) read_write_float = AttrRW(Float()) @@ -36,25 +45,32 @@ class TangoAssertableController(AssertableController): @pytest.fixture(scope="class") -def assertable_controller(class_mocker: MockerFixture): - return TangoAssertableController(class_mocker) +def tango_controller_api(class_mocker: MockerFixture) -> AssertableControllerAPI: + return AssertableControllerAPI(TangoController(), class_mocker) + + +def create_test_context(tango_controller_api: AssertableControllerAPI): + device = TangoTransport( + tango_controller_api, + # This is passed to enable instantiating the transport, but tests must avoid + # using via patching of functions. It will raise NotImplementedError if used. + asyncio.AbstractEventLoop(), + )._dsr._device + # https://tango-controls.readthedocs.io/projects/pytango/en/v9.5.1/testing/test_context.html + with DeviceTestContext(device, debug=0) as proxy: + yield proxy class TestTangoDevice: @pytest.fixture(scope="class") - def tango_context(self, assertable_controller): - with mock.patch( - "fastcs.transport.tango.dsr._run_threadsafe_blocking", - patch_run_threadsafe_blocking, - ): - device = TangoTransport( - assertable_controller, asyncio.AbstractEventLoop() - )._dsr._device - # https://tango-controls.readthedocs.io/projects/pytango/en/v9.5.1/testing/test_context.html - with DeviceTestContext(device, debug=0) as proxy: - yield proxy - - def test_list_attributes(self, tango_context): + def tango_context( + self, + mock_run_threadsafe_blocking, + tango_controller_api: AssertableControllerAPI, + ): + yield from create_test_context(tango_controller_api) + + def test_list_attributes(self, tango_context: DeviceProxy): assert list(tango_context.get_attribute_list()) == [ "Enum", "OneDWaveform", @@ -86,90 +102,109 @@ def test_status(self, tango_context): expect = "The device is in ON state." assert tango_context.command_inout("Status") == expect - def test_read_int(self, assertable_controller, tango_context): + def test_read_int( + self, tango_controller_api: AssertableControllerAPI, tango_context + ): expect = 0 - with assertable_controller.assert_read_here(["read_int"]): + with tango_controller_api.assert_read_here(["read_int"]): result = tango_context.read_attribute("ReadInt").value assert result == expect - def test_read_write_int(self, assertable_controller, tango_context): + def test_read_write_int( + self, tango_controller_api: AssertableControllerAPI, tango_context + ): expect = 0 - with assertable_controller.assert_read_here(["read_write_int"]): + with tango_controller_api.assert_read_here(["read_write_int"]): result = tango_context.read_attribute("ReadWriteInt").value assert result == expect new = 9 - with assertable_controller.assert_write_here(["read_write_int"]): + with tango_controller_api.assert_write_here(["read_write_int"]): tango_context.write_attribute("ReadWriteInt", new) assert tango_context.read_attribute("ReadWriteInt").value == new - def test_read_write_float(self, assertable_controller, tango_context): + def test_read_write_float( + self, tango_controller_api: AssertableControllerAPI, tango_context + ): expect = 0.0 - with assertable_controller.assert_read_here(["read_write_float"]): + with tango_controller_api.assert_read_here(["read_write_float"]): result = tango_context.read_attribute("ReadWriteFloat").value assert result == expect new = 0.5 - with assertable_controller.assert_write_here(["read_write_float"]): + with tango_controller_api.assert_write_here(["read_write_float"]): tango_context.write_attribute("ReadWriteFloat", new) assert tango_context.read_attribute("ReadWriteFloat").value == new - def test_read_bool(self, assertable_controller, tango_context): + def test_read_bool( + self, tango_controller_api: AssertableControllerAPI, tango_context + ): expect = False - with assertable_controller.assert_read_here(["read_bool"]): + with tango_controller_api.assert_read_here(["read_bool"]): result = tango_context.read_attribute("ReadBool").value assert result == expect - def test_write_bool(self, assertable_controller, tango_context): - with assertable_controller.assert_write_here(["write_bool"]): + def test_write_bool( + self, tango_controller_api: AssertableControllerAPI, tango_context + ): + with tango_controller_api.assert_write_here(["write_bool"]): tango_context.write_attribute("WriteBool", True) - def test_enum(self, assertable_controller, tango_context): - enum_attr = assertable_controller.attributes["enum"] + def test_enum(self, tango_controller_api: AssertableControllerAPI, tango_context): + enum_attr = tango_controller_api.attributes["enum"] + assert isinstance(enum_attr, AttrRW) enum_cls = enum_attr.datatype.dtype assert isinstance(enum_attr.get(), enum_cls) assert enum_attr.get() == enum_cls(0) expect = 0 - with assertable_controller.assert_read_here(["enum"]): + with tango_controller_api.assert_read_here(["enum"]): result = tango_context.read_attribute("Enum").value assert result == expect new = 1 - with assertable_controller.assert_write_here(["enum"]): + with tango_controller_api.assert_write_here(["enum"]): tango_context.write_attribute("Enum", new) assert tango_context.read_attribute("Enum").value == new assert isinstance(enum_attr.get(), enum_cls) assert enum_attr.get() == enum_cls(1) - def test_1d_waveform(self, assertable_controller, tango_context): + def test_1d_waveform( + self, tango_controller_api: AssertableControllerAPI, tango_context + ): expect = np.zeros((10,), dtype=np.int32) - with assertable_controller.assert_read_here(["one_d_waveform"]): + with tango_controller_api.assert_read_here(["one_d_waveform"]): result = tango_context.read_attribute("OneDWaveform").value assert np.array_equal(result, expect) new = np.array([1, 2, 3], dtype=np.int32) - with assertable_controller.assert_write_here(["one_d_waveform"]): + with tango_controller_api.assert_write_here(["one_d_waveform"]): tango_context.write_attribute("OneDWaveform", new) assert np.array_equal(tango_context.read_attribute("OneDWaveform").value, new) - def test_2d_waveform(self, assertable_controller, tango_context): + def test_2d_waveform( + self, tango_controller_api: AssertableControllerAPI, tango_context + ): expect = np.zeros((10, 10), dtype=np.int32) - with assertable_controller.assert_read_here(["two_d_waveform"]): + with tango_controller_api.assert_read_here(["two_d_waveform"]): result = tango_context.read_attribute("TwoDWaveform").value assert np.array_equal(result, expect) new = np.array([[1, 2, 3]], dtype=np.int32) - with assertable_controller.assert_write_here(["two_d_waveform"]): + with tango_controller_api.assert_write_here(["two_d_waveform"]): tango_context.write_attribute("TwoDWaveform", new) assert np.array_equal(tango_context.read_attribute("TwoDWaveform").value, new) - def test_go(self, assertable_controller, tango_context): - with assertable_controller.assert_execute_here(["go"]): - tango_context.command_inout("Go") - - def test_read_child1(self, assertable_controller, tango_context): + def test_read_child1( + self, tango_controller_api: AssertableControllerAPI, tango_context + ): expect = 0 - with assertable_controller.assert_read_here(["SubController01", "read_int"]): + with tango_controller_api.assert_read_here(["SubController01", "read_int"]): result = tango_context.read_attribute("SubController01_ReadInt").value assert result == expect - def test_read_child2(self, assertable_controller, tango_context): + def test_read_child2( + self, tango_controller_api: AssertableControllerAPI, tango_context + ): expect = 0 - with assertable_controller.assert_read_here(["SubController02", "read_int"]): + with tango_controller_api.assert_read_here(["SubController02", "read_int"]): result = tango_context.read_attribute("SubController02_ReadInt").value assert result == expect + + def test_go(self, tango_controller_api: AssertableControllerAPI, tango_context): + with tango_controller_api.assert_execute_here(["go"]): + tango_context.command_inout("Go")