diff --git a/kasa/aestransport.py b/kasa/aestransport.py index bc1eacff7..bbcc511f1 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -3,7 +3,7 @@ Based on the work of https://github.com/petretiandrea/plugp100 under compatible GNU GPL3 license. """ - +import asyncio import base64 import hashlib import logging @@ -39,6 +39,7 @@ ONE_DAY_SECONDS = 86400 SESSION_EXPIRE_BUFFER_SECONDS = 60 * 20 +BACKOFF_SECONDS_AFTER_LOGIN_ERROR = 1 def _sha1(payload: bytes) -> str: @@ -184,8 +185,24 @@ async def send_secure_passthrough(self, request: str) -> Dict[str, Any]: assert self._encryption_session is not None raw_response: str = resp_dict["result"]["response"] - response = self._encryption_session.decrypt(raw_response.encode()) - return json_loads(response) # type: ignore[return-value] + + try: + response = self._encryption_session.decrypt(raw_response.encode()) + ret_val = json_loads(response) + except Exception as ex: + try: + ret_val = json_loads(raw_response) + _LOGGER.debug( + "Received unencrypted response over secure passthrough from %s", + self._host, + ) + except Exception: + raise SmartDeviceException( + f"Unable to decrypt response from {self._host}, " + + f"error: {ex}, response: {raw_response}", + ex, + ) from ex + return ret_val # type: ignore[return-value] async def perform_login(self): """Login to the device.""" @@ -199,6 +216,7 @@ async def perform_login(self): self._default_credentials = get_default_credentials( DEFAULT_CREDENTIALS["TAPO"] ) + await asyncio.sleep(BACKOFF_SECONDS_AFTER_LOGIN_ERROR) await self.perform_handshake() await self.try_login(self._get_login_params(self._default_credentials)) _LOGGER.debug( diff --git a/kasa/smart/smartdevice.py b/kasa/smart/smartdevice.py index ca9ed63be..0929c418d 100644 --- a/kasa/smart/smartdevice.py +++ b/kasa/smart/smartdevice.py @@ -69,7 +69,7 @@ async def update(self, update_children: bool = True): resp = await self.protocol.query("component_nego") self._components_raw = resp["component_nego"] self._components = { - comp["id"]: comp["ver_code"] + comp["id"]: int(comp["ver_code"]) for comp in self._components_raw["component_list"] } await self._initialize_modules() @@ -86,9 +86,14 @@ async def update(self, update_children: bool = True): "get_current_power": None, } + if self._components["device"] >= 2: + extra_reqs = { + **extra_reqs, + "get_device_usage": None, + } + req = { "get_device_info": None, - "get_device_usage": None, "get_device_time": None, **extra_reqs, } @@ -96,8 +101,9 @@ async def update(self, update_children: bool = True): resp = await self.protocol.query(req) self._info = resp["get_device_info"] - self._usage = resp["get_device_usage"] self._time = resp["get_device_time"] + # Device usage is not available on older firmware versions + self._usage = resp.get("get_device_usage", {}) # Emeter is not always available, but we set them still for now. self._energy = resp.get("get_energy_usage", {}) self._emeter = resp.get("get_current_power", {}) diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 74f2275d2..f61bac206 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -82,6 +82,7 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex + await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) continue except TimeoutException as ex: await self._transport.reset() diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index a692ba9be..51f1e3d90 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -1,5 +1,6 @@ import base64 import json +import logging import random import string import time @@ -180,6 +181,67 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati assert "result" in res +async def test_unencrypted_response(mocker, caplog): + host = "127.0.0.1" + mock_aes_device = MockAesDevice(host, 200, 0, 0, do_not_encrypt_response=True) + mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post) + + transport = AesTransport( + config=DeviceConfig(host, credentials=Credentials("foo", "bar")) + ) + transport._state = TransportState.ESTABLISHED + transport._session_expire_at = time.time() + 86400 + transport._encryption_session = mock_aes_device.encryption_session + transport._token_url = transport._app_url.with_query( + f"token={mock_aes_device.token}" + ) + + request = { + "method": "get_device_info", + "params": None, + "request_time_milis": round(time.time() * 1000), + "requestID": 1, + "terminal_uuid": "foobar", + } + caplog.set_level(logging.DEBUG) + res = await transport.send(json_dumps(request)) + assert "result" in res + assert ( + "Received unencrypted response over secure passthrough from 127.0.0.1" + in caplog.text + ) + + +async def test_unencrypted_response_invalid_json(mocker, caplog): + host = "127.0.0.1" + mock_aes_device = MockAesDevice( + host, 200, 0, 0, do_not_encrypt_response=True, send_response=b"Foobar" + ) + mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post) + + transport = AesTransport( + config=DeviceConfig(host, credentials=Credentials("foo", "bar")) + ) + transport._state = TransportState.ESTABLISHED + transport._session_expire_at = time.time() + 86400 + transport._encryption_session = mock_aes_device.encryption_session + transport._token_url = transport._app_url.with_query( + f"token={mock_aes_device.token}" + ) + + request = { + "method": "get_device_info", + "params": None, + "request_time_milis": round(time.time() * 1000), + "requestID": 1, + "terminal_uuid": "foobar", + } + caplog.set_level(logging.DEBUG) + msg = f"Unable to decrypt response from {host}, error: Incorrect padding, response: Foobar" + with pytest.raises(SmartDeviceException, match=msg): + await transport.send(json_dumps(request)) + + ERRORS = [e for e in SmartErrorCode if e != 0] @@ -233,15 +295,28 @@ async def __aexit__(self, exc_t, exc_v, exc_tb): pass async def read(self): - return json_dumps(self._json).encode() + if isinstance(self._json, dict): + return json_dumps(self._json).encode() + return self._json encryption_session = AesEncyptionSession(KEY_IV[:16], KEY_IV[16:]) - def __init__(self, host, status_code=200, error_code=0, inner_error_code=0): + def __init__( + self, + host, + status_code=200, + error_code=0, + inner_error_code=0, + *, + do_not_encrypt_response=False, + send_response=None, + ): self.host = host self.status_code = status_code self.error_code = error_code self._inner_error_code = inner_error_code + self.do_not_encrypt_response = do_not_encrypt_response + self.send_response = send_response self.http_client = HttpClient(DeviceConfig(self.host)) self.inner_call_count = 0 self.token = "".join(random.choices(string.ascii_uppercase, k=32)) # noqa: S311 @@ -289,13 +364,15 @@ async def _return_secure_passthrough_response(self, url: URL, json: Dict[str, An decrypted_request_dict = json_loads(decrypted_request) decrypted_response = await self._post(url, decrypted_request_dict) async with decrypted_response: - response_data = await decrypted_response.read() - decrypted_response_dict = json_loads(response_data.decode()) - encrypted_response = self.encryption_session.encrypt( - json_dumps(decrypted_response_dict).encode() + decrypted_response_data = await decrypted_response.read() + encrypted_response = self.encryption_session.encrypt(decrypted_response_data) + response = ( + decrypted_response_data + if self.do_not_encrypt_response + else encrypted_response ) result = { - "result": {"response": encrypted_response.decode()}, + "result": {"response": response.decode()}, "error_code": self.error_code, } return self._mock_response(self.status_code, result) @@ -310,5 +387,6 @@ async def _return_login_response(self, url: URL, json: Dict[str, Any]): async def _return_send_response(self, url: URL, json: Dict[str, Any]): result = {"result": {"method": None}, "error_code": self.inner_error_code} + response = self.send_response if self.send_response else result self.inner_call_count += 1 - return self._mock_response(self.status_code, result) + return self._mock_response(self.status_code, response)