From 2b910fb64ed000885999fa56befbbcb6d3fa5df2 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 1 May 2024 07:57:02 +0100 Subject: [PATCH 1/6] Make get_module return typed module --- kasa/device.py | 6 +++++- kasa/iot/iotdevice.py | 7 +++++++ kasa/module.py | 13 ++++++++++++- kasa/smart/smartdevice.py | 7 ++++--- kasa/tests/smart/features/test_brightness.py | 3 ++- kasa/tests/smart/modules/test_fan.py | 10 +++++----- 6 files changed, 35 insertions(+), 11 deletions(-) diff --git a/kasa/device.py b/kasa/device.py index 8a81030f8..d3f185dd1 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -15,7 +15,7 @@ from .exceptions import KasaException from .feature import Feature from .iotprotocol import IotProtocol -from .module import Module +from .module import Module, ModuleName, ModuleT from .protocol import BaseProtocol from .xortransport import XorTransport @@ -116,6 +116,10 @@ async def disconnect(self): def modules(self) -> Mapping[str, Module]: """Return the device modules.""" + @abstractmethod + def get_module(self, module_name: ModuleName[ModuleT]) -> ModuleT | None: + """Return the module from the device modules or None if not present.""" + @property @abstractmethod def is_on(self) -> bool: diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index 81b5eddac..ab2c76f1b 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -26,6 +26,7 @@ from ..emeterstatus import EmeterStatus from ..exceptions import KasaException from ..feature import Feature +from ..module import ModuleName, ModuleT from ..protocol import BaseProtocol from .iotmodule import IotModule from .modules import Emeter, Time @@ -201,6 +202,12 @@ def modules(self) -> dict[str, IotModule]: """Return the device modules.""" return self._modules + def get_module(self, module_name: ModuleName[ModuleT]) -> ModuleT | None: + """Return the module from the device modules or None if not present.""" + if module_name in self.modules: + return cast(ModuleT, self.modules[module_name]) + return None + def add_module(self, name: str, module: IotModule): """Register a module.""" if name in self.modules: diff --git a/kasa/module.py b/kasa/module.py index 8422eaf94..d831adbfd 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -4,7 +4,7 @@ import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generic, TypeVar from .exceptions import KasaException from .feature import Feature @@ -14,6 +14,17 @@ _LOGGER = logging.getLogger(__name__) +ModuleT = TypeVar("ModuleT", bound="Module") + + +class ModuleName(str, Generic[ModuleT]): + """Custom generic type for module names. + + At runtime this is a generic subclass of str. + """ + + __slots__ = () + class Module(ABC): """Base class implemention for all modules. diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 04c2607be..0dd2d1007 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -16,6 +16,7 @@ from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode from ..fan import Fan from ..feature import Feature +from ..module import ModuleName, ModuleT from ..smartprotocol import SmartProtocol from .modules import ( Brightness, @@ -308,14 +309,14 @@ async def _initialize_features(self): for feat in module._module_features.values(): self._add_feature(feat) - def get_module(self, module_name) -> SmartModule | None: + def get_module(self, module_name: ModuleName[ModuleT]) -> ModuleT | None: """Return the module from the device modules or None if not present.""" if module_name in self.modules: - return self.modules[module_name] + return cast(ModuleT, self.modules[module_name]) elif self._exposes_child_modules: for child in self._children.values(): if module_name in child.modules: - return child.modules[module_name] + return cast(ModuleT, child.modules[module_name]) return None @property diff --git a/kasa/tests/smart/features/test_brightness.py b/kasa/tests/smart/features/test_brightness.py index 79df0abf9..8bf7b0d8f 100644 --- a/kasa/tests/smart/features/test_brightness.py +++ b/kasa/tests/smart/features/test_brightness.py @@ -1,6 +1,7 @@ import pytest from kasa.iot import IotDevice +from kasa.module import ModuleName from kasa.smart import SmartDevice from kasa.tests.conftest import dimmable, parametrize @@ -10,7 +11,7 @@ @brightness async def test_brightness_component(dev: SmartDevice): """Test brightness feature.""" - brightness = dev.get_module("Brightness") + brightness = dev.get_module(ModuleName("Brightness")) assert brightness assert isinstance(dev, SmartDevice) assert "brightness" in dev._components diff --git a/kasa/tests/smart/modules/test_fan.py b/kasa/tests/smart/modules/test_fan.py index 429a5d18f..26ff34313 100644 --- a/kasa/tests/smart/modules/test_fan.py +++ b/kasa/tests/smart/modules/test_fan.py @@ -1,8 +1,7 @@ -from typing import cast - import pytest from pytest_mock import MockerFixture +from kasa.module import ModuleName from kasa.smart import SmartDevice from kasa.smart.modules import FanModule from kasa.tests.device_fixtures import parametrize @@ -13,7 +12,7 @@ @fan async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): """Test fan speed feature.""" - fan = cast(FanModule, dev.get_module("FanModule")) + fan = dev.get_module(ModuleName[FanModule]("FanModule")) assert fan level_feature = fan._module_features["fan_speed_level"] @@ -38,7 +37,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): @fan async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): """Test sleep mode feature.""" - fan = cast(FanModule, dev.get_module("FanModule")) + fan = dev.get_module(ModuleName[FanModule]("FanModule")) assert fan sleep_feature = fan._module_features["fan_sleep_mode"] assert isinstance(sleep_feature.value, bool) @@ -57,7 +56,8 @@ async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): async def test_fan_interface(dev: SmartDevice, mocker: MockerFixture): """Test fan speed on device interface.""" assert isinstance(dev, SmartDevice) - fan = cast(FanModule, dev.get_module("FanModule")) + fan = dev.get_module(ModuleName[FanModule]("FanModule")) + assert fan device = fan._device assert device.is_fan From f7bd23b58ea530f846b3360c4e199b05617efe8c Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 1 May 2024 14:42:34 +0100 Subject: [PATCH 2/6] Update to use overloads and allow module class to be passed instead of string --- kasa/device.py | 14 ++++++++++++-- kasa/iot/iotdevice.py | 13 +++++++++++-- kasa/module.py | 12 +++++++++++- kasa/smart/smartdevice.py | 14 ++++++++++++-- kasa/tests/smart/features/test_brightness.py | 5 ++--- kasa/tests/smart/modules/test_fan.py | 6 +++--- 6 files changed, 51 insertions(+), 13 deletions(-) diff --git a/kasa/device.py b/kasa/device.py index d3f185dd1..78891d78e 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Any, Mapping, Sequence +from typing import Any, Mapping, Sequence, overload from .credentials import Credentials from .device_type import DeviceType @@ -116,8 +116,18 @@ async def disconnect(self): def modules(self) -> Mapping[str, Module]: """Return the device modules.""" + @overload @abstractmethod - def get_module(self, module_name: ModuleName[ModuleT]) -> ModuleT | None: + def get_module(self, module_name: ModuleName[ModuleT]) -> ModuleT | None: ... + + @overload + @abstractmethod + def get_module(self, module_name: str) -> Module | None: ... + + @abstractmethod + def get_module( + self, module_name: ModuleName[ModuleT] | str + ) -> ModuleT | Module | None: """Return the module from the device modules or None if not present.""" @property diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index ab2c76f1b..dd1350287 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -19,7 +19,7 @@ import inspect import logging from datetime import datetime, timedelta -from typing import Any, Mapping, Sequence, cast +from typing import Any, Mapping, Sequence, cast, overload from ..device import Device, WifiNetwork from ..deviceconfig import DeviceConfig @@ -202,7 +202,16 @@ def modules(self) -> dict[str, IotModule]: """Return the device modules.""" return self._modules - def get_module(self, module_name: ModuleName[ModuleT]) -> ModuleT | None: + @overload + def get_module(self, module_name: ModuleName[ModuleT]) -> ModuleT | None: # type: ignore[overload-overlap] + ... + + @overload + def get_module(self, module_name: str) -> IotModule | None: ... + + def get_module( + self, module_name: ModuleName[ModuleT] | str + ) -> ModuleT | IotModule | None: """Return the module from the device modules or None if not present.""" if module_name in self.modules: return cast(ModuleT, self.modules[module_name]) diff --git a/kasa/module.py b/kasa/module.py index d831adbfd..708d2ae82 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -4,7 +4,11 @@ import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import ( + TYPE_CHECKING, + Generic, + TypeVar, +) from .exceptions import KasaException from .feature import Feature @@ -23,6 +27,12 @@ class ModuleName(str, Generic[ModuleT]): At runtime this is a generic subclass of str. """ + def __new__(cls, value: type[ModuleT] | str): + """Create new ModuleName instance.""" + value = value if isinstance(value, str) else value.__name__ + obj = str.__new__(cls, value) + return obj + __slots__ = () diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 0dd2d1007..94e45439e 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -5,7 +5,7 @@ import base64 import logging from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast +from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast, overload from ..aestransport import AesTransport from ..bulb import HSV, Bulb, BulbPreset, ColorTempRange @@ -35,6 +35,7 @@ if TYPE_CHECKING: from .smartmodule import SmartModule + # List of modules that wall switches with children, i.e. ks240 report on # the child but only work on the parent. See longer note below in _initialize_modules. # This list should be updated when creating new modules that could have the @@ -309,7 +310,16 @@ async def _initialize_features(self): for feat in module._module_features.values(): self._add_feature(feat) - def get_module(self, module_name: ModuleName[ModuleT]) -> ModuleT | None: + @overload + def get_module(self, module_name: ModuleName[ModuleT]) -> ModuleT | None: # type: ignore[overload-overlap] + ... + + @overload + def get_module(self, module_name: str) -> SmartModule | None: ... + + def get_module( + self, module_name: ModuleName[ModuleT] | str + ) -> ModuleT | SmartModule | None: """Return the module from the device modules or None if not present.""" if module_name in self.modules: return cast(ModuleT, self.modules[module_name]) diff --git a/kasa/tests/smart/features/test_brightness.py b/kasa/tests/smart/features/test_brightness.py index 8bf7b0d8f..02a396aae 100644 --- a/kasa/tests/smart/features/test_brightness.py +++ b/kasa/tests/smart/features/test_brightness.py @@ -1,7 +1,6 @@ import pytest from kasa.iot import IotDevice -from kasa.module import ModuleName from kasa.smart import SmartDevice from kasa.tests.conftest import dimmable, parametrize @@ -11,7 +10,7 @@ @brightness async def test_brightness_component(dev: SmartDevice): """Test brightness feature.""" - brightness = dev.get_module(ModuleName("Brightness")) + brightness = dev.get_module("Brightness") assert brightness assert isinstance(dev, SmartDevice) assert "brightness" in dev._components @@ -34,7 +33,7 @@ async def test_brightness_component(dev: SmartDevice): @dimmable -async def test_brightness_dimmable(dev: SmartDevice): +async def test_brightness_dimmable(dev: IotDevice): """Test brightness feature.""" assert isinstance(dev, IotDevice) assert "brightness" in dev.sys_info or bool(dev.sys_info["is_dimmable"]) diff --git a/kasa/tests/smart/modules/test_fan.py b/kasa/tests/smart/modules/test_fan.py index 26ff34313..45836a4ae 100644 --- a/kasa/tests/smart/modules/test_fan.py +++ b/kasa/tests/smart/modules/test_fan.py @@ -12,7 +12,7 @@ @fan async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): """Test fan speed feature.""" - fan = dev.get_module(ModuleName[FanModule]("FanModule")) + fan = dev.get_module(ModuleName(FanModule)) assert fan level_feature = fan._module_features["fan_speed_level"] @@ -37,7 +37,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): @fan async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): """Test sleep mode feature.""" - fan = dev.get_module(ModuleName[FanModule]("FanModule")) + fan = dev.get_module(ModuleName(FanModule)) assert fan sleep_feature = fan._module_features["fan_sleep_mode"] assert isinstance(sleep_feature.value, bool) @@ -56,7 +56,7 @@ async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): async def test_fan_interface(dev: SmartDevice, mocker: MockerFixture): """Test fan speed on device interface.""" assert isinstance(dev, SmartDevice) - fan = dev.get_module(ModuleName[FanModule]("FanModule")) + fan = dev.get_module(ModuleName(FanModule)) assert fan device = fan._device assert device.is_fan From 52690faa48c171b0e9c00980f290fc7117d13210 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 1 May 2024 14:46:37 +0100 Subject: [PATCH 3/6] Remove no longer necessary cast --- kasa/iot/iotdevice.py | 2 +- kasa/smart/smartdevice.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index dd1350287..bfa592e90 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -214,7 +214,7 @@ def get_module( ) -> ModuleT | IotModule | None: """Return the module from the device modules or None if not present.""" if module_name in self.modules: - return cast(ModuleT, self.modules[module_name]) + return self.modules[module_name] return None def add_module(self, name: str, module: IotModule): diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 94e45439e..9f719b208 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -322,11 +322,11 @@ def get_module( ) -> ModuleT | SmartModule | None: """Return the module from the device modules or None if not present.""" if module_name in self.modules: - return cast(ModuleT, self.modules[module_name]) + return self.modules[module_name] elif self._exposes_child_modules: for child in self._children.values(): if module_name in child.modules: - return cast(ModuleT, child.modules[module_name]) + return child.modules[module_name] return None @property From d171681197c5741775dd45e2d0f693ea9ba1e27d Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Fri, 3 May 2024 09:32:57 +0100 Subject: [PATCH 4/6] Remove ModuleName --- kasa/device.py | 10 ++++------ kasa/iot/iotdevice.py | 15 ++++++++++----- kasa/module.py | 16 ---------------- kasa/smart/smartdevice.py | 15 ++++++++++----- kasa/tests/smart/modules/test_fan.py | 7 +++---- 5 files changed, 27 insertions(+), 36 deletions(-) diff --git a/kasa/device.py b/kasa/device.py index 78891d78e..7d4c3022e 100644 --- a/kasa/device.py +++ b/kasa/device.py @@ -15,7 +15,7 @@ from .exceptions import KasaException from .feature import Feature from .iotprotocol import IotProtocol -from .module import Module, ModuleName, ModuleT +from .module import Module, ModuleT from .protocol import BaseProtocol from .xortransport import XorTransport @@ -118,16 +118,14 @@ def modules(self) -> Mapping[str, Module]: @overload @abstractmethod - def get_module(self, module_name: ModuleName[ModuleT]) -> ModuleT | None: ... + def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ... @overload @abstractmethod - def get_module(self, module_name: str) -> Module | None: ... + def get_module(self, module_type: str) -> Module | None: ... @abstractmethod - def get_module( - self, module_name: ModuleName[ModuleT] | str - ) -> ModuleT | Module | None: + def get_module(self, module_type: type[ModuleT] | str) -> ModuleT | Module | None: """Return the module from the device modules or None if not present.""" @property diff --git a/kasa/iot/iotdevice.py b/kasa/iot/iotdevice.py index bfa592e90..e69de80cd 100755 --- a/kasa/iot/iotdevice.py +++ b/kasa/iot/iotdevice.py @@ -26,7 +26,7 @@ from ..emeterstatus import EmeterStatus from ..exceptions import KasaException from ..feature import Feature -from ..module import ModuleName, ModuleT +from ..module import ModuleT from ..protocol import BaseProtocol from .iotmodule import IotModule from .modules import Emeter, Time @@ -203,16 +203,21 @@ def modules(self) -> dict[str, IotModule]: return self._modules @overload - def get_module(self, module_name: ModuleName[ModuleT]) -> ModuleT | None: # type: ignore[overload-overlap] - ... + def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ... @overload - def get_module(self, module_name: str) -> IotModule | None: ... + def get_module(self, module_type: str) -> IotModule | None: ... def get_module( - self, module_name: ModuleName[ModuleT] | str + self, module_type: type[ModuleT] | str ) -> ModuleT | IotModule | None: """Return the module from the device modules or None if not present.""" + if isinstance(module_type, str): + module_name = module_type.lower() + elif issubclass(module_type, IotModule): + module_name = module_type.__name__.lower() + else: + return None if module_name in self.modules: return self.modules[module_name] return None diff --git a/kasa/module.py b/kasa/module.py index 708d2ae82..5b6354a9c 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -6,7 +6,6 @@ from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, - Generic, TypeVar, ) @@ -21,21 +20,6 @@ ModuleT = TypeVar("ModuleT", bound="Module") -class ModuleName(str, Generic[ModuleT]): - """Custom generic type for module names. - - At runtime this is a generic subclass of str. - """ - - def __new__(cls, value: type[ModuleT] | str): - """Create new ModuleName instance.""" - value = value if isinstance(value, str) else value.__name__ - obj = str.__new__(cls, value) - return obj - - __slots__ = () - - class Module(ABC): """Base class implemention for all modules. diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 9f719b208..0b07cabcc 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -16,7 +16,7 @@ from ..exceptions import AuthenticationError, DeviceError, KasaException, SmartErrorCode from ..fan import Fan from ..feature import Feature -from ..module import ModuleName, ModuleT +from ..module import ModuleT from ..smartprotocol import SmartProtocol from .modules import ( Brightness, @@ -311,16 +311,21 @@ async def _initialize_features(self): self._add_feature(feat) @overload - def get_module(self, module_name: ModuleName[ModuleT]) -> ModuleT | None: # type: ignore[overload-overlap] - ... + def get_module(self, module_type: type[ModuleT]) -> ModuleT | None: ... @overload - def get_module(self, module_name: str) -> SmartModule | None: ... + def get_module(self, module_type: str) -> SmartModule | None: ... def get_module( - self, module_name: ModuleName[ModuleT] | str + self, module_type: type[ModuleT] | str ) -> ModuleT | SmartModule | None: """Return the module from the device modules or None if not present.""" + if isinstance(module_type, str): + module_name = module_type + elif issubclass(module_type, SmartModule): + module_name = module_type.__name__ + else: + return None if module_name in self.modules: return self.modules[module_name] elif self._exposes_child_modules: diff --git a/kasa/tests/smart/modules/test_fan.py b/kasa/tests/smart/modules/test_fan.py index 45836a4ae..372459510 100644 --- a/kasa/tests/smart/modules/test_fan.py +++ b/kasa/tests/smart/modules/test_fan.py @@ -1,7 +1,6 @@ import pytest from pytest_mock import MockerFixture -from kasa.module import ModuleName from kasa.smart import SmartDevice from kasa.smart.modules import FanModule from kasa.tests.device_fixtures import parametrize @@ -12,7 +11,7 @@ @fan async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): """Test fan speed feature.""" - fan = dev.get_module(ModuleName(FanModule)) + fan = dev.get_module(FanModule) assert fan level_feature = fan._module_features["fan_speed_level"] @@ -37,7 +36,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture): @fan async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): """Test sleep mode feature.""" - fan = dev.get_module(ModuleName(FanModule)) + fan = dev.get_module(FanModule) assert fan sleep_feature = fan._module_features["fan_sleep_mode"] assert isinstance(sleep_feature.value, bool) @@ -56,7 +55,7 @@ async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture): async def test_fan_interface(dev: SmartDevice, mocker: MockerFixture): """Test fan speed on device interface.""" assert isinstance(dev, SmartDevice) - fan = dev.get_module(ModuleName(FanModule)) + fan = dev.get_module(FanModule) assert fan device = fan._device assert device.is_fan From 17621b460e9a86fbc394d5856b4b5f4c0fdaa621 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Fri, 3 May 2024 09:43:25 +0100 Subject: [PATCH 5/6] Fix SmartModule import --- kasa/smart/smartdevice.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index 7f0fbeb80..98c5f7efe 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -5,7 +5,7 @@ import base64 import logging from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast, overload +from typing import Any, Mapping, Sequence, cast, overload from ..aestransport import AesTransport from ..bulb import HSV, Bulb, BulbPreset, ColorTempRange @@ -29,12 +29,10 @@ Firmware, TimeModule, ) +from .smartmodule import SmartModule _LOGGER = logging.getLogger(__name__) -if TYPE_CHECKING: - from .smartmodule import SmartModule - # List of modules that wall switches with children, i.e. ks240 report on # the child but only work on the parent. See longer note below in _initialize_modules. From bb1492bed229fe18204a0c316c545b7cd25f18c6 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Fri, 3 May 2024 15:24:21 +0100 Subject: [PATCH 6/6] Add tests --- kasa/tests/test_iotdevice.py | 29 ++++++++++++++++++++++++++++- kasa/tests/test_smartdevice.py | 22 +++++++++++++++++++++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/kasa/tests/test_iotdevice.py b/kasa/tests/test_iotdevice.py index 4c5d5126a..b4d56291e 100644 --- a/kasa/tests/test_iotdevice.py +++ b/kasa/tests/test_iotdevice.py @@ -19,7 +19,7 @@ from kasa import KasaException from kasa.iot import IotDevice -from .conftest import handle_turn_on, turn_on +from .conftest import get_device_for_fixture_protocol, handle_turn_on, turn_on from .device_fixtures import device_iot, has_emeter_iot, no_emeter_iot from .fakeprotocol_iot import FakeIotProtocol @@ -258,3 +258,30 @@ async def test_modules_not_supported(dev: IotDevice): await dev.update() for module in dev.modules.values(): assert module.is_supported is not None + + +async def test_get_modules(): + """Test get_modules for child and parent modules.""" + dummy_device = await get_device_for_fixture_protocol( + "HS100(US)_2.0_1.5.6.json", "IOT" + ) + from kasa.iot.modules import Cloud + from kasa.smart.modules import CloudModule + + # Modules on device + module = dummy_device.get_module("Cloud") + assert module + assert module._device == dummy_device + assert isinstance(module, Cloud) + + module = dummy_device.get_module(Cloud) + assert module + assert module._device == dummy_device + assert isinstance(module, Cloud) + + # Invalid modules + module = dummy_device.get_module("DummyModule") + assert module is None + + module = dummy_device.get_module(CloudModule) + assert module is None diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 476a37ae5..bb2f81bf0 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -122,23 +122,43 @@ async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture): spies[device].assert_not_called() -async def test_get_modules(mocker): +async def test_get_modules(): """Test get_modules for child and parent modules.""" dummy_device = await get_device_for_fixture_protocol( "KS240(US)_1.0_1.0.5.json", "SMART" ) + from kasa.iot.modules import AmbientLight + from kasa.smart.modules import CloudModule, FanModule + + # Modules on device module = dummy_device.get_module("CloudModule") assert module assert module._device == dummy_device + assert isinstance(module, CloudModule) + module = dummy_device.get_module(CloudModule) + assert module + assert module._device == dummy_device + assert isinstance(module, CloudModule) + + # Modules on child module = dummy_device.get_module("FanModule") assert module assert module._device != dummy_device assert module._device._parent == dummy_device + module = dummy_device.get_module(FanModule) + assert module + assert module._device != dummy_device + assert module._device._parent == dummy_device + + # Invalid modules module = dummy_device.get_module("DummyModule") assert module is None + module = dummy_device.get_module(AmbientLight) + assert module is None + @bulb_smart async def test_smartdevice_brightness(dev: SmartDevice):