diff --git a/kasa/cli.py b/kasa/cli.py index d1b40a9e8..dce38def1 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -1258,5 +1258,51 @@ async def feature(dev: Device, child: str, name: str, value): return response +@cli.group(invoke_without_command=True) +@pass_dev +@click.pass_context +async def firmware(ctx: click.Context, dev: Device): + """Firmware update.""" + if ctx.invoked_subcommand is None: + return await ctx.invoke(firmware_info) + + +@firmware.command(name="info") +@pass_dev +@click.pass_context +async def firmware_info(ctx: click.Context, dev: Device): + """Return firmware information.""" + if not (firmware := dev.modules.get(Module.Firmware)): + echo("This device does not support firmware info.") + return + + res = await firmware.check_for_updates() + if res.update_available: + echo("[green bold]Update available![/green bold]") + echo(f"Current firmware: {res.current_version}") + echo(f"Version {res.available_version} released at {res.release_date}") + echo("Release notes") + echo("=============") + echo(res.release_notes) + echo("=============") + else: + echo("[red bold]No updates available.[/red bold]") + + +@firmware.command(name="update") +@pass_dev +@click.pass_context +async def firmware_update(ctx: click.Context, dev: Device): + """Perform firmware update.""" + await ctx.invoke(firmware_info) + click.confirm("Are you sure you want to upgrade the firmware?", abort=True) + + async def progress(x): + echo(f"Progress: {x}") + + echo("Going to update %s", dev) + await dev.modules[Module.Firmware].update_firmware(progress_cb=progress) # type: ignore + + if __name__ == "__main__": cli() diff --git a/kasa/interfaces/__init__.py b/kasa/interfaces/__init__.py index d8d089c5c..5328289b3 100644 --- a/kasa/interfaces/__init__.py +++ b/kasa/interfaces/__init__.py @@ -1,12 +1,14 @@ """Package for interfaces.""" from .fan import Fan +from .firmware import Firmware from .led import Led from .light import Light, LightPreset from .lighteffect import LightEffect __all__ = [ "Fan", + "Firmware", "Led", "Light", "LightEffect", diff --git a/kasa/interfaces/firmware.py b/kasa/interfaces/firmware.py new file mode 100644 index 000000000..705b1e35d --- /dev/null +++ b/kasa/interfaces/firmware.py @@ -0,0 +1,51 @@ +"""Interface for firmware updates.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import date +from typing import Callable, Coroutine + +from ..module import Module + +UpdateResult = bool + + +class FirmwareDownloadState(ABC): + """Download state.""" + + status: int + progress: int + reboot_time: int + upgrade_time: int + auto_upgrade: bool + + +@dataclass +class FirmwareUpdateInfo: + """Update info status object.""" + + update_available: bool | None = None + current_version: str | None = None + available_version: str | None = None + release_date: date | None = None + release_notes: str | None = None + + +class Firmware(Module, ABC): + """Interface to access firmware information and perform updates.""" + + @abstractmethod + async def update_firmware( + self, *, progress_cb: Callable[[FirmwareDownloadState], Coroutine] | None = None + ) -> UpdateResult: + """Perform firmware update. + + This "blocks" until the update process has finished. + You can set *progress_cb* to get progress updates. + """ + + @abstractmethod + async def check_for_updates(self) -> FirmwareUpdateInfo: + """Return firmware update information.""" diff --git a/kasa/iot/modules/cloud.py b/kasa/iot/modules/cloud.py index 5022a68e7..4effc0c0a 100644 --- a/kasa/iot/modules/cloud.py +++ b/kasa/iot/modules/cloud.py @@ -1,10 +1,28 @@ """Cloud module implementation.""" -from pydantic.v1 import BaseModel +from __future__ import annotations + +import logging +from datetime import date +from typing import Callable, Coroutine, Optional + +from pydantic.v1 import BaseModel, Field, validator from ...feature import Feature +from ...interfaces.firmware import ( + Firmware, + UpdateResult, +) +from ...interfaces.firmware import ( + FirmwareDownloadState as FirmwareDownloadStateInterface, +) +from ...interfaces.firmware import ( + FirmwareUpdateInfo as FirmwareUpdateInfoInterface, +) from ..iotmodule import IotModule +_LOGGER = logging.getLogger(__name__) + class CloudInfo(BaseModel): """Container for cloud settings.""" @@ -21,7 +39,31 @@ class CloudInfo(BaseModel): username: str -class Cloud(IotModule): +class FirmwareUpdate(BaseModel): + """Update info status object.""" + + status: int = Field(alias="fwType") + version: Optional[str] = Field(alias="fwVer", default=None) # noqa: UP007 + release_date: Optional[date] = Field(alias="fwReleaseDate", default=None) # noqa: UP007 + release_notes: Optional[str] = Field(alias="fwReleaseLog", default=None) # noqa: UP007 + url: Optional[str] = Field(alias="fwUrl", default=None) # noqa: UP007 + + @validator("release_date", pre=True) + def _release_date_optional(cls, v): + if not v: + return None + + return v + + @property + def update_available(self): + """Return True if update available.""" + if self.status != 0: + return True + return False + + +class Cloud(IotModule, Firmware): """Module implementing support for cloud services.""" def __init__(self, device, module): @@ -46,27 +88,86 @@ def is_connected(self) -> bool: def query(self): """Request cloud connectivity info.""" - return self.query_for_command("get_info") + req = self.query_for_command("get_info") + + # TODO: this is problematic, as it will fail the whole query on some + # devices if they are not connected to the internet + + # The following causes a recursion error as self.is_connected + # accesses self.data which calls query. Also get_available_firmwares is async + # if self._module in self._device._last_update and self.is_connected: + # req = merge(req, self.get_available_firmwares()) + + return req @property def info(self) -> CloudInfo: """Return information about the cloud connectivity.""" return CloudInfo.parse_obj(self.data["get_info"]) - def get_available_firmwares(self): + async def get_available_firmwares(self): """Return list of available firmwares.""" - return self.query_for_command("get_intl_fw_list") + return await self.call("get_intl_fw_list") + + async def get_firmware_update(self) -> FirmwareUpdate: + """Return firmware update information.""" + try: + available_fws = (await self.get_available_firmwares()).get("fw_list", []) + if not available_fws: + return FirmwareUpdate(fwType=0) + if len(available_fws) > 1: + _LOGGER.warning( + "Got more than one update, using the first one: %s", available_fws + ) + return FirmwareUpdate.parse_obj(next(iter(available_fws))) + except Exception as ex: + _LOGGER.warning("Unable to check for firmware update: %s", ex) + return FirmwareUpdate(fwType=0) - def set_server(self, url: str): + async def set_server(self, url: str): """Set the update server URL.""" - return self.query_for_command("set_server_url", {"server": url}) + return await self.call("set_server_url", {"server": url}) - def connect(self, username: str, password: str): + async def connect(self, username: str, password: str): """Login to the cloud using given information.""" - return self.query_for_command( - "bind", {"username": username, "password": password} - ) + return await self.call("bind", {"username": username, "password": password}) - def disconnect(self): + async def disconnect(self): """Disconnect from the cloud.""" - return self.query_for_command("unbind") + return await self.call("unbind") + + async def update_firmware( + self, + *, + progress_cb: Callable[[FirmwareDownloadStateInterface], Coroutine] + | None = None, + ) -> UpdateResult: + """Perform firmware update.""" + raise NotImplementedError + i = 0 + import asyncio + + while i < 100: + await asyncio.sleep(1) + if progress_cb is not None: + await progress_cb(i) + i += 10 + + return UpdateResult("") + + async def check_for_updates(self) -> FirmwareUpdateInfoInterface: + """Return firmware update information.""" + # TODO: naming of the common firmware API methods + raise NotImplementedError + + async def get_update_state(self) -> FirmwareUpdateInfoInterface: + """Return firmware update information.""" + fw = await self.get_firmware_update() + + return FirmwareUpdateInfoInterface( + update_available=fw.update_available, + current_version=self._device.hw_info.get("sw_ver"), + available_version=fw.version, + release_date=fw.release_date, + release_notes=fw.release_notes, + ) diff --git a/kasa/module.py b/kasa/module.py index 9b541ce04..61258f697 100644 --- a/kasa/module.py +++ b/kasa/module.py @@ -36,6 +36,7 @@ class Module(ABC): LightEffect: Final[ModuleName[interfaces.LightEffect]] = ModuleName("LightEffect") Led: Final[ModuleName[interfaces.Led]] = ModuleName("Led") Light: Final[ModuleName[interfaces.Light]] = ModuleName("Light") + Firmware: Final[ModuleName[interfaces.Firmware]] = ModuleName("Firmware") # IOT only Modules IotAmbientLight: Final[ModuleName[iot.AmbientLight]] = ModuleName("ambient") @@ -63,7 +64,6 @@ class Module(ABC): DeviceModule: Final[ModuleName[smart.DeviceModule]] = ModuleName("DeviceModule") Energy: Final[ModuleName[smart.Energy]] = ModuleName("Energy") Fan: Final[ModuleName[smart.Fan]] = ModuleName("Fan") - Firmware: Final[ModuleName[smart.Firmware]] = ModuleName("Firmware") FrostProtection: Final[ModuleName[smart.FrostProtection]] = ModuleName( "FrostProtection" ) diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 430515e4b..cdad5f6a5 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -14,6 +14,12 @@ from ...exceptions import SmartErrorCode from ...feature import Feature +from ...interfaces import Firmware as FirmwareInterface +from ...interfaces.firmware import ( + FirmwareDownloadState as FirmwareDownloadStateInterface, +) +from ...interfaces.firmware import FirmwareUpdateInfo as FirmwareUpdateInfoInterface +from ...interfaces.firmware import UpdateResult from ..smartmodule import SmartModule if TYPE_CHECKING: @@ -36,7 +42,7 @@ class DownloadState(BaseModel): auto_upgrade: bool -class UpdateInfo(BaseModel): +class FirmwareUpdateInfo(BaseModel): """Update info status object.""" status: int = Field(alias="type") @@ -62,7 +68,7 @@ def update_available(self): return False -class Firmware(SmartModule): +class Firmware(SmartModule, FirmwareInterface): """Implementation of firmware module.""" REQUIRED_COMPONENT = "firmware" @@ -136,9 +142,9 @@ def firmware_update_info(self): fw = self.data.get("get_latest_fw") or self.data if not self._device.is_cloud_connected or isinstance(fw, SmartErrorCode): # Error in response, probably disconnected from the cloud. - return UpdateInfo(type=0, need_to_upgrade=False) + return FirmwareUpdateInfo(type=0, need_to_upgrade=False) - return UpdateInfo.parse_obj(fw) + return FirmwareUpdateInfo.parse_obj(fw) @property def update_available(self) -> bool | None: @@ -214,3 +220,24 @@ async def set_auto_update_enabled(self, enabled: bool): """Change autoupdate setting.""" data = {**self.data["get_auto_update_info"], "enable": enabled} await self.call("set_auto_update_info", data) + + async def update_firmware( + self, + *, + progress_cb: Callable[[FirmwareDownloadStateInterface], Coroutine] + | None = None, + ) -> UpdateResult: + """Update the firmware.""" + return await self.update(progress_cb) + + async def check_for_updates(self) -> FirmwareUpdateInfoInterface: + """Return firmware update information.""" + # TODO: naming of the common firmware API methods + info = self.firmware_update_info + return FirmwareUpdateInfoInterface( + current_version=self.current_firmware, + update_available=info.update_available, + available_version=info.version, + release_date=info.release_date, + release_notes=info.release_notes, + ) diff --git a/kasa/tests/smart/modules/test_firmware.py b/kasa/tests/smart/modules/test_firmware.py index b592041f4..aa71099fd 100644 --- a/kasa/tests/smart/modules/test_firmware.py +++ b/kasa/tests/smart/modules/test_firmware.py @@ -8,7 +8,7 @@ from kasa import Module from kasa.smart import SmartDevice -from kasa.smart.modules.firmware import DownloadState +from kasa.smart.modules.firmware import DownloadState, Firmware from kasa.tests.device_fixtures import parametrize firmware = parametrize( @@ -33,7 +33,8 @@ async def test_firmware_features( """Test light effect.""" fw = dev.modules.get(Module.Firmware) assert fw - + if not isinstance(fw, Firmware): # TODO needed while common interface still TBD + return if not dev.is_cloud_connected: pytest.skip("Device is not cloud connected, skipping test") @@ -53,6 +54,8 @@ async def test_update_available_without_cloud(dev: SmartDevice): """Test that update_available returns None when disconnected.""" fw = dev.modules.get(Module.Firmware) assert fw + if not isinstance(fw, Firmware): # TODO needed while common interface still TBD + return if dev.is_cloud_connected: assert isinstance(fw.update_available, bool) @@ -69,6 +72,8 @@ async def test_firmware_update( fw = dev.modules.get(Module.Firmware) assert fw + if not isinstance(fw, Firmware): # TODO needed while common interface still TBD + return upgrade_time = 5 extras = {"reboot_time": 5, "upgrade_time": upgrade_time, "auto_upgrade": False}