diff --git a/roborock/data/b01_q10/b01_q10_containers.py b/roborock/data/b01_q10/b01_q10_containers.py index 0e805593..562c7062 100644 --- a/roborock/data/b01_q10/b01_q10_containers.py +++ b/roborock/data/b01_q10/b01_q10_containers.py @@ -1,6 +1,26 @@ -from ..containers import RoborockBase +"""Data container classes for Q10 B01 devices. + +Many of these classes use the `field(metadata={"dps": ...})` convention to map +dataclass fields to device Data Points (DPS). This metadata is utilized by the +`update_from_dps` helper in `roborock.devices.traits.b01.q10.common` to +automatically update objects from raw device responses. +""" +from dataclasses import dataclass, field +from ..containers import RoborockBase +from .b01_q10_code_mappings import ( + B01_Q10_DP, + YXBackType, + YXDeviceCleanTask, + YXDeviceState, + YXDeviceWorkMode, + YXFanLevel, + YXWaterLevel, +) + + +@dataclass class dpCleanRecord(RoborockBase): op: str result: int @@ -8,24 +28,28 @@ class dpCleanRecord(RoborockBase): data: list +@dataclass class dpMultiMap(RoborockBase): op: str result: int data: list +@dataclass class dpGetCarpet(RoborockBase): op: str result: int data: str +@dataclass class dpSelfIdentifyingCarpet(RoborockBase): op: str result: int data: str +@dataclass class dpNetInfo(RoborockBase): wifiName: str ipAdress: str @@ -33,6 +57,7 @@ class dpNetInfo(RoborockBase): signal: int +@dataclass class dpNotDisturbExpand(RoborockBase): disturb_dust_enable: int disturb_light: int @@ -40,14 +65,38 @@ class dpNotDisturbExpand(RoborockBase): disturb_voice: int +@dataclass class dpCurrentCleanRoomIds(RoborockBase): room_id_list: list +@dataclass class dpVoiceVersion(RoborockBase): version: int +@dataclass class dpTimeZone(RoborockBase): timeZoneCity: str timeZoneSec: int + + +@dataclass +class Q10Status(RoborockBase): + """Status for Q10 devices. + + Fields are mapped to DPS values using metadata. Objects of this class can be + automatically updated using the `update_from_dps` helper. + """ + + clean_time: int | None = field(default=None, metadata={"dps": B01_Q10_DP.CLEAN_TIME}) + clean_area: int | None = field(default=None, metadata={"dps": B01_Q10_DP.CLEAN_AREA}) + battery: int | None = field(default=None, metadata={"dps": B01_Q10_DP.BATTERY}) + status: YXDeviceState | None = field(default=None, metadata={"dps": B01_Q10_DP.STATUS}) + fan_level: YXFanLevel | None = field(default=None, metadata={"dps": B01_Q10_DP.FAN_LEVEL}) + water_level: YXWaterLevel | None = field(default=None, metadata={"dps": B01_Q10_DP.WATER_LEVEL}) + clean_count: int | None = field(default=None, metadata={"dps": B01_Q10_DP.CLEAN_COUNT}) + clean_mode: YXDeviceWorkMode | None = field(default=None, metadata={"dps": B01_Q10_DP.CLEAN_MODE}) + clean_task_type: YXDeviceCleanTask | None = field(default=None, metadata={"dps": B01_Q10_DP.CLEAN_TASK_TYPE}) + back_type: YXBackType | None = field(default=None, metadata={"dps": B01_Q10_DP.BACK_TYPE}) + cleaning_progress: int | None = field(default=None, metadata={"dps": B01_Q10_DP.CLEANING_PROGRESS}) diff --git a/roborock/data/containers.py b/roborock/data/containers.py index 57d5e6b2..e40f62c4 100644 --- a/roborock/data/containers.py +++ b/roborock/data/containers.py @@ -91,10 +91,10 @@ def from_dict(cls, data: dict[str, Any]): if not isinstance(data, dict): return None field_types = {field.name: field.type for field in dataclasses.fields(cls)} - result: dict[str, Any] = {} + normalized_data: dict[str, Any] = {} for orig_key, value in data.items(): key = _decamelize(orig_key) - if (field_type := field_types.get(key)) is None: + if field_types.get(key) is None: if (log_key := f"{cls.__name__}.{key}") not in RoborockBase._missing_logged: _LOGGER.debug( "Key '%s' (decamelized: '%s') not found in %s fields, skipping", @@ -104,6 +104,23 @@ def from_dict(cls, data: dict[str, Any]): ) RoborockBase._missing_logged.add(log_key) continue + normalized_data[key] = value + + result = RoborockBase.convert_dict(field_types, normalized_data) + return cls(**result) + + @staticmethod + def convert_dict(types_map: dict[Any, type], data: dict[Any, Any]) -> dict[Any, Any]: + """Generic helper to convert a dictionary of values based on a schema map of types. + + This is meant to be used by traits that use dataclass reflection similar to + `Roborock.from_dict` to merge in new data updates. + """ + result: dict[Any, Any] = {} + for key, value in data.items(): + if key not in types_map: + continue + field_type = types_map[key] if value == "None" or value is None: result[key] = None continue @@ -124,7 +141,7 @@ def from_dict(cls, data: dict[str, Any]): _LOGGER.exception(f"Failed to convert {key} with value {value} to type {field_type}") continue - return cls(**result) + return result def as_dict(self) -> dict: return asdict( diff --git a/roborock/devices/device.py b/roborock/devices/device.py index ca1fbf14..29f1fd28 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -197,12 +197,14 @@ async def connect(self) -> None: if self._unsub: raise ValueError("Already connected to the device") unsub = await self._channel.subscribe(self._on_message) - if self.v1_properties is not None: - try: + try: + if self.v1_properties is not None: await self.v1_properties.discover_features() - except RoborockException: - unsub() - raise + elif self.b01_q10_properties is not None: + await self.b01_q10_properties.start() + except RoborockException: + unsub() + raise self._logger.info("Connected to device") self._unsub = unsub @@ -214,6 +216,8 @@ async def close(self) -> None: await self._connect_task except asyncio.CancelledError: pass + if self.b01_q10_properties is not None: + await self.b01_q10_properties.close() if self._unsub: self._unsub() self._unsub = None diff --git a/roborock/devices/rpc/b01_q10_channel.py b/roborock/devices/rpc/b01_q10_channel.py index a482e109..d27b148b 100644 --- a/roborock/devices/rpc/b01_q10_channel.py +++ b/roborock/devices/rpc/b01_q10_channel.py @@ -3,18 +3,39 @@ from __future__ import annotations import logging +from collections.abc import AsyncGenerator +from typing import Any from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP from roborock.devices.transport.mqtt_channel import MqttChannel from roborock.exceptions import RoborockException from roborock.protocols.b01_q10_protocol import ( ParamsType, + decode_rpc_response, encode_mqtt_payload, ) _LOGGER = logging.getLogger(__name__) +async def stream_decoded_responses( + mqtt_channel: MqttChannel, +) -> AsyncGenerator[dict[B01_Q10_DP, Any], None]: + """Stream decoded DPS messages received via MQTT.""" + + async for response_message in mqtt_channel.subscribe_stream(): + try: + decoded_dps = decode_rpc_response(response_message) + except RoborockException as ex: + _LOGGER.debug( + "Failed to decode B01 Q10 RPC response: %s: %s", + response_message, + ex, + ) + continue + yield decoded_dps + + async def send_command( mqtt_channel: MqttChannel, command: B01_Q10_DP, diff --git a/roborock/devices/traits/b01/q10/__init__.py b/roborock/devices/traits/b01/q10/__init__.py index ac897259..1cd89bd8 100644 --- a/roborock/devices/traits/b01/q10/__init__.py +++ b/roborock/devices/traits/b01/q10/__init__.py @@ -1,15 +1,24 @@ """Traits for Q10 B01 devices.""" +import asyncio +import logging +from typing import Any + +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.devices.rpc.b01_q10_channel import stream_decoded_responses from roborock.devices.traits import Trait from roborock.devices.transport.mqtt_channel import MqttChannel from .command import CommandTrait +from .status import StatusTrait from .vacuum import VacuumTrait __all__ = [ "Q10PropertiesApi", ] +_LOGGER = logging.getLogger(__name__) + class Q10PropertiesApi(Trait): """API for interacting with B01 devices.""" @@ -17,13 +26,49 @@ class Q10PropertiesApi(Trait): command: CommandTrait """Trait for sending commands to Q10 devices.""" + status: StatusTrait + """Trait for managing the status of Q10 devices.""" + vacuum: VacuumTrait """Trait for sending vacuum related commands to Q10 devices.""" def __init__(self, channel: MqttChannel) -> None: """Initialize the B01Props API.""" + self._channel = channel self.command = CommandTrait(channel) self.vacuum = VacuumTrait(self.command) + self.status = StatusTrait() + self._subscribe_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + """Start any necessary subscriptions for the trait.""" + self._subscribe_task = asyncio.create_task(self._subscribe_loop()) + + async def close(self) -> None: + """Close any resources held by the trait.""" + if self._subscribe_task is not None: + self._subscribe_task.cancel() + try: + await self._subscribe_task + except asyncio.CancelledError: + pass # ignore cancellation errors + self._subscribe_task = None + + async def refresh(self) -> None: + """Refresh all traits.""" + # Sending the REQUEST_DPS will cause the device to send all DPS values + # to the device. Updates will be received by the subscribe loop below. + await self.command.send(B01_Q10_DP.REQUEST_DPS, params={}) + + async def _subscribe_loop(self) -> None: + """Persistent loop to listen for status updates.""" + async for decoded_dps in stream_decoded_responses(self._channel): + _LOGGER.debug("Received Q10 status update: %s", decoded_dps) + + # Notify all traits about a new message and each trait will + # only update what fields that it is responsible for. + # More traits can be added here below. + self.status.update_from_dps(decoded_dps) def create(channel: MqttChannel) -> Q10PropertiesApi: diff --git a/roborock/devices/traits/b01/q10/common.py b/roborock/devices/traits/b01/q10/common.py new file mode 100644 index 00000000..ad66e895 --- /dev/null +++ b/roborock/devices/traits/b01/q10/common.py @@ -0,0 +1,82 @@ +"""Common utilities for Q10 traits. + +This module provides infrastructure for mapping Roborock Data Points (DPS) to +Python dataclass fields and handling the lifecycle of data updates from the +device. + +### DPS Metadata Annotation + +Classes extending `RoborockBase` can annotate their fields with DPS IDs using +the `field(metadata={"dps": ...})` convention. This creates a declarative +mapping that `DpsDataConverter` uses to automatically route incoming device +data to the correct attribute. + +Example: + +```python +@dataclass +class MyStatus(RoborockBase): + battery: int = field(metadata={"dps": B01_Q10_DP.BATTERY}) +``` + +### Update Lifecycle +1. **Raw Data**: The device sends encoded DPS updates over MQTT. +2. **Decoding**: The transport layer decodes these into a dictionary (e.g., `{"101": 80}`). +3. **Conversion**: `DpsDataConverter` uses `RoborockBase.convert_dict` to transform + raw values into appropriate Python types (e.g., Enums, ints) based on the + dataclass field types. +4. **Update**: `update_from_dps` maps these converted values to field names and + updates the target object using `setattr`. + +### Usage + +Typically, a trait will instantiate a single `DpsDataConverter` for its status class +and call `update_from_dps` whenever new data is received from the device stream. + +""" + +import dataclasses +from typing import Any + +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.data.containers import RoborockBase + + +class DpsDataConverter: + """Utility to handle the transformation and merging of DPS data into models. + + This class pre-calculates the mapping between Data Point IDs and dataclass fields + to optimize repeated updates from device streams. + """ + + def __init__(self, dps_type_map: dict[B01_Q10_DP, type], dps_field_map: dict[B01_Q10_DP, str]): + """Initialize the converter for a specific RoborockBase-derived class.""" + self._dps_type_map = dps_type_map + self._dps_field_map = dps_field_map + + @classmethod + def from_dataclass(cls, dataclass_type: type[RoborockBase]): + """Initialize the converter for a specific RoborockBase-derived class.""" + dps_type_map: dict[B01_Q10_DP, type] = {} + dps_field_map: dict[B01_Q10_DP, str] = {} + for field_obj in dataclasses.fields(dataclass_type): + if field_obj.metadata and "dps" in field_obj.metadata: + dps_id = field_obj.metadata["dps"] + dps_type_map[dps_id] = field_obj.type + dps_field_map[dps_id] = field_obj.name + return cls(dps_type_map, dps_field_map) + + def update_from_dps(self, target: RoborockBase, decoded_dps: dict[B01_Q10_DP, Any]) -> None: + """Convert and merge raw DPS data into the target object. + + Uses the pre-calculated type mapping to ensure values are converted to the + correct Python types before being updated on the target. + + Args: + target: The target object to update. + decoded_dps: The decoded DPS data to convert. + """ + conversions = RoborockBase.convert_dict(self._dps_type_map, decoded_dps) + for dps_id, value in conversions.items(): + field_name = self._dps_field_map[dps_id] + setattr(target, field_name, value) diff --git a/roborock/devices/traits/b01/q10/status.py b/roborock/devices/traits/b01/q10/status.py new file mode 100644 index 00000000..7f44a526 --- /dev/null +++ b/roborock/devices/traits/b01/q10/status.py @@ -0,0 +1,26 @@ +"""Status trait for Q10 B01 devices.""" + +import logging + +from roborock.data.b01_q10.b01_q10_containers import Q10Status + +from .common import DpsDataConverter + +_LOGGER = logging.getLogger(__name__) + +_CONVERTER = DpsDataConverter.from_dataclass(Q10Status) + + +class StatusTrait(Q10Status): + """Trait for managing the status of Q10 Roborock devices. + + This is a thin wrapper around Q10Status that provides the Trait interface. + The current values reflect the most recently received data from the device. + New values can be requited through the `Q10PropertiesApi`'s `refresh` method. + """ + + def update_from_dps(self, decoded_dps: dict) -> None: + """Update the trait from raw DPS data.""" + _CONVERTER.update_from_dps(self, decoded_dps) + # In the future we can register listeners and notify them here on update + # if `update_from_dps` performed any updates. diff --git a/roborock/devices/transport/mqtt_channel.py b/roborock/devices/transport/mqtt_channel.py index 498cef13..5ff0ab08 100644 --- a/roborock/devices/transport/mqtt_channel.py +++ b/roborock/devices/transport/mqtt_channel.py @@ -1,7 +1,8 @@ """Modules for communicating with specific Roborock devices over MQTT.""" +import asyncio import logging -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable from roborock.callbacks import decoder_callback from roborock.data import HomeDataDevice, RRiot, UserData @@ -73,6 +74,21 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab dispatch = decoder_callback(self._decoder, callback, _LOGGER) return await self._mqtt_session.subscribe(self._subscribe_topic, dispatch) + async def subscribe_stream(self) -> AsyncGenerator[RoborockMessage, None]: + """Subscribe to the device's message stream. + + This is useful for processing all incoming messages in an async for loop, + when they are not necessarily associated with a specific request. + """ + message_queue: asyncio.Queue[RoborockMessage] = asyncio.Queue() + unsub = await self.subscribe(message_queue.put_nowait) + try: + while True: + message = await message_queue.get() + yield message + finally: + unsub() + async def publish(self, message: RoborockMessage) -> None: """Publish a command message. diff --git a/tests/devices/traits/b01/q10/__init__.py b/tests/devices/traits/b01/q10/__init__.py new file mode 100644 index 00000000..78977420 --- /dev/null +++ b/tests/devices/traits/b01/q10/__init__.py @@ -0,0 +1 @@ +"""Tests for the Q10 B01 traits.""" diff --git a/tests/devices/traits/b01/q10/test_status.py b/tests/devices/traits/b01/q10/test_status.py new file mode 100644 index 00000000..b2c56ee9 --- /dev/null +++ b/tests/devices/traits/b01/q10/test_status.py @@ -0,0 +1,141 @@ +"""Tests for the Q10 B01 status trait.""" + +import asyncio +import json +import pathlib +from collections.abc import AsyncGenerator +from typing import Any +from unittest.mock import AsyncMock, Mock + +import pytest + +from roborock.data.b01_q10.b01_q10_code_mappings import ( + YXDeviceCleanTask, + YXDeviceState, + YXFanLevel, +) +from roborock.devices.traits.b01.q10 import Q10PropertiesApi, create +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol + +TEST_DATA_DIR = pathlib.Path("tests/protocols/testdata/b01_q10_protocol") + +TESTDATA_DP_STATUS_DP_CLEAN_TASK_TYPE = (TEST_DATA_DIR / "dpStatus-dpCleanTaskType.json").read_bytes() +TESTDATA_DP_REQUEST_DPS = (TEST_DATA_DIR / "dpRequetdps.json").read_bytes() + + +@pytest.fixture +def mock_channel(): + """Fixture for a mocked MQTT channel.""" + mock = AsyncMock() + return mock + + +@pytest.fixture +def message_queue() -> asyncio.Queue[RoborockMessage]: + """Fixture for a message queue used by the mock stream.""" + return asyncio.Queue() + + +@pytest.fixture +def mock_subscribe_stream(mock_channel: AsyncMock, message_queue: asyncio.Queue[RoborockMessage]) -> Mock: + """Fixture to mock the subscribe_stream method to yield from a queue.""" + + async def mock_stream() -> AsyncGenerator[RoborockMessage, None]: + while True: + yield await message_queue.get() + + mock = Mock(return_value=mock_stream()) + mock_channel.subscribe_stream = mock + return mock + + +@pytest.fixture +async def q10_api(mock_channel: AsyncMock, mock_subscribe_stream: Mock) -> AsyncGenerator[Q10PropertiesApi, None]: + """Fixture to create and manage the Q10PropertiesApi.""" + api = create(mock_channel) + await api.start() + yield api + await api.close() + + +def build_message(payload: bytes) -> RoborockMessage: + """Helper to build a RoborockMessage for testing.""" + return RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=payload, + version=b"B01", + ) + + +async def wait_for_attribute_value(obj: Any, attribute: str, value: Any, timeout: float = 2.0) -> None: + """Wait for an attribute on an object to reach a specific value. + + This is a temporary polling solution until listeners are implemented. + """ + for _ in range(int(timeout / 0.1)): + if getattr(obj, attribute) == value: + return + await asyncio.sleep(0.1) + pytest.fail(f"Timeout waiting for {attribute} to become {value} on {obj}") + + +async def test_status_trait_streaming( + q10_api: Q10PropertiesApi, + message_queue: asyncio.Queue[RoborockMessage], +) -> None: + """Test that the StatusTrait updates its state from streaming messages.""" + # status (121) = 8 (CHARGING_STATE) + # clean_task_type (138) = 0 (IDLE) + message = build_message(TESTDATA_DP_STATUS_DP_CLEAN_TASK_TYPE) + + assert q10_api.status.status is None + assert q10_api.status.clean_task_type is None + + # Push the message into the queue + message_queue.put_nowait(message) + + # Wait for the update + await wait_for_attribute_value(q10_api.status, "status", YXDeviceState.CHARGING_STATE) + + # Verify trait attributes are updated + assert q10_api.status.status == YXDeviceState.CHARGING_STATE + assert q10_api.status.clean_task_type == YXDeviceCleanTask.IDLE + + +async def test_status_trait_refresh( + q10_api: Q10PropertiesApi, + mock_channel: AsyncMock, + message_queue: asyncio.Queue[RoborockMessage], +) -> None: + """Test that the StatusTrait sends a refresh command and updates state.""" + assert q10_api.status.battery is None + assert q10_api.status.status is None + assert q10_api.status.fan_level is None + + # Mock the response to refresh + # battery (122) = 100 + # status (121) = 8 (CHARGING_STATE) + # fun_level (123) = 2 (NORMAL) + message = build_message(TESTDATA_DP_REQUEST_DPS) + + # Send a refresh command + await q10_api.refresh() + mock_channel.publish.assert_called_once() + sent_message = mock_channel.publish.call_args[0][0] + assert sent_message.protocol == RoborockMessageProtocol.RPC_REQUEST + # Verify refresh payload + data = json.loads(sent_message.payload) + assert data + assert data.get("dps") + assert data.get("dps").get("102") == {} # REQUEST_DPS code is 102 + + # Push the response message into the queue + message_queue.put_nowait(message) + + # Wait for the update + await wait_for_attribute_value(q10_api.status, "battery", 100) + + # Verify trait attributes are updated + assert q10_api.status.battery == 100 + assert q10_api.status.status == YXDeviceState.CHARGING_STATE + assert q10_api.status.fan_level == YXFanLevel.NORMAL