diff --git a/kasa/cli.py b/kasa/cli.py index 167179e36..de9eb95cc 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -223,7 +223,7 @@ async def state(ctx, dev: SmartDevice): """Print out device state and versions.""" await dev.update() click.echo(click.style(f"== {dev.alias} - {dev.model} ==", bold=True)) - click.echo(f"\tHost: {dev.host}") + click.echo(f"\tHost: {dev.protocol.host}") click.echo( click.style( "\tDevice state: {}\n".format("ON" if dev.is_on else "OFF"), diff --git a/kasa/discover.py b/kasa/discover.py index 7aaf85245..efdcedaa1 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -42,7 +42,7 @@ def __init__( self.timeout = timeout self.interface = interface self.on_discovered = on_discovered - self.protocol = TPLinkSmartHomeProtocol() + self.protocol = TPLinkSmartHomeProtocol(target) self.target = (target, Discover.DISCOVERY_PORT) self.discovered_devices = {} self.discovered_devices_raw = {} @@ -201,13 +201,13 @@ async def discover( async def discover_single(host: str) -> SmartDevice: """Discover a single device by the given IP address. - :param host: Hostname of device to query + :param host: fname of device to query :rtype: SmartDevice :return: Object for querying/controlling found device. """ - protocol = TPLinkSmartHomeProtocol() + protocol = TPLinkSmartHomeProtocol(host) - info = await protocol.query(host, Discover.DISCOVERY_QUERY) + info = await protocol.query(Discover.DISCOVERY_QUERY) device_class = Discover._get_device_class(info) if device_class is not None: diff --git a/kasa/protocol.py b/kasa/protocol.py index 6ee6f72d6..864fae677 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -28,14 +28,21 @@ class TPLinkSmartHomeProtocol: DEFAULT_PORT = 9999 DEFAULT_TIMEOUT = 5 - @staticmethod - async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict: - """Request information from a TP-Link SmartHome Device. + def __init__(self, host: str, retry_count: int = 3): + """Initialize a new instance of protocol. :param str host: host name or ip address of the device + :param int retry_count: how many retries to do in case of failure + """ + self.host = host + self.port = TPLinkSmartHomeProtocol.DEFAULT_PORT + self.retry_count = retry_count + + async def query(self, request: Union[str, Dict]) -> Dict: + """Request information from a TP-Link SmartHome Device. + :param request: command to send to the device (can be either dict or json string) - :param retry_count: how many retries to do in case of failure :return: response dict """ if isinstance(request, dict): @@ -43,11 +50,9 @@ async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> D timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT writer = None - for retry in range(retry_count + 1): + for retry in range(self.retry_count + 1): try: - task = asyncio.open_connection( - host, TPLinkSmartHomeProtocol.DEFAULT_PORT - ) + task = asyncio.open_connection(self.host, self.port) reader, writer = await asyncio.wait_for(task, timeout=timeout) _LOGGER.debug("> (%i) %s", len(request), request) writer.write(TPLinkSmartHomeProtocol.encrypt(request)) @@ -73,7 +78,7 @@ async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> D return json_payload except Exception as ex: - if retry >= retry_count: + if retry >= self.retry_count: _LOGGER.debug("Giving up after %s retries", retry) raise SmartDeviceException( "Unable to query the device: %s" % ex diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 19589bbad..12e1c683a 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -218,11 +218,9 @@ def __init__(self, host: str) -> None: :param str host: host name or ip address on which the device listens """ - self.host = host - - self.protocol = TPLinkSmartHomeProtocol() + self.protocol = TPLinkSmartHomeProtocol(host) self.emeter_type = "emeter" - _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) + _LOGGER.debug("Initializing %s of type %s", self.protocol.host, type(self)) self._device_type = DeviceType.Unknown # TODO: typing Any is just as using Optional[Dict] would require separate checks in # accessors. the @updated_required decorator does not ensure mypy that these @@ -255,7 +253,7 @@ async def _query_helper( request = self._create_request(target, cmd, arg, child_ids) try: - response = await self.protocol.query(host=self.host, request=request) + response = await self.protocol.query(request=request) except Exception as ex: raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex @@ -300,7 +298,7 @@ async def update(self): # Check for emeter if we were never updated, or if the device has emeter if self._last_update is None or self.has_emeter: req.update(self._create_emeter_request()) - self._last_update = await self.protocol.query(self.host, req) + self._last_update = await self.protocol.query(req) # TODO: keep accessible for tests self._sys_info = self._last_update["system"]["get_sysinfo"] @@ -741,5 +739,5 @@ def is_color(self) -> bool: def __repr__(self): if self._last_update is None: - return f"<{self._device_type} at {self.host} - update() needed>" - return f"<{self._device_type} model {self.model} at {self.host} ({self.alias}), is_on: {self.is_on} - dev specific: {self.state_information}>" + return f"<{self._device_type} at {self.protocol.host} - update() needed>" + return f"<{self._device_type} model {self.model} at {self.protocol.host} ({self.alias}), is_on: {self.is_on} - dev specific: {self.state_information}>" diff --git a/kasa/smartstrip.py b/kasa/smartstrip.py index 222c73e45..be79f99b5 100755 --- a/kasa/smartstrip.py +++ b/kasa/smartstrip.py @@ -94,7 +94,9 @@ async def update(self): _LOGGER.debug("Initializing %s child sockets", len(children)) for child in children: self.children.append( - SmartStripPlug(self.host, parent=self, child_id=child["id"]) + SmartStripPlug( + self.protocol.host, parent=self, child_id=child["id"] + ) ) async def turn_on(self, **kwargs): diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 69f1f3b72..0510691c1 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -148,7 +148,7 @@ def get_device_for_file(file): sysinfo = json.load(f) model = basename(file) p = device_for_file(model)(host="123.123.123.123") - p.protocol = FakeTransportProtocol(sysinfo) + p.protocol = FakeTransportProtocol("123.123.123.123", sysinfo) asyncio.run(p.update()) return p diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index 55c3e00cb..a022becc9 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -253,8 +253,9 @@ def success(res): class FakeTransportProtocol(TPLinkSmartHomeProtocol): - def __init__(self, info): + def __init__(self, host, info): self.discovery_data = info + self.host = host proto = FakeTransportProtocol.baseproto for target in info: @@ -415,7 +416,7 @@ def light_state(self, x, *args): }, } - async def query(self, host, request, port=9999): + async def query(self, request): proto = self.proto # collect child ids from context diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 51c01d49d..535139a45 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -21,7 +21,8 @@ def aio_mock_writer(_, __): conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count) + protocol = TPLinkSmartHomeProtocol("127.0.0.1", retry_count=retry_count) + await protocol.query({}) assert conn.call_count == retry_count + 1