diff --git a/kasa/smart/modules/childdevicemodule.py b/kasa/smart/modules/childdevicemodule.py index 62e024d0c..9f4710b2d 100644 --- a/kasa/smart/modules/childdevicemodule.py +++ b/kasa/smart/modules/childdevicemodule.py @@ -1,5 +1,4 @@ """Implementation for child devices.""" -from typing import Dict from ..smartmodule import SmartModule @@ -8,12 +7,4 @@ class ChildDeviceModule(SmartModule): """Implementation for child devices.""" REQUIRED_COMPONENT = "child_device" - - def query(self) -> Dict: - """Query to execute during the update cycle.""" - # TODO: There is no need to fetch the component list every time, - # so this should be optimized only for the init. - return { - "get_child_device_list": None, - "get_child_device_component_list": None, - } + QUERY_GETTER_NAME = "get_child_device_list" diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 8b0236c37..3cbd12f97 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -41,10 +41,18 @@ def __init__( self.modules: Dict[str, "SmartModule"] = {} self._parent: Optional["SmartDevice"] = None self._children: Mapping[str, "SmartDevice"] = {} + self._last_update = {} async def _initialize_children(self): """Initialize children for power strips.""" - children = self.internal_state["child_info"]["child_device_list"] + child_info_query = { + "get_child_device_component_list": None, + "get_child_device_list": None, + } + resp = await self.protocol.query(child_info_query) + self.internal_state.update(resp) + + children = self.internal_state["get_child_device_list"]["child_device_list"] children_components = { child["device_id"]: { comp["id"]: int(comp["ver_code"]) for comp in child["component_list"] @@ -88,13 +96,30 @@ def _try_get_response(self, responses: dict, request: str, default=None) -> dict ) async def _negotiate(self): - resp = await self.protocol.query("component_nego") + """Perform initialization. + + We fetch the device info and the available components as early as possible. + If the device reports supporting child devices, they are also initialized. + """ + initial_query = {"component_nego": None, "get_device_info": None} + resp = await self.protocol.query(initial_query) + + # Save the initial state to allow modules access the device info already + # during the initialization, which is necessary as some information like the + # supported color temperature range is contained within the response. + self._last_update.update(resp) + self._info = self._try_get_response(resp, "get_device_info") + + # Create our internal presentation of available components self._components_raw = resp["component_nego"] self._components = { comp["id"]: int(comp["ver_code"]) for comp in self._components_raw["component_list"] } + if "child_device" in self._components and not self.children: + await self._initialize_children() + async def update(self, update_children: bool = True): """Update the device.""" if self.credentials is None and self.credentials_hash is None: @@ -110,20 +135,10 @@ async def update(self, update_children: bool = True): for module in self.modules.values(): req.update(module.query()) - resp = await self.protocol.query(req) + self._last_update = resp = await self.protocol.query(req) self._info = self._try_get_response(resp, "get_device_info") - - self._last_update = { - "components": self._components_raw, - **resp, - "child_info": self._try_get_response(resp, "get_child_device_list", {}), - } - - if child_info := self._last_update.get("child_info"): - if not self.children: - await self._initialize_children() - + if child_info := self._try_get_response(resp, "get_child_device_list", {}): # TODO: we don't currently perform queries on children based on modules, # but just update the information that is returned in the main query. for info in child_info["child_device_list"]: diff --git a/kasa/tests/test_childdevice.py b/kasa/tests/test_childdevice.py index 07baf598b..97d3fd376 100644 --- a/kasa/tests/test_childdevice.py +++ b/kasa/tests/test_childdevice.py @@ -24,7 +24,7 @@ def test_childdevice_init(dev, dummy_protocol, mocker): @strip_smart async def test_childdevice_update(dev, dummy_protocol, mocker): """Test that parent update updates children.""" - child_info = dev._last_update["child_info"] + child_info = dev.internal_state["get_child_device_list"] child_list = child_info["child_device_list"] assert len(dev.children) == child_info["sum"] diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index a9871fa29..d7b1cca9d 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -1,8 +1,9 @@ """Tests for SMART devices.""" import logging -from unittest.mock import patch +from typing import Any, Dict -import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 +import pytest +from pytest_mock import MockerFixture from kasa import KasaException from kasa.exceptions import SmartErrorCode @@ -25,13 +26,79 @@ async def test_try_get_response(dev: SmartDevice, caplog): @device_smart -async def test_update_no_device_info(dev: SmartDevice): +async def test_update_no_device_info(dev: SmartDevice, mocker: MockerFixture): mock_response: dict = { "get_device_usage": {}, "get_device_time": {}, } msg = f"get_device_info not found in {mock_response} for device 127.0.0.123" - with patch.object(dev.protocol, "query", return_value=mock_response), pytest.raises( - KasaException, match=msg - ): + with mocker.patch.object( + dev.protocol, "query", return_value=mock_response + ), pytest.raises(KasaException, match=msg): await dev.update() + + +@device_smart +async def test_initial_update(dev: SmartDevice, mocker: MockerFixture): + """Test the initial update cycle.""" + # As the fixture data is already initialized, we reset the state for testing + dev._components_raw = None + dev._features = {} + + negotiate = mocker.spy(dev, "_negotiate") + initialize_modules = mocker.spy(dev, "_initialize_modules") + initialize_features = mocker.spy(dev, "_initialize_features") + + # Perform two updates and verify that initialization is only done once + await dev.update() + await dev.update() + + negotiate.assert_called_once() + assert dev._components_raw is not None + initialize_modules.assert_called_once() + assert dev.modules + initialize_features.assert_called_once() + assert dev.features + + +@device_smart +async def test_negotiate(dev: SmartDevice, mocker: MockerFixture): + """Test that the initial negotiation performs expected steps.""" + # As the fixture data is already initialized, we reset the state for testing + dev._components_raw = None + dev._children = {} + + query = mocker.spy(dev.protocol, "query") + initialize_children = mocker.spy(dev, "_initialize_children") + await dev._negotiate() + + # Check that we got the initial negotiation call + query.assert_any_call({"component_nego": None, "get_device_info": None}) + assert dev._components_raw + + # Check the children are created, if device supports them + if "child_device" in dev._components: + initialize_children.assert_called_once() + query.assert_any_call( + { + "get_child_device_component_list": None, + "get_child_device_list": None, + } + ) + assert len(dev.children) == dev.internal_state["get_child_device_list"]["sum"] + + +@device_smart +async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): + """Test that the regular update uses queries from all supported modules.""" + query = mocker.spy(dev.protocol, "query") + + # We need to have some modules initialized by now + assert dev.modules + + await dev.update() + full_query: Dict[str, Any] = {} + for mod in dev.modules.values(): + full_query |= mod.query() + + query.assert_called_with(full_query)