From 6e4ef6697544964c34b07322c3dfb48d7ce384e7 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Apr 2024 18:07:35 +0100 Subject: [PATCH 1/2] Handle paging of partial responses of lists like child_device_info --- kasa/smartprotocol.py | 43 +++++++++++++++++++++-- kasa/tests/fakeprotocol_smart.py | 25 +++++++++++--- kasa/tests/test_smartprotocol.py | 58 +++++++++++++++++++++++++++++++- 3 files changed, 118 insertions(+), 8 deletions(-) diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 9a1482b18..bd372e1a6 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -67,7 +67,9 @@ async def query(self, request: str | dict, retry_count: int = 3) -> dict: async def _query(self, request: str | dict, retry_count: int = 3) -> dict: for retry in range(retry_count + 1): try: - return await self._execute_query(request, retry) + return await self._execute_query( + request, retry_count=retry, handle_lists=True + ) except _ConnectionError as sdex: if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) @@ -145,6 +147,9 @@ async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dic method = response["method"] self._handle_response_error_code(response, method, raise_on_error=False) result = response.get("result", None) + await self._handle_response_lists( + result, method, retry_count=retry_count + ) multi_result[method] = result # Multi requests don't continue after errors so requery any missing for method, params in requests.items(): @@ -156,7 +161,9 @@ async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dic multi_result[method] = resp.get("result") return multi_result - async def _execute_query(self, request: str | dict, retry_count: int) -> dict: + async def _execute_query( + self, request: str | dict, *, retry_count: int, handle_lists: bool + ) -> dict: debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) if isinstance(request, dict): @@ -189,8 +196,40 @@ async def _execute_query(self, request: str | dict, retry_count: int) -> dict: # Single set_ requests do not return a result result = response_data.get("result") + if handle_lists and result: + await self._handle_response_lists( + result, smart_method, retry_count=retry_count + ) return {smart_method: result} + async def _handle_response_lists( + self, response_result: dict[str, Any], method, retry_count + ): + if ( + not isinstance(response_result, SmartErrorCode) + and "start_index" in response_result + and (list_sum := response_result.get("sum")) + ): + response_list_name = next( + iter( + [ + key + for key in response_result + if isinstance(response_result[key], list) + ] + ) + ) + while (list_length := len(response_result[response_list_name])) < list_sum: + response = await self._execute_query( + {method: {"start_index": list_length}}, + retry_count=retry_count, + handle_lists=False, + ) + next_batch = response[method] + response_result[response_list_name].extend( + next_batch[response_list_name] + ) + def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True): error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] if error_code == SmartErrorCode.SUCCESS: diff --git a/kasa/tests/fakeprotocol_smart.py b/kasa/tests/fakeprotocol_smart.py index dd9b1f169..d3ee39fb3 100644 --- a/kasa/tests/fakeprotocol_smart.py +++ b/kasa/tests/fakeprotocol_smart.py @@ -21,7 +21,7 @@ async def query(self, request, retry_count: int = 3): class FakeSmartTransport(BaseTransport): - def __init__(self, info, fixture_name): + def __init__(self, info, fixture_name, *, list_return_size=10, no_components=False): super().__init__( config=DeviceConfig( "127.0.0.123", @@ -33,10 +33,12 @@ def __init__(self, info, fixture_name): ) self.fixture_name = fixture_name self.info = copy.deepcopy(info) - self.components = { - comp["id"]: comp["ver_code"] - for comp in self.info["component_nego"]["component_list"] - } + if no_components is False: + self.components = { + comp["id"]: comp["ver_code"] + for comp in self.info["component_nego"]["component_list"] + } + self.list_return_size = list_return_size @property def default_port(self): @@ -158,7 +160,20 @@ def _send_request(self, request_dict: dict): elif method == "component_nego" or method[:4] == "get_": if method in info: result = copy.deepcopy(info[method]) + if "start_index" in result and "sum" in result: + list_key = next( + iter([key for key in result if isinstance(result[key], list)]) + ) + start_index = ( + start_index + if (params and (start_index := params.get("start_index"))) + else 0 + ) + result[list_key] = result[list_key][ + start_index : start_index + self.list_return_size + ] return {"result": result, "error_code": 0} + if ( # FIXTURE_MISSING is for service calls not in place when # SMART fixtures started to be generated diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index b970eaa5a..c61848a11 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -7,7 +7,8 @@ KasaException, SmartErrorCode, ) -from ..smartprotocol import _ChildProtocolWrapper +from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper +from .fakeprotocol_smart import FakeSmartTransport DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} DUMMY_MULTIPLE_QUERY = { @@ -180,3 +181,58 @@ async def test_childdevicewrapper_multiplerequest_error(dummy_protocol, mocker): mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response) with pytest.raises(KasaException): await wrapped_protocol.query(DUMMY_QUERY) + + +@pytest.mark.parametrize("list_sum", [5, 10, 30]) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 50]) +async def test_smart_protocol_lists_single_request(mocker, list_sum, batch_size): + child_device_list = [{"foo": i} for i in range(list_sum)] + response = { + "get_child_device_list": { + "child_device_list": child_device_list, + "start_index": 0, + "sum": list_sum, + } + } + request = {"get_child_device_list": None} + + ft = FakeSmartTransport( + response, "foobar", list_return_size=batch_size, no_components=True + ) + protocol = SmartProtocol(transport=ft) + query_spy = mocker.spy(protocol, "_execute_query") + resp = await protocol.query(request) + expected_count = int(list_sum / batch_size) + (1 if list_sum % batch_size else 0) + assert query_spy.call_count == expected_count + assert resp == response + + +@pytest.mark.parametrize("list_sum", [5, 10, 30]) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 50]) +async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_size): + child_list = [{"foo": i} for i in range(list_sum)] + response = { + "get_child_device_list": { + "child_device_list": child_list, + "start_index": 0, + "sum": list_sum, + }, + "get_child_device_component_list": { + "child_component_list": child_list, + "start_index": 0, + "sum": list_sum, + }, + } + request = {"get_child_device_list": None, "get_child_device_component_list": None} + + ft = FakeSmartTransport( + response, "foobar", list_return_size=batch_size, no_components=True + ) + protocol = SmartProtocol(transport=ft) + query_spy = mocker.spy(protocol, "_execute_query") + resp = await protocol.query(request) + expected_count = 1 + 2 * ( + int(list_sum / batch_size) + (0 if list_sum % batch_size else -1) + ) + assert query_spy.call_count == expected_count + assert resp == response From 7b491e288619d5a2983efec45e32d2ed1d054513 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 24 Apr 2024 18:47:24 +0100 Subject: [PATCH 2/2] Update post-review --- kasa/smartprotocol.py | 48 ++++++++++++++++---------------- kasa/tests/fakeprotocol_smart.py | 11 ++++++-- kasa/tests/test_smartprotocol.py | 10 +++++-- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index bd372e1a6..cbfd16b0f 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -68,7 +68,7 @@ async def _query(self, request: str | dict, retry_count: int = 3) -> dict: for retry in range(retry_count + 1): try: return await self._execute_query( - request, retry_count=retry, handle_lists=True + request, retry_count=retry, iterate_list_pages=True ) except _ConnectionError as sdex: if retry >= retry_count: @@ -162,7 +162,7 @@ async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dic return multi_result async def _execute_query( - self, request: str | dict, *, retry_count: int, handle_lists: bool + self, request: str | dict, *, retry_count: int, iterate_list_pages: bool = True ) -> dict: debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) @@ -196,7 +196,7 @@ async def _execute_query( # Single set_ requests do not return a result result = response_data.get("result") - if handle_lists and result: + if iterate_list_pages and result: await self._handle_response_lists( result, smart_method, retry_count=retry_count ) @@ -206,29 +206,29 @@ async def _handle_response_lists( self, response_result: dict[str, Any], method, retry_count ): if ( - not isinstance(response_result, SmartErrorCode) - and "start_index" in response_result - and (list_sum := response_result.get("sum")) + isinstance(response_result, SmartErrorCode) + or "start_index" not in response_result + or (list_sum := response_result.get("sum")) is None ): - response_list_name = next( - iter( - [ - key - for key in response_result - if isinstance(response_result[key], list) - ] - ) + return + + response_list_name = next( + iter( + [ + key + for key in response_result + if isinstance(response_result[key], list) + ] ) - while (list_length := len(response_result[response_list_name])) < list_sum: - response = await self._execute_query( - {method: {"start_index": list_length}}, - retry_count=retry_count, - handle_lists=False, - ) - next_batch = response[method] - response_result[response_list_name].extend( - next_batch[response_list_name] - ) + ) + while (list_length := len(response_result[response_list_name])) < list_sum: + response = await self._execute_query( + {method: {"start_index": list_length}}, + retry_count=retry_count, + iterate_list_pages=False, + ) + next_batch = response[method] + response_result[response_list_name].extend(next_batch[response_list_name]) def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True): error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] diff --git a/kasa/tests/fakeprotocol_smart.py b/kasa/tests/fakeprotocol_smart.py index d3ee39fb3..052b884d0 100644 --- a/kasa/tests/fakeprotocol_smart.py +++ b/kasa/tests/fakeprotocol_smart.py @@ -21,7 +21,14 @@ async def query(self, request, retry_count: int = 3): class FakeSmartTransport(BaseTransport): - def __init__(self, info, fixture_name, *, list_return_size=10, no_components=False): + def __init__( + self, + info, + fixture_name, + *, + list_return_size=10, + component_nego_not_included=False, + ): super().__init__( config=DeviceConfig( "127.0.0.123", @@ -33,7 +40,7 @@ def __init__(self, info, fixture_name, *, list_return_size=10, no_components=Fal ) self.fixture_name = fixture_name self.info = copy.deepcopy(info) - if no_components is False: + if not component_nego_not_included: self.components = { comp["id"]: comp["ver_code"] for comp in self.info["component_nego"]["component_list"] diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index c61848a11..ca62ba02d 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -197,7 +197,10 @@ async def test_smart_protocol_lists_single_request(mocker, list_sum, batch_size) request = {"get_child_device_list": None} ft = FakeSmartTransport( - response, "foobar", list_return_size=batch_size, no_components=True + response, + "foobar", + list_return_size=batch_size, + component_nego_not_included=True, ) protocol = SmartProtocol(transport=ft) query_spy = mocker.spy(protocol, "_execute_query") @@ -226,7 +229,10 @@ async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_siz request = {"get_child_device_list": None, "get_child_device_component_list": None} ft = FakeSmartTransport( - response, "foobar", list_return_size=batch_size, no_components=True + response, + "foobar", + list_return_size=batch_size, + component_nego_not_included=True, ) protocol = SmartProtocol(transport=ft) query_spy = mocker.spy(protocol, "_execute_query")