Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions kasa/smartprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ async def query(self, request: str | dict, retry_count: int = 3) -> dict:
async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
for retry in range(retry_count + 1):
try:
return await self._execute_query(request, retry)
return await self._execute_query(
request, retry_count=retry, iterate_list_pages=True
)
except _ConnectionError as sdex:
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
Expand Down Expand Up @@ -145,6 +147,9 @@ async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dic
method = response["method"]
self._handle_response_error_code(response, method, raise_on_error=False)
result = response.get("result", None)
await self._handle_response_lists(
result, method, retry_count=retry_count
)
multi_result[method] = result
# Multi requests don't continue after errors so requery any missing
for method, params in requests.items():
Expand All @@ -156,7 +161,9 @@ async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dic
multi_result[method] = resp.get("result")
return multi_result

async def _execute_query(self, request: str | dict, retry_count: int) -> dict:
async def _execute_query(
self, request: str | dict, *, retry_count: int, iterate_list_pages: bool = True
) -> dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)

if isinstance(request, dict):
Expand Down Expand Up @@ -189,8 +196,40 @@ async def _execute_query(self, request: str | dict, retry_count: int) -> dict:

# Single set_ requests do not return a result
result = response_data.get("result")
if iterate_list_pages and result:
await self._handle_response_lists(
result, smart_method, retry_count=retry_count
)
return {smart_method: result}

async def _handle_response_lists(
self, response_result: dict[str, Any], method, retry_count
):
if (
isinstance(response_result, SmartErrorCode)
or "start_index" not in response_result
or (list_sum := response_result.get("sum")) is None
):
return

response_list_name = next(
iter(
[
key
for key in response_result
if isinstance(response_result[key], list)
]
)
)
while (list_length := len(response_result[response_list_name])) < list_sum:
response = await self._execute_query(
{method: {"start_index": list_length}},
retry_count=retry_count,
iterate_list_pages=False,
)
next_batch = response[method]
response_result[response_list_name].extend(next_batch[response_list_name])

def _handle_response_error_code(self, resp_dict: dict, method, raise_on_error=True):
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
Expand Down
32 changes: 27 additions & 5 deletions kasa/tests/fakeprotocol_smart.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@ async def query(self, request, retry_count: int = 3):


class FakeSmartTransport(BaseTransport):
def __init__(self, info, fixture_name):
def __init__(
self,
info,
fixture_name,
*,
list_return_size=10,
component_nego_not_included=False,
):
super().__init__(
config=DeviceConfig(
"127.0.0.123",
Expand All @@ -33,10 +40,12 @@ def __init__(self, info, fixture_name):
)
self.fixture_name = fixture_name
self.info = copy.deepcopy(info)
self.components = {
comp["id"]: comp["ver_code"]
for comp in self.info["component_nego"]["component_list"]
}
if not component_nego_not_included:
self.components = {
comp["id"]: comp["ver_code"]
for comp in self.info["component_nego"]["component_list"]
}
self.list_return_size = list_return_size

@property
def default_port(self):
Expand Down Expand Up @@ -177,7 +186,20 @@ def _send_request(self, request_dict: dict):
elif method == "component_nego" or method[:4] == "get_":
if method in info:
result = copy.deepcopy(info[method])
if "start_index" in result and "sum" in result:
list_key = next(
iter([key for key in result if isinstance(result[key], list)])
)
start_index = (
start_index
if (params and (start_index := params.get("start_index")))
else 0
)
result[list_key] = result[list_key][
start_index : start_index + self.list_return_size
]
return {"result": result, "error_code": 0}

if (
# FIXTURE_MISSING is for service calls not in place when
# SMART fixtures started to be generated
Expand Down
64 changes: 63 additions & 1 deletion kasa/tests/test_smartprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
KasaException,
SmartErrorCode,
)
from ..smartprotocol import _ChildProtocolWrapper
from ..smartprotocol import SmartProtocol, _ChildProtocolWrapper
from .fakeprotocol_smart import FakeSmartTransport

DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}
DUMMY_MULTIPLE_QUERY = {
Expand Down Expand Up @@ -180,3 +181,64 @@ async def test_childdevicewrapper_multiplerequest_error(dummy_protocol, mocker):
mocker.patch.object(wrapped_protocol._transport, "send", return_value=mock_response)
with pytest.raises(KasaException):
await wrapped_protocol.query(DUMMY_QUERY)


@pytest.mark.parametrize("list_sum", [5, 10, 30])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 50])
async def test_smart_protocol_lists_single_request(mocker, list_sum, batch_size):
child_device_list = [{"foo": i} for i in range(list_sum)]
response = {
"get_child_device_list": {
"child_device_list": child_device_list,
"start_index": 0,
"sum": list_sum,
}
}
request = {"get_child_device_list": None}

ft = FakeSmartTransport(
response,
"foobar",
list_return_size=batch_size,
component_nego_not_included=True,
)
protocol = SmartProtocol(transport=ft)
query_spy = mocker.spy(protocol, "_execute_query")
resp = await protocol.query(request)
expected_count = int(list_sum / batch_size) + (1 if list_sum % batch_size else 0)
assert query_spy.call_count == expected_count
assert resp == response


@pytest.mark.parametrize("list_sum", [5, 10, 30])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 50])
async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_size):
child_list = [{"foo": i} for i in range(list_sum)]
response = {
"get_child_device_list": {
"child_device_list": child_list,
"start_index": 0,
"sum": list_sum,
},
"get_child_device_component_list": {
"child_component_list": child_list,
"start_index": 0,
"sum": list_sum,
},
}
request = {"get_child_device_list": None, "get_child_device_component_list": None}

ft = FakeSmartTransport(
response,
"foobar",
list_return_size=batch_size,
component_nego_not_included=True,
)
protocol = SmartProtocol(transport=ft)
query_spy = mocker.spy(protocol, "_execute_query")
resp = await protocol.query(request)
expected_count = 1 + 2 * (
int(list_sum / batch_size) + (0 if list_sum % batch_size else -1)
)
assert query_spy.call_count == expected_count
assert resp == response