diff options
Diffstat (limited to 'src/polaris')
-rw-r--r-- | src/polaris/__init__.py | 14 | ||||
-rw-r--r-- | src/polaris/kettle.py | 271 | ||||
-rw-r--r-- | src/polaris/protocol.py | 1015 |
3 files changed, 1093 insertions, 207 deletions
diff --git a/src/polaris/__init__.py b/src/polaris/__init__.py index aa077ce..f1a7c1d 100644 --- a/src/polaris/__init__.py +++ b/src/polaris/__init__.py @@ -1,4 +1,12 @@ -# SPDX-License-Identifier: BSD-3-Clause +# Polaris PWK 1725CGLD "smart" kettle python library +# -------------------------------------------------- +# Copyright (C) Evgeny Zinoviev, 2022 +# License: BSD-3c -from .kettle import Kettle -from .protocol import Message, FrameType, PowerType
\ No newline at end of file +from .kettle import Kettle, DeviceListener +from .protocol import ( + PowerType, + IncomingMessageListener, + ConnectionStatusListener, + ConnectionStatus +)
\ No newline at end of file diff --git a/src/polaris/kettle.py b/src/polaris/kettle.py index 37f6813..1b995fd 100644 --- a/src/polaris/kettle.py +++ b/src/polaris/kettle.py @@ -1,97 +1,238 @@ -# SPDX-License-Identifier: BSD-3-Clause +# Polaris PWK 1725CGLD smart kettle python library +# ------------------------------------------------ +# Copyright (C) Evgeny Zinoviev, 2022 +# License: BSD-3c from __future__ import annotations +import threading import logging import zeroconf -import cryptography.hazmat.primitives._serialization -from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey -from cryptography.hazmat.primitives import hashes - -from functools import partial -from abc import ABC -from ipaddress import ip_address -from typing import Optional +from abc import abstractmethod +from ipaddress import ip_address, IPv4Address, IPv6Address +from typing import Optional, List, Union from .protocol import ( - Connection, + UDPConnection, ModeMessage, - HandshakeMessage, TargetTemperatureMessage, - Message, - PowerType + PowerType, + ConnectionStatus, + ConnectionStatusListener, + WrappedMessage ) -_logger = logging.getLogger(__name__) - -# Polaris PWK 1725CGLD IoT kettle -class Kettle(zeroconf.ServiceListener, ABC): - macaddr: str - device_token: str - sb: Optional[zeroconf.ServiceBrowser] - found_device: Optional[zeroconf.ServiceInfo] - conn: Optional[Connection] +class DeviceDiscover(threading.Thread, zeroconf.ServiceListener): + si: Optional[zeroconf.ServiceInfo] + _mac: str + _sb: Optional[zeroconf.ServiceBrowser] + _zc: Optional[zeroconf.Zeroconf] + _listeners: List[DeviceListener] + _valid_addresses: List[Union[IPv4Address, IPv6Address]] + _only_ipv4: bool - def __init__(self, mac: str, device_token: str): + def __init__(self, mac: str, + listener: Optional[DeviceListener] = None, + only_ipv4=True): super().__init__() - self.zeroconf = zeroconf.Zeroconf() - self.sb = None - self.macaddr = mac - self.device_token = device_token - self.found_device = None - self.conn = None + self.si = None + self._mac = mac + self._zc = None + self._sb = None + self._only_ipv4 = only_ipv4 + self._valid_addresses = [] + self._listeners = [] + if isinstance(listener, DeviceListener): + self._listeners.append(listener) + self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}') + + def add_listener(self, listener: DeviceListener): + if listener not in self._listeners: + self._listeners.append(listener) + else: + self._logger.warning(f'add_listener: listener {listener} already in the listeners list') + + def set_info(self, info: zeroconf.ServiceInfo): + valid_addresses = self._get_valid_addresses(info) + if not valid_addresses: + raise ValueError('no valid addresses') + self._valid_addresses = valid_addresses + self.si = info + for f in self._listeners: + try: + f.device_updated() + except Exception as exc: + self._logger.error(f'set_info: error while calling device_updated on {f}') + self._logger.exception(exc) - def find(self) -> zeroconf.ServiceInfo: - self.sb = zeroconf.ServiceBrowser(self.zeroconf, "_syncleo._udp.local.", self) - self.sb.join() + def add_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: + self._add_update_service('add_service', zc, type_, name) - return self.found_device + def update_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: + self._add_update_service('update_service', zc, type_, name) - # zeroconf.ServiceListener implementation - def add_service(self, - zc: zeroconf.Zeroconf, - type_: str, - name: str) -> None: - if name.startswith(f'{self.macaddr}.'): - info = zc.get_service_info(type_, name) + def _add_update_service(self, method: str, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: + info = zc.get_service_info(type_, name) + if name.startswith(f'{self._mac}.'): + self._logger.info(f'{method}: type={type_} name={name}') + try: + self.set_info(info) + except ValueError as exc: + self._logger.error(f'{method}: rejected: {str(exc)}') + else: + self._logger.debug(f'{method}: mac not matched: {info}') + + def remove_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: + if name.startswith(f'{self._mac}.'): + self._logger.info(f'remove_service: type={type_} name={name}') + # TODO what to do here?! + + def run(self): + self._logger.info('starting zeroconf service browser') + ip_version = zeroconf.IPVersion.V4Only if self._only_ipv4 else zeroconf.IPVersion.All + self._zc = zeroconf.Zeroconf(ip_version=ip_version) + self._sb = zeroconf.ServiceBrowser(self._zc, "_syncleo._udp.local.", self) + self._sb.join() + + def stop(self): + if self._sb: try: - self.sb.cancel() + self._sb.cancel() except RuntimeError: pass - self.zeroconf.close() - self.found_device = info + self._sb = None + self._zc.close() + self._zc = None + + def _get_valid_addresses(self, si: zeroconf.ServiceInfo) -> List[Union[IPv4Address, IPv6Address]]: + valid = [] + for addr in map(ip_address, si.addresses): + if self._only_ipv4 and not isinstance(addr, IPv4Address): + continue + if isinstance(addr, IPv4Address) and str(addr).startswith('169.254.'): + continue + valid.append(addr) + return valid - assert self.device_curve == 29, f'curve type {self.device_curve} is not implemented' - - def start_server(self, callback: callable): - addresses = list(map(ip_address, self.found_device.addresses)) - self.conn = Connection(addr=addresses[0], - port=int(self.found_device.port), - device_pubkey=self.device_pubkey, - device_token=bytes.fromhex(self.device_token)) + @property + def pubkey(self) -> bytes: + return bytes.fromhex(self.si.properties[b'public'].decode()) - # shake the kettle's hand - self._pass_message(HandshakeMessage(), callback) - self.conn.start() + @property + def curve(self) -> int: + return int(self.si.properties[b'curve'].decode()) - def stop_server(self): - self.conn.interrupted = True + @property + def addr(self) -> Union[IPv4Address, IPv6Address]: + return self._valid_addresses[0] @property - def device_pubkey(self) -> bytes: - return bytes.fromhex(self.found_device.properties[b'public'].decode()) + def port(self) -> int: + return int(self.si.port) @property - def device_curve(self) -> int: - return int(self.found_device.properties[b'curve'].decode()) + def protocol(self) -> int: + return int(self.si.properties[b'protocol'].decode()) + + +class DeviceListener: + @abstractmethod + def device_updated(self): + pass + + +class Kettle(DeviceListener, ConnectionStatusListener): + mac: str + device: Optional[DeviceDiscover] + device_token: str + conn: Optional[UDPConnection] + conn_status: Optional[ConnectionStatus] + _logger: logging.Logger + _find_evt: threading.Event + + def __init__(self, mac: str, device_token: str): + super().__init__() + self.mac = mac + self.device = None + self.device_token = device_token + self.conn = None + self.conn_status = None + self._find_evt = threading.Event() + self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}') + + def device_updated(self): + self._find_evt.set() + self._logger.info(f'device updated, service info: {self.device.si}') + + def connection_status_updated(self, status: ConnectionStatus): + self.conn_status = status + + def discover(self, wait=True, timeout=None, listener=None) -> Optional[zeroconf.ServiceInfo]: + do_start = False + if not self.device: + self.device = DeviceDiscover(self.mac, listener=self, only_ipv4=True) + do_start = True + self._logger.debug('discover: started device discovery') + else: + self._logger.warning('discover: already started') + + if listener is not None: + self.device.add_listener(listener) + + if do_start: + self.device.start() + + if wait: + self._find_evt.clear() + try: + self._find_evt.wait(timeout=timeout) + except KeyboardInterrupt: + self.device.stop() + return None + return self.device.si + + def start_server_if_needed(self, + incoming_message_listener=None, + connection_status_listener=None): + if self.conn: + self._logger.warning('start_server_if_needed: server is already started!') + self.conn.set_address(self.device.addr, self.device.port) + self.conn.set_device_pubkey(self.device.pubkey) + return + + assert self.device.curve == 29, f'curve type {self.device.curve} is not implemented' + assert self.device.protocol == 2, f'protocol {self.device.protocol} is not supported' + + self.conn = UDPConnection(addr=self.device.addr, + port=self.device.port, + device_pubkey=self.device.pubkey, + device_token=bytes.fromhex(self.device_token)) + if incoming_message_listener: + self.conn.add_incoming_message_listener(incoming_message_listener) + + self.conn.add_connection_status_listener(self) + if connection_status_listener: + self.conn.add_connection_status_listener(connection_status_listener) + + self.conn.start() + + def stop_all(self): + # when we stop server, we should also stop device discovering service + if self.conn: + self.conn.interrupted = True + self.conn = None + self.device.stop() + self.device = None + + def is_connected(self) -> bool: + return self.conn is not None and self.conn_status == ConnectionStatus.CONNECTED def set_power(self, power_type: PowerType, callback: callable): - self._pass_message(ModeMessage(power_type), callback) + message = ModeMessage(power_type) + self.conn.enqueue_message(WrappedMessage(message, handler=callback, ack=True)) def set_target_temperature(self, temp: int, callback: callable): - self._pass_message(TargetTemperatureMessage(temp), callback) - - def _pass_message(self, message: Message, callback: callable): - self.conn.send_message(message, partial(callback, self)) + message = TargetTemperatureMessage(temp) + self.conn.enqueue_message(WrappedMessage(message, handler=callback, ack=True)) diff --git a/src/polaris/protocol.py b/src/polaris/protocol.py index cc4e36a..5d7390f 100644 --- a/src/polaris/protocol.py +++ b/src/polaris/protocol.py @@ -1,39 +1,61 @@ -# SPDX-License-Identifier: BSD-3-Clause +# Polaris PWK 1725CGLD "smart" kettle python library +# -------------------------------------------------- +# Copyright (C) Evgeny Zinoviev, 2022 +# License: BSD-3c from __future__ import annotations + import logging import socket import random import struct import threading -import queue +import time -from enum import Enum -from typing import Union, Optional, Any +from abc import abstractmethod, ABC +from enum import Enum, auto +from typing import Union, Optional, Dict, Tuple, List from ipaddress import IPv4Address, IPv6Address -import cryptography.hazmat.primitives._serialization +import cryptography.hazmat.primitives._serialization as srlz from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey from cryptography.hazmat.primitives import ciphers, padding, hashes from cryptography.hazmat.primitives.ciphers import algorithms, modes - +ReprDict = Dict[str, Union[str, int, float, bool]] _logger = logging.getLogger(__name__) +PING_FREQUENCY = 3 +RESEND_ATTEMPTS = 5 +READ_TIMEOUT = 1 +ERROR_TIMEOUT = 15 +MESSAGE_QUEUE_REMOVE_DELAY = 13 # after what time to delete (and pass False to handlers, if needed) messages with phase=DONE from queue +DISCONNECT_TIMEOUT = 15 + -# drop-in replacement for Java API +def safe_callback_call(f: callable, + *args, + logger: logging.Logger = None, + error_message: str = None): + try: + return f(*args) + except Exception as exc: + logger.error(f'{error_message}, see exception below:') + logger.exception(exc) + return None + + +# drop-in replacement for java.lang.System.arraycopy # TODO: rewrite def arraycopy(src, src_pos, dest, dest_pos, length): for i in range(length): dest[i + dest_pos] = src[i + src_pos] -class FrameType(Enum): - ACK = 0 - CMD = 1 - AUX = 2 - NAK = 3 +# "convert" unsigned byte to signed +def u8_to_s8(b: int) -> int: + return struct.unpack('b', bytes([b]))[0] class PowerType(Enum): @@ -44,23 +66,37 @@ class PowerType(Enum): # update: if I set it to '2', it just resets to '0' +# low-level protocol structures +# ----------------------------- + +class FrameType(Enum): + ACK = 0 + CMD = 1 + AUX = 2 + NAK = 3 + + class FrameHead: - seq: int # u8 + seq: Optional[int] # u8 type: FrameType # u8 - length: int # u16 + length: int # u16. This is the length of FrameItem's payload @staticmethod def from_bytes(buf: bytes) -> FrameHead: seq, ft, length = struct.unpack('<BBH', buf) return FrameHead(seq, FrameType(ft), length) - def __init__(self, seq: int, frame_type: FrameType, length: Optional[int] = None): + def __init__(self, + seq: Optional[int], + frame_type: FrameType, + length: Optional[int] = None): self.seq = seq self.type = frame_type self.length = length or 0 def pack(self) -> bytes: assert self.length != 0, "FrameHead.length has not been set" + assert self.seq is not None, "FrameHead.seq has not been set" return struct.pack('<BBH', self.seq, self.type.value, self.length) @@ -84,30 +120,47 @@ class FrameItem: return bytes(ba) +# high-level wrappers around FrameItem +# ------------------------------------ + +class MessagePhase(Enum): + WAITING = 0 + SENT = 1 + DONE = 2 + + class Message: frame: Optional[FrameItem] + id: int + + _global_id = 0 def __init__(self): self.frame = None + # global internal message id, only useful for debugging purposes + self.id = self.next_id() + + def __repr__(self): + return f'<{self.__class__.__name__} id={self.id} seq={self.frame.head.seq}>' + @staticmethod - def from_encrypted(buf: bytes, - inkey: bytes, - outkey: bytes) -> Message: - # _logger.debug('[from_encrypted] buf='+buf.hex()) - # print(f'buf len={len(buf)}') + def next_id(): + _id = Message._global_id + Message._global_id += 1 + return _id + + @staticmethod + def from_encrypted(buf: bytes, inkey: bytes, outkey: bytes) -> Message: + _logger.debug(f'Message:from_encrypted: buf={buf.hex()}') + assert len(buf) >= 4, 'invalid size' head = FrameHead.from_bytes(buf[:4]) assert len(buf) == head.length + 4, f'invalid buf size ({len(buf)} != {head.length})' - payload = buf[4:] - - # byte b = paramFrameHead.seq; b = head.seq - # TODO check if protocol is 2, otherwise raise an exception - j = b & 0xF k = b >> 4 & 0xF @@ -123,15 +176,9 @@ class Message: decryptor = cipher.decryptor() decrypted_data = decryptor.update(payload) + decryptor.finalize() - # print(f'head.length={head.length} len(decr)={len(decrypted_data)}') - # if len(decrypted_data) > head.length: unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() decrypted_data = unpadder.update(decrypted_data) - # try: decrypted_data += unpadder.finalize() - # except ValueError as exc: - # _logger.exception(exc) - # pass assert len(decrypted_data) != 0, 'decrypted data is null' assert head.seq == decrypted_data[0], f'decrypted seq mismatch {head.seq} != {decrypted_data[0]}' @@ -145,29 +192,54 @@ class Message: return NakMessage(head.seq) elif head.type == FrameType.AUX: + # TODO implement AUX raise NotImplementedError('FrameType AUX is not yet implemented') elif head.type == FrameType.CMD: - cmd = decrypted_data[0] + type = decrypted_data[1] data = decrypted_data[2:] - return CmdMessage(head.seq, cmd, data) + + cl = UnknownMessage + + subclasses = [cl for cl in CmdIncomingMessage.__subclasses__() if cl is not SimpleBooleanMessage] + subclasses.extend(SimpleBooleanMessage.__subclasses__()) + + for _cl in subclasses: + # `UnknownMessage` is a special class that holds a packed command that we don't recognize. + # It will be used anyway if we don't find a match, so skip it here + if _cl == UnknownMessage: + continue + + if _cl.TYPE == type: + cl = _cl + break + + m = cl.from_packed_data(data, seq=head.seq) + if isinstance(m, UnknownMessage): + m.set_type(type) + return m else: raise NotImplementedError(f'Unexpected frame type: {head.type}') - @property - def data(self) -> bytes: + def pack_data(self) -> bytes: return b'' - def encrypt(self, - outkey: bytes, - inkey: bytes, - token: bytes, - pubkey: bytes): + @property + def seq(self) -> Union[int, None]: + try: + return self.frame.head.seq + except: + return None + + @seq.setter + def seq(self, seq: int): + self.frame.head.seq = seq + def encrypt(self, outkey: bytes, inkey: bytes, token: bytes, pubkey: bytes): assert self.frame is not None - data = self.data + data = self._get_data_to_encrypt() assert data is not None b = self.frame.head.seq @@ -199,7 +271,7 @@ class Message: arraycopy(data, 0, newdata, 1, len(data)) newdata = bytes(newdata) - _logger.debug('payload to be sent: ' + newdata.hex()) + _logger.debug('frame payload to be encrypted: ' + newdata.hex()) padder = padding.PKCS7(algorithms.AES.block_size).padder() ciphertext = bytearray() @@ -208,74 +280,143 @@ class Message: self.frame.setpayload(ciphertext) - def set_seq(self, seq: int): - self.frame.head.seq = seq - - def __repr__(self): - return f'<{self.__class__.__name__} seq={self.frame.head.seq}>' + def _get_data_to_encrypt(self) -> bytes: + return self.pack_data() -class AckMessage(Message): - def __init__(self, seq: int = 0): +class AckMessage(Message, ABC): + def __init__(self, seq: Optional[int] = None): super().__init__() - self.frame = FrameItem(FrameHead(seq, FrameType.ACK, 0)) + self.frame = FrameItem(FrameHead(seq, FrameType.ACK, None)) -class NakMessage(Message): - def __init__(self, seq: int = 0): +class NakMessage(Message, ABC): + def __init__(self, seq: Optional[int] = None): super().__init__() - self.frame = FrameItem(FrameHead(seq, FrameType.NAK, 0)) + self.frame = FrameItem(FrameHead(seq, FrameType.NAK, None)) class CmdMessage(Message): - _type: Optional[int] - _data: bytes + type: Optional[int] + data: bytes - def __init__(self, seq=0, - type: Optional[int] = None, - data: bytes = b''): - super().__init__() - self._data = data - if type is not None: - self.frame = FrameItem(FrameHead(seq, FrameType.CMD)) - self._type = type - else: - self._type = None + TYPE = None - @property - def data(self) -> bytes: + def _get_data_to_encrypt(self) -> bytes: buf = bytearray() - buf.append(self._type) - buf.extend(self._data) + buf.append(self.get_type()) + buf.extend(self.pack_data()) return bytes(buf) + def __init__(self, seq: Optional[int] = None): + super().__init__() + self.frame = FrameItem(FrameHead(seq, FrameType.CMD)) + self.data = b'' + + def _repr_fields(self) -> ReprDict: + return { + 'cmd': self.get_type() + } + def __repr__(self): params = [ __name__+'.'+self.__class__.__name__, - f'seq={self.frame.head.seq}', - # f'type={self.frame.head.type}', - f'cmd={self._type}' + f'id={self.id}', + f'seq={self.seq}' ] - if self._data: - params.append(f'data={self._data.hex()}') + fields = self._repr_fields() + if fields: + for k, v in fields.items(): + params.append(f'{k}={v}') + elif self.data: + params.append(f'data={self.data.hex()}') return '<'+' '.join(params)+'>' + def get_type(self) -> int: + return self.__class__.TYPE + + +class CmdIncomingMessage(CmdMessage): + @staticmethod + @abstractmethod + def from_packed_data(cls, data: bytes, seq: Optional[int] = None): + pass + + @abstractmethod + def _repr_fields(self) -> ReprDict: + pass + + +class CmdOutgoingMessage(CmdMessage): + @abstractmethod + def pack_data(self) -> bytes: + return b'' + + +class ModeMessage(CmdOutgoingMessage, CmdIncomingMessage): + TYPE = 1 + + pt: PowerType + + def __init__(self, power_type: PowerType, seq: Optional[int] = None): + super().__init__(seq) + self.pt = power_type + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> ModeMessage: + assert len(data) == 1, 'data size expected to be 1' + mode, = struct.unpack('B', data) + return ModeMessage(PowerType(mode), seq=seq) + + def pack_data(self) -> bytes: + return self.pt.value.to_bytes(1, byteorder='little') -class ModeMessage(CmdMessage): - def __init__(self, power_type: PowerType): - super().__init__(type=1, - data=(power_type.value).to_bytes(1, byteorder='little')) + def _repr_fields(self) -> ReprDict: + return {'mode': self.pt.name} -class TargetTemperatureMessage(CmdMessage): - def __init__(self, temp: int): - super().__init__(type=2, - data=bytes(bytearray([temp, 0]))) +class TargetTemperatureMessage(CmdOutgoingMessage, CmdIncomingMessage): + temperature: int + TYPE = 2 + def __init__(self, temp: int, seq: Optional[int] = None): + super().__init__(seq) + self.temperature = temp + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> TargetTemperatureMessage: + assert len(data) == 2, 'data size expected to be 2' + nat, frac = struct.unpack('BB', data) + temp = int(nat + (frac / 100)) + return TargetTemperatureMessage(temp, seq=seq) + + def pack_data(self) -> bytes: + return bytes([self.temperature, 0]) + + def _repr_fields(self) -> ReprDict: + return {'temperature': self.temperature} + + +class PingMessage(CmdIncomingMessage, CmdOutgoingMessage): + TYPE = 255 + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> PingMessage: + assert len(data) == 0, 'no data expected' + return PingMessage(seq=seq) + + def pack_data(self) -> bytes: + return b'' + + def _repr_fields(self) -> ReprDict: + return {} + + +# This is the first protocol message. Sent by a client. +# Kettle usually ACKs this, but sometimes i don't get any ACK and the very next message is HandshakeResponseMessage. class HandshakeMessage(CmdMessage): - def __init__(self): - super().__init__(type=0) + TYPE = 0 def encrypt(self, outkey: bytes, @@ -297,20 +438,312 @@ class HandshakeMessage(CmdMessage): self.frame.setpayload(pld) -# TODO -# implement resending UDP messages if no answer has been received in a second -# try at least 5 times, then give up -class Connection(threading.Thread): - seq_no: int +# Kettle either sends this right after the handshake, of first it ACKs the handshake then sends this. +class HandshakeResponseMessage(CmdIncomingMessage): + TYPE = 0 + + protocol: int + fw_major: int + fw_minor: int + mode: int + token: bytes + + def __init__(self, + protocol: int, + fw_major: int, + fw_minor: int, + mode: int, + token: bytes, + seq: Optional[int] = None): + super().__init__(seq) + self.protocol = protocol + self.fw_major = fw_major + self.fw_minor = fw_minor + self.mode = mode + self.token = token + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> HandshakeResponseMessage: + protocol, fw_major, fw_minor, mode = struct.unpack('<HBBB', data[:5]) + return HandshakeResponseMessage(protocol, fw_major, fw_minor, mode, token=data[5:], seq=seq) + + def _repr_fields(self) -> ReprDict: + return { + 'protocol': self.protocol, + 'fw': f'{self.fw_major}.{self.fw_minor}', + 'mode': self.mode, + 'token': self.token.hex() + } + + +# Apparently, some hardware info. +# On the other hand, if you look at com.syncleiot.iottransport.commands.CmdHardware, its mqtt topic is "mcu_firmware". +# My device returns 1.1.1. The thing uses on ESP8266 MCU under the hood (or, more precisely, under a piece of cheap +# plastic), so maybe 1.1.1 is the MCU fw revision. +class DeviceHardwareMessage(CmdIncomingMessage): + TYPE = 143 # -113 + + hw: List[int] + + def __init__(self, hw: List[int], seq: Optional[int] = None): + super().__init__(seq) + self.hw = hw + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> DeviceHardwareMessage: + assert len(data) == 3, 'invalid data size, expected 3' + hw = list(struct.unpack('<BBB', data)) + return DeviceHardwareMessage(hw, seq=seq) + + def _repr_fields(self) -> ReprDict: + return {'device_hardware': '.'.join(map(str, self.hw))} + + +# This message is sent by kettle right after the HandshakeMessageResponse. +# The diagnostic data is supposed to be sent to vendor, which we, obviously, not going to do. +# So just ACK and skip it. +class DeviceDiagnosticMessage(CmdIncomingMessage): + TYPE = 145 # -111 + + diag_data: bytes + + def __init__(self, diag_data: bytes, seq: Optional[int] = None): + super().__init__(seq) + self.diag_data = diag_data + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> DeviceDiagnosticMessage: + return DeviceDiagnosticMessage(diag_data=data, seq=seq) + + def _repr_fields(self) -> ReprDict: + return {'diag_data': self.diag_data.hex()} + + +class SimpleBooleanMessage(ABC, CmdIncomingMessage): + value: bool + + def __init__(self, value: bool, seq: Optional[int] = None): + super().__init__(seq) + self.value = value + + @classmethod + def from_packed_data(cls, data: bytes, seq: Optional[int] = None): + assert len(data) == 1, 'invalid data size, expected 1' + enabled, = struct.unpack('<B', data) + return cls(value=enabled == 1, seq=seq) + + @abstractmethod + def _repr_fields(self) -> ReprDict: + pass + + +class AccessControlMessage(SimpleBooleanMessage): + TYPE = 133 # -123 + + def _repr_fields(self) -> ReprDict: + return {'acl_enabled': self.value} + + +class ErrorMessage(SimpleBooleanMessage): + TYPE = 7 + + def _repr_fields(self) -> ReprDict: + return {'error': self.value} + + +class ChildLockMessage(SimpleBooleanMessage): + TYPE = 30 + + def _repr_fields(self) -> ReprDict: + return {'child_lock': self.value} + + +class VolumeMessage(SimpleBooleanMessage): + TYPE = 9 + + def _repr_fields(self) -> ReprDict: + return {'volume': self.value} + + +class BacklightMessage(SimpleBooleanMessage): + TYPE = 28 + + def _repr_fields(self) -> ReprDict: + return {'backlight': self.value} + + +class CurrentTemperatureMessage(CmdIncomingMessage): + TYPE = 20 + + current_temperature: int + + def __init__(self, temp: int, seq: Optional[int] = None): + super().__init__(seq) + self.current_temperature = temp + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> CurrentTemperatureMessage: + assert len(data) == 2, 'data size expected to be 2' + nat, frac = struct.unpack('BB', data) + temp = int(nat + (frac / 100)) + return CurrentTemperatureMessage(temp, seq=seq) + + def pack_data(self) -> bytes: + return bytes([self.current_temperature, 0]) + + def _repr_fields(self) -> ReprDict: + return {'current_temperature': self.current_temperature} + + +class UnknownMessage(CmdIncomingMessage): + type: Optional[int] + data: bytes + + def __init__(self, data: bytes, **kwargs): + super().__init__(**kwargs) + self.type = None + self.data = data + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> UnknownMessage: + return UnknownMessage(data, seq=seq) + + def set_type(self, type: int): + self.type = type + + def get_type(self) -> int: + return self.type + + def _repr_fields(self) -> ReprDict: + return { + 'type': self.type, + 'data': self.data.hex() + } + + +class WrappedMessage: + _message: Message + _handler: Optional[callable] + _validator: Optional[callable] + _logger: Optional[logging.Logger] + _phase: MessagePhase + _phase_update_time: float + + def __init__(self, + message: Message, + handler: Optional[callable] = None, + validator: Optional[callable] = None, + ack=False): + self._message = message + self._handler = handler + self._validator = validator + self._logger = None + self._phase = MessagePhase.WAITING + self._phase_update_time = 0 + if not validator and ack: + self._validator = lambda m: isinstance(m, AckMessage) + + def setlogger(self, logger: logging.Logger): + self._logger = logger + + def validate(self, message: Message): + if not self._validator: + return True + return self._validator(message) + + def call(self, *args, error_message: str = None) -> None: + if not self._handler: + return + try: + self._handler(*args) + except Exception as exc: + logger = self._logger or logging.getLogger(self.__class__.__name__) + logger.error(f'{error_message}, see exception below:') + logger.exception(exc) + + @property + def phase(self) -> MessagePhase: + return self._phase + + @phase.setter + def phase(self, phase: MessagePhase): + self._phase = phase + self._phase_update_time = 0 if phase == MessagePhase.WAITING else time.time() + + @property + def phase_update_time(self) -> float: + return self._phase_update_time + + @property + def message(self) -> Message: + return self._message + + @property + def id(self) -> int: + return self._message.id + + @property + def seq(self) -> int: + return self._message.seq + + @seq.setter + def seq(self, seq: int): + self._message.seq = seq + + def __repr__(self): + return f'<{__name__}.{self.__class__.__name__} message={self._message.__repr__()}>' + + +# Connection stuff +# Well, strictly speaking, as it's UDP, there's no connection, but who cares. +# --------------------------------------------------------------------------- + +class IncomingMessageListener: + @abstractmethod + def incoming_message(self, message: Message) -> Optional[Message]: + pass + + +class ConnectionStatus(Enum): + NOT_CONNECTED = auto() + CONNECTING = auto() + CONNECTED = auto() + RECONNECTING = auto() + DISCONNECTED = auto() + + +class ConnectionStatusListener: + @abstractmethod + def connection_status_updated(self, status: ConnectionStatus): + pass + + +class UDPConnection(threading.Thread, ConnectionStatusListener): + inseq: int + outseq: int source_port: int device_addr: str device_port: int device_token: bytes + device_pubkey: bytes interrupted: bool - waiting_for_response: dict[int, callable] + response_handlers: Dict[int, WrappedMessage] + outgoing_queue: List[WrappedMessage] pubkey: Optional[bytes] encinkey: Optional[bytes] encoutkey: Optional[bytes] + inc_listeners: List[IncomingMessageListener] + conn_listeners: List[ConnectionStatusListener] + outgoing_time: float + outgoing_time_1st: float + incoming_time: float + status: ConnectionStatus + reconnect_tries: int + + _addr_lock: threading.Lock + _iml_lock: threading.Lock + _csl_lock: threading.Lock + _st_lock: threading.Lock def __init__(self, addr: Union[IPv4Address, IPv6Address], @@ -318,36 +751,105 @@ class Connection(threading.Thread): device_pubkey: bytes, device_token: bytes): super().__init__() - self.logger = logging.getLogger(__name__+'.'+self.__class__.__name__) + self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__} <{hex(id(self))}>') self.setName(self.__class__.__name__) - # self.daemon = True - self.seq_no = -1 + self.inseq = 0 + self.outseq = 0 self.source_port = random.randint(1024, 65535) self.device_addr = str(addr) self.device_port = port self.device_token = device_token - self.lock = threading.Lock() - self.outgoing_queue = queue.SimpleQueue() - self.waiting_for_response = {} + self.device_pubkey = device_pubkey + self.outgoing_queue = [] + self.response_handlers = {} self.interrupted = False + self.outgoing_time = 0 + self.outgoing_time_1st = 0 + self.incoming_time = 0 + self.inc_listeners = [] + self.conn_listeners = [self] + self.status = ConnectionStatus.NOT_CONNECTED + self.reconnect_tries = 0 + + self._iml_lock = threading.Lock() + self._csl_lock = threading.Lock() + self._addr_lock = threading.Lock() + self._st_lock = threading.Lock() self.pubkey = None self.encinkey = None self.encoutkey = None - self.prepare_keys(device_pubkey) - - def prepare_keys(self, device_pubkey: bytes): + def connection_status_updated(self, status: ConnectionStatus): + # self._logger.info(f'connection_status_updated: status = {status}') + with self._st_lock: + # self._logger.debug(f'connection_status_updated: lock acquired') + self.status = status + if status == ConnectionStatus.RECONNECTING: + self.reconnect_tries += 1 + if status in (ConnectionStatus.CONNECTED, ConnectionStatus.NOT_CONNECTED, ConnectionStatus.DISCONNECTED): + self.reconnect_tries = 0 + + def _cleanup(self): + # erase outgoing queue + for wm in self.outgoing_queue: + wm.call(False, + error_message=f'_cleanup: exception while calling cb(False) on message {wm.message}') + self.outgoing_queue = [] + self.response_handlers = {} + + # reset timestamps + self.incoming_time = 0 + self.outgoing_time = 0 + self.outgoing_time_1st = 0 + + self._logger.info('_cleanup: done') + + def set_address(self, addr: Union[IPv4Address, IPv6Address], port: int): + with self._addr_lock: + if self.device_addr != str(addr) or self.device_port != port: + self.device_addr = str(addr) + self.device_port = port + self._logger.info(f'updated device network address: {self.device_addr}:{self.device_port}') + + def set_device_pubkey(self, pubkey: bytes): + if self.device_pubkey.hex() != pubkey.hex(): + self._logger.info(f'device pubkey has changed (old={self.device_pubkey.hex()}, new={pubkey.hex()})') + self.device_pubkey = pubkey + self._notify_cs(ConnectionStatus.RECONNECTING) + + def get_address(self) -> Tuple[str, int]: + with self._addr_lock: + return self.device_addr, self.device_port + + def add_incoming_message_listener(self, listener: IncomingMessageListener): + with self._iml_lock: + if listener not in self.inc_listeners: + self.inc_listeners.append(listener) + + def add_connection_status_listener(self, listener: ConnectionStatusListener): + with self._csl_lock: + if listener not in self.conn_listeners: + self.conn_listeners.append(listener) + + def _notify_cs(self, status: ConnectionStatus): + # self._logger.debug(f'_notify_cs: status={status}') + with self._csl_lock: + for obj in self.conn_listeners: + # self._logger.debug(f'_notify_cs: notifying {obj}') + obj.connection_status_updated(status) + + def _prepare_keys(self): # generate key pair privkey = X25519PrivateKey.generate() - self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=cryptography.hazmat.primitives._serialization.Encoding.Raw, - format=cryptography.hazmat.primitives._serialization.PublicFormat.Raw))) + self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=srlz.Encoding.Raw, + format=srlz.PublicFormat.Raw))) # generate shared key device_pubkey = X25519PublicKey.from_public_bytes( - bytes(reversed(device_pubkey)) + bytes(reversed(self.device_pubkey)) ) shared_key = bytes(reversed( privkey.exchange(device_pubkey) @@ -362,51 +864,286 @@ class Connection(threading.Thread): self.encinkey = shared_sha256[:16] self.encoutkey = shared_sha256[16:] + self._logger.info('encryption keys have been created') + + def _handshake_callback(self, r: MessageResponse): + # if got error for our HandshakeMessage, reset everything and try again + if r is False: + # self._logger.debug('_handshake_callback: set status=RECONNETING') + self._notify_cs(ConnectionStatus.RECONNECTING) + else: + # self._logger.debug('_handshake_callback: set status=CONNECTED') + self._notify_cs(ConnectionStatus.CONNECTED) + def run(self): + self._logger.info('starting server loop') + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind(('0.0.0.0', self.source_port)) - sock.settimeout(1) + sock.settimeout(READ_TIMEOUT) while not self.interrupted: - if not self.outgoing_queue.empty(): - message = self.outgoing_queue.get() - message.encrypt(outkey=self.encoutkey, - inkey=self.encinkey, - token=self.device_token, - pubkey=self.pubkey) - buf = message.frame.pack() - self.logger.debug('send: '+buf.hex()) - self.logger.debug(f'sendto: {self.device_addr}:{self.device_port}') - sock.sendto(buf, (self.device_addr, self.device_port)) + with self._st_lock: + status = self.status + + if status in (ConnectionStatus.DISCONNECTED, ConnectionStatus.RECONNECTING): + self._cleanup() + if status == ConnectionStatus.DISCONNECTED: + break + + # no activity for some time means connection is broken + fail = False + fail_path = 0 + if self.incoming_time > 0 and time.time() - self.incoming_time >= DISCONNECT_TIMEOUT: + fail = True + fail_path = 1 + elif self.outgoing_time_1st > 0 and self.incoming_time == 0 and time.time() - self.outgoing_time_1st >= DISCONNECT_TIMEOUT: + fail = True + fail_path = 2 + + if fail: + self._logger.debug(f'run: setting status=RECONNECTING because of long inactivity, fail_path={fail_path}') + self._notify_cs(ConnectionStatus.RECONNECTING) + + # establishing a connection + if status in (ConnectionStatus.RECONNECTING, ConnectionStatus.NOT_CONNECTED): + if status == ConnectionStatus.RECONNECTING and self.reconnect_tries >= 3: + self._notify_cs(ConnectionStatus.DISCONNECTED) + continue + + self._reset_outseq() + self._prepare_keys() + + # shake the imaginary kettle's hand + wrapped = WrappedMessage(HandshakeMessage(), + handler=self._handshake_callback, + validator=lambda m: isinstance(m, (AckMessage, HandshakeResponseMessage))) + self.enqueue_message(wrapped, prepend=True) + self._notify_cs(ConnectionStatus.CONNECTING) + + # pick next (wrapped) message to send + wm = self._get_next_message() # wm means "wrapped message" + if wm: + if not isinstance(wm.message, (AckMessage, NakMessage)): + old_seq = wm.seq + wm.seq = self.outseq + self._set_response_handler(wm, old_seq=old_seq) + elif wm.seq is None: + # ack/nak is a response to some incoming message (and it must have the same seqno that incoming + # message had) + raise RuntimeError(f'run: seq must be set for {wm.__class__.__name__}') + + self._logger.debug(f'run: sending message: {wm.message}') + wm.message.encrypt(outkey=self.encoutkey, inkey=self.encinkey, + token=self.device_token, pubkey=self.pubkey) + buf = wm.message.frame.pack() + one_shot = isinstance(wm.message, (AckMessage, NakMessage, PingMessage)) + # self._logger.debug(f'run: raw data to be sent: {buf.hex()}') + + # sending the first time + if wm.phase == MessagePhase.WAITING: + sock.sendto(buf, self.get_address()) + # resending + elif wm.phase == MessagePhase.SENT: + left = RESEND_ATTEMPTS + while left > 0: + sock.sendto(buf, self.get_address()) + left -= 1 + if left > 0: + time.sleep(0.05) + + if one_shot or wm.phase == MessagePhase.SENT: + wm.phase = MessagePhase.DONE + else: + wm.phase = MessagePhase.SENT + + now = time.time() + self.outgoing_time = now + if not self.outgoing_time_1st: + self.outgoing_time_1st = now + + # receiving data try: data = sock.recv(4096) - self.handle_incoming(data) - except TimeoutError: + self._handle_incoming(data) + except (TimeoutError, socket.timeout): pass - def handle_incoming(self, buf: bytes): - self.logger.debug('handle_incoming: '+buf.hex()) - message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey) - if message.frame.head.seq in self.waiting_for_response: - self.logger.info(f'received awaited message: {message}') + self._logger.info('bye...') + + def _get_next_message(self) -> Optional[WrappedMessage]: + message = None + lpfx = '_get_next_message:' + remove_list = [] + for wm in self.outgoing_queue: + if wm.phase == MessagePhase.DONE: + if isinstance(wm.message, (AckMessage, NakMessage)) or time.time() - wm.phase_update_time >= MESSAGE_QUEUE_REMOVE_DELAY: + remove_list.append(wm) + continue + message = wm + break + + for wm in remove_list: + self._logger.debug(f'{lpfx} rm path: removing id={wm.id} seq={wm.seq}') + + # clear message handler + if wm.seq in self.response_handlers: + self.response_handlers[wm.seq].call( + False, error_message=f'{lpfx} rm path: error while calling callback for seq={wm.seq}') + del self.response_handlers[wm.seq] + + # remove from queue try: - f = self.waiting_for_response[message.frame.head.seq] - f(message) - except Exception as exc: - self.logger.exception(exc) - finally: - del self.waiting_for_response[message.frame.head.seq] + self.outgoing_queue.remove(wm) + except ValueError as exc: + self._logger.error(f'{lpfx} rm path: removing from outgoing_queue raised an exception: {str(exc)}') + + # ping pong + if self.outgoing_time_1st != 0 and self.status == ConnectionStatus.CONNECTED: + now = time.time() + out_delta = now - self.outgoing_time + in_delta = now - self.incoming_time + if not message and max(out_delta, in_delta) > PING_FREQUENCY: + self._logger.debug(f'{lpfx} no activity: in for {in_delta:.2f}s, out for {out_delta:.2f}s, time to ping the damn thing') + message = WrappedMessage(PingMessage(), ack=True) + + return message + + def _handle_incoming(self, buf: bytes): + self.incoming_time = time.time() + + incoming_message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey) + seq = incoming_message.seq + + lpfx = f'handle_incoming({incoming_message.id}):' + self._logger.debug(f'{lpfx} received: {incoming_message}') + + if isinstance(incoming_message, (AckMessage, NakMessage)): + seq_max = self.outseq + seq_name = 'outseq' + else: + seq_max = self.inseq + seq_name = 'inseq' + self.inseq = seq + + if seq < seq_max < 0xfd: + self._logger.warning(f'{lpfx} dropping: seq={seq}, {seq_name}={seq_name}') + return + + if seq not in self.response_handlers: + self._handle_incoming_cmd(incoming_message) + return + + callback_value = None # None means don't call a callback + handler = self.response_handlers[seq] + + if handler.validate(incoming_message): + self._logger.info(f'{lpfx} response OK') + handler.phase = MessagePhase.DONE + callback_value = incoming_message + self._incr_outseq() else: - self.logger.info(f'received message (not awaited): {message}') - - def send_message(self, message: Message, callback: callable): - seq = self.next_seqno() - message.set_seq(seq) - self.outgoing_queue.put(message) - self.waiting_for_response[seq] = callback - - def next_seqno(self) -> int: - with self.lock: - self.seq_no += 1 - self.logger.debug(f'next_seqno: set to {self.seq_no}') - return self.seq_no + self._logger.info(f'{lpfx} response is INVALID') + + # It seems that we've received an incoming CmdMessage or PingMessage with the same seqno that our outgoing + # message had. Bad, but what can I say, this is quick-and-dirty made UDP based protocol and this sort of + # shit just happens. + + # (To be fair, maybe my implementation is not perfect either. But hey, what did you expect from a + # reverse-engineered re-implementation of custom UDP-based protocol that some noname vendor uses for their + # cheap IoT devices? I think _that_ is _the_ definition of shit. At least my implementation is FOSS, which + # is more than you'll ever be able to say about them.) + + # All this crapload of code below might not be needed at all, 'cause the protocol uses separate frame seq + # numbers for IN and OUT frames and this situation is not highly likely, as Theresa May could argue. + # After a handshake, a kettle sends us 10 or so CmdMessages, and then either we continuously ping it every + # 3 seconds, or kettle pings us. This in any case widens the gap between inseq and outseq. + + # But! the seqno is only 1 byte in size and once it reaches 0xff, it circles back to zero. And that (plus, + # perhaps, some bad luck) gives a chance for a collision. + + if handler.phase == MessagePhase.DONE or isinstance(handler.message, HandshakeMessage): + # no more attempts left, returning error back to user + # as to handshake, it cannot fail. + callback_value = False + + # else: + # # try resending the message + # handler.phase_reset() + # max_seq = self.outseq + # wait_remap = {} + # for m in self.outgoing_queue: + # if m.seq in self.waiting_for_response: + # wait_remap[m.seq] = (m.seq+1) % 256 + # m.set_seq((m.seq+1) % 256) + # if m.seq > max_seq: + # max_seq = m.seq + # if max_seq > self.outseq: + # self.outseq = max_seq % 256 + # if wait_remap: + # waiting_new = {} + # for old_seq, new_seq in wait_remap.items(): + # waiting_new[new_seq] = self.waiting_for_response[old_seq] + # self.waiting_for_response = waiting_new + + if isinstance(incoming_message, (PingMessage, CmdIncomingMessage)): + # handle incoming message as usual, as we need to ack/nak it anyway + self._handle_incoming_cmd(incoming_message) + + if callback_value is not None: + handler.call(callback_value, + error_message=f'{lpfx} error while calling callback for msg id={handler.message.id} seq={seq}') + del self.response_handlers[seq] + + def _handle_incoming_cmd(self, incoming_message: Message): + if isinstance(incoming_message, (AckMessage, NakMessage)): + self._logger.debug(f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): it\'s {incoming_message.__class__.__name__}, ignoring') + return + + replied = False + with self._iml_lock: + for f in self.inc_listeners: + retval = safe_callback_call(f.incoming_message, incoming_message, + logger=self._logger, + error_message=f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): error while calling message listener') + if isinstance(retval, Message): + if isinstance(retval, (AckMessage, NakMessage)): + retval.seq = incoming_message.seq + self.enqueue_message(WrappedMessage(retval), prepend=True) + replied = True + break + else: + raise RuntimeError('are you sure your response is correct? only ack/nak are allowed') + + if not replied: + self.enqueue_message(WrappedMessage(AckMessage(incoming_message.seq)), prepend=True) + + def enqueue_message(self, wrapped: WrappedMessage, prepend=False): + self._logger.debug(f'enqueue_message: {wrapped.message}') + if not prepend: + self.outgoing_queue.append(wrapped) + else: + self.outgoing_queue.insert(0, wrapped) + + def _set_response_handler(self, wm: WrappedMessage, old_seq=None): + if old_seq in self.response_handlers: + del self.response_handlers[old_seq] + + seq = wm.seq + assert seq is not None, 'seq is not set' + + if seq in self.response_handlers: + self._logger.warning(f'_set_response_handler(seq={seq}): handler is already set, cancelling it') + self.response_handlers[seq].call(False, + error_message=f'_set_response_handler({seq}): error while calling old callback') + self.response_handlers[seq] = wm + + def _incr_outseq(self) -> None: + self.outseq = (self.outseq + 1) % 256 + + def _reset_outseq(self): + self.outseq = 0 + self._logger.debug(f'_reset_outseq: set 0') + + +MessageResponse = Union[Message, bool] |