From 36d866f7ee88dffe6644e5fd6643965029a8d611 Mon Sep 17 00:00:00 2001 From: Teemu Rytilahti Date: Sun, 24 May 2020 20:47:14 +0200 Subject: [PATCH 1/3] Add retries to query(), defaults to 3 + add tests --- kasa/__init__.py | 3 +- kasa/protocol.py | 63 +++++++++------ kasa/smartdevice.py | 7 +- kasa/tests/conftest.py | 12 +++ kasa/tests/test_protocol.py | 150 +++++++++++++++++++++--------------- 5 files changed, 141 insertions(+), 94 deletions(-) diff --git a/kasa/__init__.py b/kasa/__init__.py index b6c42059d..e77aa7dde 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -13,9 +13,10 @@ """ from importlib_metadata import version # type: ignore from kasa.discover import Discover +from kasa.exceptions import SmartDeviceException from kasa.protocol import TPLinkSmartHomeProtocol from kasa.smartbulb import SmartBulb -from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice, SmartDeviceException +from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice from kasa.smartdimmer import SmartDimmer from kasa.smartplug import SmartPlug from kasa.smartstrip import SmartStrip diff --git a/kasa/protocol.py b/kasa/protocol.py index 443a428e4..74e73a44e 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -16,6 +16,8 @@ from pprint import pformat as pf from typing import Dict, Union +from .exceptions import SmartDeviceException + _LOGGER = logging.getLogger(__name__) @@ -27,12 +29,13 @@ class TPLinkSmartHomeProtocol: DEFAULT_TIMEOUT = 5 @staticmethod - async def query(host: str, request: Union[str, Dict]) -> Dict: + async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict: """Request information from a TP-Link SmartHome Device. :param str host: host name or ip address of the device :param request: command to send to the device (can be either dict or json string) + :param retry_count: how many retries to do in case of failure :return: response dict """ if isinstance(request, dict): @@ -40,29 +43,41 @@ async def query(host: str, request: Union[str, Dict]) -> Dict: timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT writer = None - try: - task = asyncio.open_connection(host, TPLinkSmartHomeProtocol.DEFAULT_PORT) - reader, writer = await asyncio.wait_for(task, timeout=timeout) - _LOGGER.debug("> (%i) %s", len(request), request) - writer.write(TPLinkSmartHomeProtocol.encrypt(request)) - await writer.drain() - - buffer = bytes() - # Some devices send responses with a length header of 0 and - # terminate with a zero size chunk. Others send the length and - # will hang if we attempt to read more data. - length = -1 - while True: - chunk = await reader.read(4096) - if length == -1: - length = struct.unpack(">I", chunk[0:4])[0] - buffer += chunk - if (length > 0 and len(buffer) >= length + 4) or not chunk: - break - finally: - if writer: - writer.close() - await writer.wait_closed() + for retry in range(retry_count + 1): + try: + task = asyncio.open_connection( + host, TPLinkSmartHomeProtocol.DEFAULT_PORT + ) + reader, writer = await asyncio.wait_for(task, timeout=timeout) + _LOGGER.debug("> (%i) %s", len(request), request) + writer.write(TPLinkSmartHomeProtocol.encrypt(request)) + await writer.drain() + + buffer = bytes() + # Some devices send responses with a length header of 0 and + # terminate with a zero size chunk. Others send the length and + # will hang if we attempt to read more data. + length = -1 + while True: + chunk = await reader.read(4096) + if length == -1: + length = struct.unpack(">I", chunk[0:4])[0] + buffer += chunk + if (length > 0 and len(buffer) >= length + 4) or not chunk: + break + except Exception as ex: + if retry == retry_count: + _LOGGER.debug("Giving up after %s retries", retry) + raise SmartDeviceException( + "Unable to query the device: %s" % ex + ) from ex + + _LOGGER.debug("Unable to query the device, retrying: %s", ex) + + finally: + if writer: + writer.close() + await writer.wait_closed() response = TPLinkSmartHomeProtocol.decrypt(buffer[4:]) json_payload = json.loads(response) diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 49dc6c4a7..cd2e8f5f9 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -19,7 +19,8 @@ from enum import Enum from typing import Any, Dict, List, Optional -from kasa.protocol import TPLinkSmartHomeProtocol +from .exceptions import SmartDeviceException +from .protocol import TPLinkSmartHomeProtocol _LOGGER = logging.getLogger(__name__) @@ -47,10 +48,6 @@ class WifiNetwork: rssi: Optional[int] = None -class SmartDeviceException(Exception): - """Base exception for device errors.""" - - class EmeterStatus(dict): """Container for converting different representations of emeter data. diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 30e798bec..f2b4c178f 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -3,6 +3,7 @@ import json import os from os.path import basename +from unittest.mock import MagicMock import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342 @@ -151,3 +152,14 @@ def pytest_collection_modifyitems(config, items): return else: print("Running against ip %s" % config.getoption("--ip")) + + +# allow mocks to be awaited +# https://stackoverflow.com/questions/51394411/python-object-magicmock-cant-be-used-in-await-expression/51399767#51399767 + + +async def async_magic(): + pass + + +MagicMock.__await__ = lambda x: async_magic().__await__() diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 313fd69dc..9c74e5eb4 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -1,73 +1,95 @@ import json -from unittest import TestCase +import pytest + +from ..exceptions import SmartDeviceException from ..protocol import TPLinkSmartHomeProtocol -class TestTPLinkSmartHomeProtocol(TestCase): - def test_encrypt(self): - d = json.dumps({"foo": 1, "bar": 2}) - encrypted = TPLinkSmartHomeProtocol.encrypt(d) - # encrypt adds a 4 byte header - encrypted = encrypted[4:] - self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(encrypted)) - - def test_encrypt_unicode(self): - d = "{'snowman': '\u2603'}" - - e = bytes( - [ - 208, - 247, - 132, - 234, - 133, - 242, - 159, - 254, - 144, - 183, - 141, - 173, - 138, - 104, - 240, - 115, - 84, - 41, - ] - ) +@pytest.mark.parametrize("retry_count", [1, 3, 5]) +async def test_protocol_retries(mocker, retry_count): + def aio_mock_writer(_, __): + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") - encrypted = TPLinkSmartHomeProtocol.encrypt(d) - # encrypt adds a 4 byte header - encrypted = encrypted[4:] - - self.assertEqual(e, encrypted) - - def test_decrypt_unicode(self): - e = bytes( - [ - 208, - 247, - 132, - 234, - 133, - 242, - 159, - 254, - 144, - 183, - 141, - 173, - 138, - 104, - 240, - 115, - 84, - 41, - ] + mocker.patch( + "asyncio.StreamWriter.write", side_effect=Exception("dummy exception") ) - d = "{'snowman': '\u2603'}" + return reader, writer + + conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + with pytest.raises(SmartDeviceException): + await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count) + + assert conn.call_count == retry_count + 1 + + +def test_encrypt(self): + d = json.dumps({"foo": 1, "bar": 2}) + encrypted = TPLinkSmartHomeProtocol.encrypt(d) + # encrypt adds a 4 byte header + encrypted = encrypted[4:] + self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(encrypted)) + + +def test_encrypt_unicode(self): + d = "{'snowman': '\u2603'}" + + e = bytes( + [ + 208, + 247, + 132, + 234, + 133, + 242, + 159, + 254, + 144, + 183, + 141, + 173, + 138, + 104, + 240, + 115, + 84, + 41, + ] + ) + + encrypted = TPLinkSmartHomeProtocol.encrypt(d) + # encrypt adds a 4 byte header + encrypted = encrypted[4:] + + self.assertEqual(e, encrypted) + + +def test_decrypt_unicode(self): + e = bytes( + [ + 208, + 247, + 132, + 234, + 133, + 242, + 159, + 254, + 144, + 183, + 141, + 173, + 138, + 104, + 240, + 115, + 84, + 41, + ] + ) + + d = "{'snowman': '\u2603'}" - self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(e)) + self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(e)) From 5568f2049db7fddaf34b2b891e486190b0d3f47f Mon Sep 17 00:00:00 2001 From: Teemu Rytilahti Date: Sun, 24 May 2020 20:51:49 +0200 Subject: [PATCH 2/3] Catch also json decoding errors for retries --- kasa/protocol.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/kasa/protocol.py b/kasa/protocol.py index 74e73a44e..6ee6f72d6 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -65,8 +65,15 @@ async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> D buffer += chunk if (length > 0 and len(buffer) >= length + 4) or not chunk: break + + response = TPLinkSmartHomeProtocol.decrypt(buffer[4:]) + json_payload = json.loads(response) + _LOGGER.debug("< (%i) %s", len(response), pf(json_payload)) + + return json_payload + except Exception as ex: - if retry == retry_count: + if retry >= retry_count: _LOGGER.debug("Giving up after %s retries", retry) raise SmartDeviceException( "Unable to query the device: %s" % ex @@ -79,11 +86,8 @@ async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> D writer.close() await writer.wait_closed() - response = TPLinkSmartHomeProtocol.decrypt(buffer[4:]) - json_payload = json.loads(response) - _LOGGER.debug("< (%i) %s", len(response), pf(json_payload)) - - return json_payload + # make mypy happy, this should never be reached.. + raise SmartDeviceException("Query reached somehow to unreachable") @staticmethod def encrypt(request: str) -> bytes: From 3aac122fb0b24551e39c6e9974d4bebb7876125d Mon Sep 17 00:00:00 2001 From: Teemu Rytilahti Date: Sun, 24 May 2020 22:23:09 +0200 Subject: [PATCH 3/3] add missing exceptions file, fix old protocol tests --- kasa/exceptions.py | 5 +++++ kasa/tests/test_protocol.py | 12 ++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) create mode 100644 kasa/exceptions.py diff --git a/kasa/exceptions.py b/kasa/exceptions.py new file mode 100644 index 000000000..90d36c9a0 --- /dev/null +++ b/kasa/exceptions.py @@ -0,0 +1,5 @@ +"""python-kasa exceptions.""" + + +class SmartDeviceException(Exception): + """Base exception for device errors.""" diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 9c74e5eb4..0a8291e1c 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -25,15 +25,15 @@ def aio_mock_writer(_, __): assert conn.call_count == retry_count + 1 -def test_encrypt(self): +def test_encrypt(): d = json.dumps({"foo": 1, "bar": 2}) encrypted = TPLinkSmartHomeProtocol.encrypt(d) # encrypt adds a 4 byte header encrypted = encrypted[4:] - self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(encrypted)) + assert d == TPLinkSmartHomeProtocol.decrypt(encrypted) -def test_encrypt_unicode(self): +def test_encrypt_unicode(): d = "{'snowman': '\u2603'}" e = bytes( @@ -63,10 +63,10 @@ def test_encrypt_unicode(self): # encrypt adds a 4 byte header encrypted = encrypted[4:] - self.assertEqual(e, encrypted) + assert e == encrypted -def test_decrypt_unicode(self): +def test_decrypt_unicode(): e = bytes( [ 208, @@ -92,4 +92,4 @@ def test_decrypt_unicode(self): d = "{'snowman': '\u2603'}" - self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(e)) + assert d == TPLinkSmartHomeProtocol.decrypt(e)