From ed590fdd74a466832ee5bb788c3ea69da2f53a3c Mon Sep 17 00:00:00 2001 From: Evgeny Zinoviev Date: Sun, 25 Dec 2022 02:13:45 +0300 Subject: polaris -> syncleo --- src/polaris/__init__.py | 12 - src/polaris/kettle.py | 243 --------- src/polaris/protocol.py | 1169 -------------------------------------------- src/polaris_kettle_bot.py | 4 +- src/polaris_kettle_util.py | 2 +- src/syncleo/__init__.py | 12 + src/syncleo/kettle.py | 243 +++++++++ src/syncleo/protocol.py | 1169 ++++++++++++++++++++++++++++++++++++++++++++ test/test_polaris_stuff.py | 8 +- 9 files changed, 1431 insertions(+), 1431 deletions(-) delete mode 100644 src/polaris/__init__.py delete mode 100644 src/polaris/kettle.py delete mode 100644 src/polaris/protocol.py create mode 100644 src/syncleo/__init__.py create mode 100644 src/syncleo/kettle.py create mode 100644 src/syncleo/protocol.py diff --git a/src/polaris/__init__.py b/src/polaris/__init__.py deleted file mode 100644 index f1a7c1d..0000000 --- a/src/polaris/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Polaris PWK 1725CGLD "smart" kettle python library -# -------------------------------------------------- -# Copyright (C) Evgeny Zinoviev, 2022 -# License: BSD-3c - -from .kettle import Kettle, DeviceListener -from .protocol import ( - PowerType, - IncomingMessageListener, - ConnectionStatusListener, - ConnectionStatus -) \ No newline at end of file diff --git a/src/polaris/kettle.py b/src/polaris/kettle.py deleted file mode 100644 index d6e0dd6..0000000 --- a/src/polaris/kettle.py +++ /dev/null @@ -1,243 +0,0 @@ -# Polaris PWK 1725CGLD smart kettle python library -# ------------------------------------------------ -# Copyright (C) Evgeny Zinoviev, 2022 -# License: BSD-3c - -from __future__ import annotations - -import threading -import logging -import zeroconf - -from abc import abstractmethod -from ipaddress import ip_address, IPv4Address, IPv6Address -from typing import Optional, List, Union - -from .protocol import ( - UDPConnection, - ModeMessage, - TargetTemperatureMessage, - PowerType, - ConnectionStatus, - ConnectionStatusListener, - WrappedMessage -) - - -class DeviceDiscover(threading.Thread, zeroconf.ServiceListener): - si: Optional[zeroconf.ServiceInfo] - _mac: str - _sb: Optional[zeroconf.ServiceBrowser] - _zc: Optional[zeroconf.Zeroconf] - _listeners: List[DeviceListener] - _valid_addresses: List[Union[IPv4Address, IPv6Address]] - _only_ipv4: bool - - def __init__(self, mac: str, - listener: Optional[DeviceListener] = None, - only_ipv4=True): - super().__init__() - self.si = None - self._mac = mac - self._zc = None - self._sb = None - self._only_ipv4 = only_ipv4 - self._valid_addresses = [] - self._listeners = [] - if isinstance(listener, DeviceListener): - self._listeners.append(listener) - self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}') - - def add_listener(self, listener: DeviceListener): - if listener not in self._listeners: - self._listeners.append(listener) - else: - self._logger.warning(f'add_listener: listener {listener} already in the listeners list') - - def set_info(self, info: zeroconf.ServiceInfo): - valid_addresses = self._get_valid_addresses(info) - if not valid_addresses: - raise ValueError('no valid addresses') - self._valid_addresses = valid_addresses - self.si = info - for f in self._listeners: - try: - f.device_updated() - except Exception as exc: - self._logger.error(f'set_info: error while calling device_updated on {f}') - self._logger.exception(exc) - - def add_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: - self._add_update_service('add_service', zc, type_, name) - - def update_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: - self._add_update_service('update_service', zc, type_, name) - - def _add_update_service(self, method: str, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: - info = zc.get_service_info(type_, name) - if name.startswith(f'{self._mac}.'): - self._logger.info(f'{method}: type={type_} name={name}') - try: - self.set_info(info) - except ValueError as exc: - self._logger.error(f'{method}: rejected: {str(exc)}') - else: - self._logger.debug(f'{method}: mac not matched: {info}') - - def remove_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: - if name.startswith(f'{self._mac}.'): - self._logger.info(f'remove_service: type={type_} name={name}') - # TODO what to do here?! - - def run(self): - self._logger.debug('starting zeroconf service browser') - ip_version = zeroconf.IPVersion.V4Only if self._only_ipv4 else zeroconf.IPVersion.All - self._zc = zeroconf.Zeroconf(ip_version=ip_version) - self._sb = zeroconf.ServiceBrowser(self._zc, "_syncleo._udp.local.", self) - self._sb.join() - - def stop(self): - if self._sb: - try: - self._sb.cancel() - except RuntimeError: - pass - self._sb = None - self._zc.close() - self._zc = None - - def _get_valid_addresses(self, si: zeroconf.ServiceInfo) -> List[Union[IPv4Address, IPv6Address]]: - valid = [] - for addr in map(ip_address, si.addresses): - if self._only_ipv4 and not isinstance(addr, IPv4Address): - continue - if isinstance(addr, IPv4Address) and str(addr).startswith('169.254.'): - continue - valid.append(addr) - return valid - - @property - def pubkey(self) -> bytes: - return bytes.fromhex(self.si.properties[b'public'].decode()) - - @property - def curve(self) -> int: - return int(self.si.properties[b'curve'].decode()) - - @property - def addr(self) -> Union[IPv4Address, IPv6Address]: - return self._valid_addresses[0] - - @property - def port(self) -> int: - return int(self.si.port) - - @property - def protocol(self) -> int: - return int(self.si.properties[b'protocol'].decode()) - - -class DeviceListener: - @abstractmethod - def device_updated(self): - pass - - -class Kettle(DeviceListener, ConnectionStatusListener): - mac: str - device: Optional[DeviceDiscover] - device_token: str - conn: Optional[UDPConnection] - conn_status: Optional[ConnectionStatus] - _read_timeout: Optional[int] - _logger: logging.Logger - _find_evt: threading.Event - - def __init__(self, mac: str, device_token: str, read_timeout: Optional[int] = None): - super().__init__() - self.mac = mac - self.device = None - self.device_token = device_token - self.conn = None - self.conn_status = None - self._read_timeout = read_timeout - self._find_evt = threading.Event() - self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}') - - def device_updated(self): - self._find_evt.set() - self._logger.info(f'device updated, service info: {self.device.si}') - - def connection_status_updated(self, status: ConnectionStatus): - self.conn_status = status - - def discover(self, wait=True, timeout=None, listener=None) -> Optional[zeroconf.ServiceInfo]: - do_start = False - if not self.device: - self.device = DeviceDiscover(self.mac, listener=self, only_ipv4=True) - do_start = True - self._logger.debug('discover: started device discovery') - else: - self._logger.warning('discover: already started') - - if listener is not None: - self.device.add_listener(listener) - - if do_start: - self.device.start() - - if wait: - self._find_evt.clear() - try: - self._find_evt.wait(timeout=timeout) - except KeyboardInterrupt: - self.device.stop() - return None - return self.device.si - - def start_server_if_needed(self, - incoming_message_listener=None, - connection_status_listener=None): - if self.conn: - self._logger.warning('start_server_if_needed: server is already started!') - self.conn.set_address(self.device.addr, self.device.port) - self.conn.set_device_pubkey(self.device.pubkey) - return - - assert self.device.curve == 29, f'curve type {self.device.curve} is not implemented' - assert self.device.protocol == 2, f'protocol {self.device.protocol} is not supported' - - kw = {} - if self._read_timeout is not None: - kw['read_timeout'] = self._read_timeout - self.conn = UDPConnection(addr=self.device.addr, - port=self.device.port, - device_pubkey=self.device.pubkey, - device_token=bytes.fromhex(self.device_token), **kw) - if incoming_message_listener: - self.conn.add_incoming_message_listener(incoming_message_listener) - - self.conn.add_connection_status_listener(self) - if connection_status_listener: - self.conn.add_connection_status_listener(connection_status_listener) - - self.conn.start() - - def stop_all(self): - # when we stop server, we should also stop device discovering service - if self.conn: - self.conn.interrupted = True - self.conn = None - self.device.stop() - self.device = None - - def is_connected(self) -> bool: - return self.conn is not None and self.conn_status == ConnectionStatus.CONNECTED - - def set_power(self, power_type: PowerType, callback: callable): - message = ModeMessage(power_type) - self.conn.enqueue_message(WrappedMessage(message, handler=callback, ack=True)) - - def set_target_temperature(self, temp: int, callback: callable): - message = TargetTemperatureMessage(temp) - self.conn.enqueue_message(WrappedMessage(message, handler=callback, ack=True)) diff --git a/src/polaris/protocol.py b/src/polaris/protocol.py deleted file mode 100644 index 36a1a8f..0000000 --- a/src/polaris/protocol.py +++ /dev/null @@ -1,1169 +0,0 @@ -# Polaris PWK 1725CGLD "smart" kettle python library -# -------------------------------------------------- -# Copyright (C) Evgeny Zinoviev, 2022 -# License: BSD-3c - -from __future__ import annotations - -import logging -import socket -import random -import struct -import threading -import time - -from abc import abstractmethod, ABC -from enum import Enum, auto -from typing import Union, Optional, Dict, Tuple, List -from ipaddress import IPv4Address, IPv6Address - -import cryptography.hazmat.primitives._serialization as srlz - -from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey -from cryptography.hazmat.primitives import ciphers, padding, hashes -from cryptography.hazmat.primitives.ciphers import algorithms, modes - -ReprDict = Dict[str, Union[str, int, float, bool]] -_logger = logging.getLogger(__name__) - -PING_FREQUENCY = 3 -RESEND_ATTEMPTS = 5 -ERROR_TIMEOUT = 15 -MESSAGE_QUEUE_REMOVE_DELAY = 13 # after what time to delete (and pass False to handlers, if needed) messages with phase=DONE from queue -DISCONNECT_TIMEOUT = 15 - - -def safe_callback_call(f: callable, - *args, - logger: logging.Logger = None, - error_message: str = None): - try: - return f(*args) - except Exception as exc: - logger.error(f'{error_message}, see exception below:') - logger.exception(exc) - return None - - -# drop-in replacement for java.lang.System.arraycopy -# TODO: rewrite -def arraycopy(src, src_pos, dest, dest_pos, length): - for i in range(length): - dest[i + dest_pos] = src[i + src_pos] - - -# "convert" unsigned byte to signed -def u8_to_s8(b: int) -> int: - return struct.unpack('b', bytes([b]))[0] - - -class PowerType(Enum): - OFF = 0 # turn off - ON = 1 # turn on, set target temperature to 100 - CUSTOM = 3 # turn on, allows custom target temperature - # MYSTERY_MODE = 2 # don't know what 2 means, needs testing - # update: if I set it to '2', it just resets to '0' - - -# low-level protocol structures -# ----------------------------- - -class FrameType(Enum): - ACK = 0 - CMD = 1 - AUX = 2 - NAK = 3 - - -class FrameHead: - seq: Optional[int] # u8 - type: FrameType # u8 - length: int # u16. This is the length of FrameItem's payload - - @staticmethod - def from_bytes(buf: bytes) -> FrameHead: - seq, ft, length = struct.unpack(' bytes: - assert self.length != 0, "FrameHead.length has not been set" - assert self.seq is not None, "FrameHead.seq has not been set" - return struct.pack(' bytes: - ba = bytearray(self.head.pack()) - ba.extend(self.payload) - return bytes(ba) - - -# high-level wrappers around FrameItem -# ------------------------------------ - -class MessagePhase(Enum): - WAITING = 0 - SENT = 1 - DONE = 2 - - -class Message: - frame: Optional[FrameItem] - id: int - - _global_id = 0 - - def __init__(self): - self.frame = None - - # global internal message id, only useful for debugging purposes - self.id = self.next_id() - - def __repr__(self): - return f'<{self.__class__.__name__} id={self.id} seq={self.frame.head.seq}>' - - @staticmethod - def next_id(): - _id = Message._global_id - Message._global_id = (Message._global_id + 1) % 100000 - return _id - - @staticmethod - def from_encrypted(buf: bytes, inkey: bytes, outkey: bytes) -> Message: - _logger.debug(f'Message:from_encrypted: buf={buf.hex()}') - - assert len(buf) >= 4, 'invalid size' - head = FrameHead.from_bytes(buf[:4]) - - assert len(buf) == head.length + 4, f'invalid buf size ({len(buf)} != {head.length})' - payload = buf[4:] - b = head.seq - - j = b & 0xF - k = b >> 4 & 0xF - - key = bytearray(len(inkey)) - arraycopy(inkey, j, key, 0, len(inkey) - j) - arraycopy(inkey, 0, key, len(inkey) - j, j) - - iv = bytearray(len(outkey)) - arraycopy(outkey, k, iv, 0, len(outkey) - k) - arraycopy(outkey, 0, iv, len(outkey) - k, k) - - cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv)) - decryptor = cipher.decryptor() - decrypted_data = decryptor.update(payload) + decryptor.finalize() - - unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() - decrypted_data = unpadder.update(decrypted_data) - decrypted_data += unpadder.finalize() - - assert len(decrypted_data) != 0, 'decrypted data is null' - assert head.seq == decrypted_data[0], f'decrypted seq mismatch {head.seq} != {decrypted_data[0]}' - - # _logger.debug('Message.from_encrypted: plaintext: '+decrypted_data.hex()) - - if head.type == FrameType.ACK: - return AckMessage(head.seq) - - elif head.type == FrameType.NAK: - return NakMessage(head.seq) - - elif head.type == FrameType.AUX: - # TODO implement AUX - raise NotImplementedError('FrameType AUX is not yet implemented') - - elif head.type == FrameType.CMD: - type = decrypted_data[1] - data = decrypted_data[2:] - - cl = UnknownMessage - - subclasses = [cl for cl in CmdIncomingMessage.__subclasses__() if cl is not SimpleBooleanMessage] - subclasses.extend(SimpleBooleanMessage.__subclasses__()) - - for _cl in subclasses: - # `UnknownMessage` is a special class that holds a packed command that we don't recognize. - # It will be used anyway if we don't find a match, so skip it here - if _cl == UnknownMessage: - continue - - if _cl.TYPE == type: - cl = _cl - break - - m = cl.from_packed_data(data, seq=head.seq) - if isinstance(m, UnknownMessage): - m.set_type(type) - return m - - else: - raise NotImplementedError(f'Unexpected frame type: {head.type}') - - def pack_data(self) -> bytes: - return b'' - - @property - def seq(self) -> Union[int, None]: - try: - return self.frame.head.seq - except: - return None - - @seq.setter - def seq(self, seq: int): - self.frame.head.seq = seq - - def encrypt(self, outkey: bytes, inkey: bytes, token: bytes, pubkey: bytes): - assert self.frame is not None - - data = self._get_data_to_encrypt() - assert data is not None - - b = self.frame.head.seq - i = b & 0xf - j = b >> 4 & 0xf - - outkey = bytearray(outkey) - - l = len(outkey) - key = bytearray(l) - - arraycopy(outkey, i, key, 0, l-i) - arraycopy(outkey, 0, key, l-i, i) - - inkey = bytearray(inkey) - - l = len(inkey) - iv = bytearray(l) - - arraycopy(inkey, j, iv, 0, l-j) - arraycopy(inkey, 0, iv, l-j, j) - - cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv)) - encryptor = cipher.encryptor() - - newdata = bytearray(len(data)+1) - newdata[0] = b - - arraycopy(data, 0, newdata, 1, len(data)) - - newdata = bytes(newdata) - _logger.debug('frame payload to be encrypted: ' + newdata.hex()) - - padder = padding.PKCS7(algorithms.AES.block_size).padder() - ciphertext = bytearray() - ciphertext.extend(encryptor.update(padder.update(newdata) + padder.finalize())) - ciphertext.extend(encryptor.finalize()) - - self.frame.setpayload(ciphertext) - - def _get_data_to_encrypt(self) -> bytes: - return self.pack_data() - - -class AckMessage(Message, ABC): - def __init__(self, seq: Optional[int] = None): - super().__init__() - self.frame = FrameItem(FrameHead(seq, FrameType.ACK, None)) - - -class NakMessage(Message, ABC): - def __init__(self, seq: Optional[int] = None): - super().__init__() - self.frame = FrameItem(FrameHead(seq, FrameType.NAK, None)) - - -class CmdMessage(Message): - type: Optional[int] - data: bytes - - TYPE = None - - def _get_data_to_encrypt(self) -> bytes: - buf = bytearray() - buf.append(self.get_type()) - buf.extend(self.pack_data()) - return bytes(buf) - - def __init__(self, seq: Optional[int] = None): - super().__init__() - self.frame = FrameItem(FrameHead(seq, FrameType.CMD)) - self.data = b'' - - def _repr_fields(self) -> ReprDict: - return { - 'cmd': self.get_type() - } - - def __repr__(self): - params = [ - __name__+'.'+self.__class__.__name__, - f'id={self.id}', - f'seq={self.seq}' - ] - fields = self._repr_fields() - if fields: - for k, v in fields.items(): - params.append(f'{k}={v}') - elif self.data: - params.append(f'data={self.data.hex()}') - return '<'+' '.join(params)+'>' - - def get_type(self) -> int: - return self.__class__.TYPE - - -class CmdIncomingMessage(CmdMessage): - @staticmethod - @abstractmethod - def from_packed_data(cls, data: bytes, seq: Optional[int] = None): - pass - - @abstractmethod - def _repr_fields(self) -> ReprDict: - pass - - -class CmdOutgoingMessage(CmdMessage): - @abstractmethod - def pack_data(self) -> bytes: - return b'' - - -class ModeMessage(CmdOutgoingMessage, CmdIncomingMessage): - TYPE = 1 - - pt: PowerType - - def __init__(self, power_type: PowerType, seq: Optional[int] = None): - super().__init__(seq) - self.pt = power_type - - @classmethod - def from_packed_data(cls, data: bytes, seq=0) -> ModeMessage: - assert len(data) == 1, 'data size expected to be 1' - mode, = struct.unpack('B', data) - return ModeMessage(PowerType(mode), seq=seq) - - def pack_data(self) -> bytes: - return self.pt.value.to_bytes(1, byteorder='little') - - def _repr_fields(self) -> ReprDict: - return {'mode': self.pt.name} - - -class TargetTemperatureMessage(CmdOutgoingMessage, CmdIncomingMessage): - temperature: int - - TYPE = 2 - - def __init__(self, temp: int, seq: Optional[int] = None): - super().__init__(seq) - self.temperature = temp - - @classmethod - def from_packed_data(cls, data: bytes, seq=0) -> TargetTemperatureMessage: - assert len(data) == 2, 'data size expected to be 2' - nat, frac = struct.unpack('BB', data) - temp = int(nat + (frac / 100)) - return TargetTemperatureMessage(temp, seq=seq) - - def pack_data(self) -> bytes: - return bytes([self.temperature, 0]) - - def _repr_fields(self) -> ReprDict: - return {'temperature': self.temperature} - - -class PingMessage(CmdIncomingMessage, CmdOutgoingMessage): - TYPE = 255 - - @classmethod - def from_packed_data(cls, data: bytes, seq=0) -> PingMessage: - assert len(data) == 0, 'no data expected' - return PingMessage(seq=seq) - - def pack_data(self) -> bytes: - return b'' - - def _repr_fields(self) -> ReprDict: - return {} - - -# This is the first protocol message. Sent by a client. -# Kettle usually ACKs this, but sometimes i don't get any ACK and the very next message is HandshakeResponseMessage. -class HandshakeMessage(CmdMessage): - TYPE = 0 - - def encrypt(self, - outkey: bytes, - inkey: bytes, - token: bytes, - pubkey: bytes): - cipher = ciphers.Cipher(algorithms.AES(outkey), modes.CBC(inkey)) - encryptor = cipher.encryptor() - - ciphertext = bytearray() - ciphertext.extend(encryptor.update(token)) - ciphertext.extend(encryptor.finalize()) - - pld = bytearray() - pld.append(0) - pld.extend(pubkey) - pld.extend(ciphertext) - - self.frame.setpayload(pld) - - -# Kettle either sends this right after the handshake, of first it ACKs the handshake then sends this. -class HandshakeResponseMessage(CmdIncomingMessage): - TYPE = 0 - - protocol: int - fw_major: int - fw_minor: int - mode: int - token: bytes - - def __init__(self, - protocol: int, - fw_major: int, - fw_minor: int, - mode: int, - token: bytes, - seq: Optional[int] = None): - super().__init__(seq) - self.protocol = protocol - self.fw_major = fw_major - self.fw_minor = fw_minor - self.mode = mode - self.token = token - - @classmethod - def from_packed_data(cls, data: bytes, seq=0) -> HandshakeResponseMessage: - protocol, fw_major, fw_minor, mode = struct.unpack(' ReprDict: - return { - 'protocol': self.protocol, - 'fw': f'{self.fw_major}.{self.fw_minor}', - 'mode': self.mode, - 'token': self.token.hex() - } - - -# Apparently, some hardware info. -# On the other hand, if you look at com.syncleiot.iottransport.commands.CmdHardware, its mqtt topic says "mcu_firmware". -# My device returns 1.1.1. The kettle uses on ESP8266 ESP-12F MCU under the hood (or, more precisely, under a piece of -# cheap plastic), so maybe 1.1.1 is some MCU ROM version. -class DeviceHardwareMessage(CmdIncomingMessage): - TYPE = 143 # -113 - - hw: List[int] - - def __init__(self, hw: List[int], seq: Optional[int] = None): - super().__init__(seq) - self.hw = hw - - @classmethod - def from_packed_data(cls, data: bytes, seq=0) -> DeviceHardwareMessage: - assert len(data) == 3, 'invalid data size, expected 3' - hw = list(struct.unpack(' ReprDict: - return {'device_hardware': '.'.join(map(str, self.hw))} - - -# This message is sent by kettle right after the HandshakeMessageResponse. -# The diagnostic data is supposed to be sent to vendor, which we, obviously, not going to do. -# So just ACK and skip it. -class DeviceDiagnosticMessage(CmdIncomingMessage): - TYPE = 145 # -111 - - diag_data: bytes - - def __init__(self, diag_data: bytes, seq: Optional[int] = None): - super().__init__(seq) - self.diag_data = diag_data - - @classmethod - def from_packed_data(cls, data: bytes, seq=0) -> DeviceDiagnosticMessage: - return DeviceDiagnosticMessage(diag_data=data, seq=seq) - - def _repr_fields(self) -> ReprDict: - return {'diag_data': self.diag_data.hex()} - - -class SimpleBooleanMessage(ABC, CmdIncomingMessage): - value: bool - - def __init__(self, value: bool, seq: Optional[int] = None): - super().__init__(seq) - self.value = value - - @classmethod - def from_packed_data(cls, data: bytes, seq: Optional[int] = None): - assert len(data) == 1, 'invalid data size, expected 1' - enabled, = struct.unpack(' ReprDict: - pass - - -class AccessControlMessage(SimpleBooleanMessage): - TYPE = 133 # -123 - - def _repr_fields(self) -> ReprDict: - return {'acl_enabled': self.value} - - -class ErrorMessage(SimpleBooleanMessage): - TYPE = 7 - - def _repr_fields(self) -> ReprDict: - return {'error': self.value} - - -class ChildLockMessage(SimpleBooleanMessage): - TYPE = 30 - - def _repr_fields(self) -> ReprDict: - return {'child_lock': self.value} - - -class VolumeMessage(SimpleBooleanMessage): - TYPE = 9 - - def _repr_fields(self) -> ReprDict: - return {'volume': self.value} - - -class BacklightMessage(SimpleBooleanMessage): - TYPE = 28 - - def _repr_fields(self) -> ReprDict: - return {'backlight': self.value} - - -class CurrentTemperatureMessage(CmdIncomingMessage): - TYPE = 20 - - current_temperature: int - - def __init__(self, temp: int, seq: Optional[int] = None): - super().__init__(seq) - self.current_temperature = temp - - @classmethod - def from_packed_data(cls, data: bytes, seq=0) -> CurrentTemperatureMessage: - assert len(data) == 2, 'data size expected to be 2' - nat, frac = struct.unpack('BB', data) - temp = int(nat + (frac / 100)) - return CurrentTemperatureMessage(temp, seq=seq) - - def pack_data(self) -> bytes: - return bytes([self.current_temperature, 0]) - - def _repr_fields(self) -> ReprDict: - return {'current_temperature': self.current_temperature} - - -class UnknownMessage(CmdIncomingMessage): - type: Optional[int] - data: bytes - - def __init__(self, data: bytes, **kwargs): - super().__init__(**kwargs) - self.type = None - self.data = data - - @classmethod - def from_packed_data(cls, data: bytes, seq=0) -> UnknownMessage: - return UnknownMessage(data, seq=seq) - - def set_type(self, type: int): - self.type = type - - def get_type(self) -> int: - return self.type - - def _repr_fields(self) -> ReprDict: - return { - 'type': self.type, - 'data': self.data.hex() - } - - -class WrappedMessage: - _message: Message - _handler: Optional[callable] - _validator: Optional[callable] - _logger: Optional[logging.Logger] - _phase: MessagePhase - _phase_update_time: float - - def __init__(self, - message: Message, - handler: Optional[callable] = None, - validator: Optional[callable] = None, - ack=False): - self._message = message - self._handler = handler - self._validator = validator - self._logger = None - self._phase = MessagePhase.WAITING - self._phase_update_time = 0 - if not validator and ack: - self._validator = lambda m: isinstance(m, AckMessage) - - def setlogger(self, logger: logging.Logger): - self._logger = logger - - def validate(self, message: Message): - if not self._validator: - return True - return self._validator(message) - - def call(self, *args, error_message: str = None) -> None: - if not self._handler: - return - try: - self._handler(*args) - except Exception as exc: - logger = self._logger or logging.getLogger(self.__class__.__name__) - logger.error(f'{error_message}, see exception below:') - logger.exception(exc) - - @property - def phase(self) -> MessagePhase: - return self._phase - - @phase.setter - def phase(self, phase: MessagePhase): - self._phase = phase - self._phase_update_time = 0 if phase == MessagePhase.WAITING else time.time() - - @property - def phase_update_time(self) -> float: - return self._phase_update_time - - @property - def message(self) -> Message: - return self._message - - @property - def id(self) -> int: - return self._message.id - - @property - def seq(self) -> int: - return self._message.seq - - @seq.setter - def seq(self, seq: int): - self._message.seq = seq - - def __repr__(self): - return f'<{__name__}.{self.__class__.__name__} message={self._message.__repr__()}>' - - -# Connection stuff -# Well, strictly speaking, as it's UDP, there's no connection, but who cares. -# --------------------------------------------------------------------------- - -class IncomingMessageListener: - @abstractmethod - def incoming_message(self, message: Message) -> Optional[Message]: - pass - - -class ConnectionStatus(Enum): - NOT_CONNECTED = auto() - CONNECTING = auto() - CONNECTED = auto() - RECONNECTING = auto() - DISCONNECTED = auto() - - -class ConnectionStatusListener: - @abstractmethod - def connection_status_updated(self, status: ConnectionStatus): - pass - - -class UDPConnection(threading.Thread, ConnectionStatusListener): - inseq: int - outseq: int - source_port: int - device_addr: str - device_port: int - device_token: bytes - device_pubkey: bytes - interrupted: bool - response_handlers: Dict[int, WrappedMessage] - outgoing_queue: List[WrappedMessage] - pubkey: Optional[bytes] - encinkey: Optional[bytes] - encoutkey: Optional[bytes] - inc_listeners: List[IncomingMessageListener] - conn_listeners: List[ConnectionStatusListener] - outgoing_time: float - outgoing_time_1st: float - incoming_time: float - status: ConnectionStatus - reconnect_tries: int - read_timeout: int - - _addr_lock: threading.Lock - _iml_lock: threading.Lock - _csl_lock: threading.Lock - _st_lock: threading.Lock - - def __init__(self, - addr: Union[IPv4Address, IPv6Address], - port: int, - device_pubkey: bytes, - device_token: bytes, - read_timeout: int = 1): - super().__init__() - self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__} <{hex(id(self))}>') - self.setName(self.__class__.__name__) - - self.inseq = 0 - self.outseq = 0 - self.source_port = random.randint(1024, 65535) - self.device_addr = str(addr) - self.device_port = port - self.device_token = device_token - self.device_pubkey = device_pubkey - self.outgoing_queue = [] - self.response_handlers = {} - self.interrupted = False - self.outgoing_time = 0 - self.outgoing_time_1st = 0 - self.incoming_time = 0 - self.inc_listeners = [] - self.conn_listeners = [self] - self.status = ConnectionStatus.NOT_CONNECTED - self.reconnect_tries = 0 - self.read_timeout = read_timeout - - self._iml_lock = threading.Lock() - self._csl_lock = threading.Lock() - self._addr_lock = threading.Lock() - self._st_lock = threading.Lock() - - self.pubkey = None - self.encinkey = None - self.encoutkey = None - - def connection_status_updated(self, status: ConnectionStatus): - # self._logger.info(f'connection_status_updated: status = {status}') - with self._st_lock: - # self._logger.debug(f'connection_status_updated: lock acquired') - self.status = status - if status == ConnectionStatus.RECONNECTING: - self.reconnect_tries += 1 - if status in (ConnectionStatus.CONNECTED, ConnectionStatus.NOT_CONNECTED, ConnectionStatus.DISCONNECTED): - self.reconnect_tries = 0 - - def _cleanup(self): - # erase outgoing queue - for wm in self.outgoing_queue: - wm.call(False, - error_message=f'_cleanup: exception while calling cb(False) on message {wm.message}') - self.outgoing_queue = [] - self.response_handlers = {} - - # reset timestamps - self.incoming_time = 0 - self.outgoing_time = 0 - self.outgoing_time_1st = 0 - - self._logger.debug('_cleanup: done') - - def set_address(self, addr: Union[IPv4Address, IPv6Address], port: int): - with self._addr_lock: - if self.device_addr != str(addr) or self.device_port != port: - self.device_addr = str(addr) - self.device_port = port - self._logger.info(f'updated device network address: {self.device_addr}:{self.device_port}') - - def set_device_pubkey(self, pubkey: bytes): - if self.device_pubkey.hex() != pubkey.hex(): - self._logger.info(f'device pubkey has changed (old={self.device_pubkey.hex()}, new={pubkey.hex()})') - self.device_pubkey = pubkey - self._notify_cs(ConnectionStatus.RECONNECTING) - - def get_address(self) -> Tuple[str, int]: - with self._addr_lock: - return self.device_addr, self.device_port - - def add_incoming_message_listener(self, listener: IncomingMessageListener): - with self._iml_lock: - if listener not in self.inc_listeners: - self.inc_listeners.append(listener) - - def add_connection_status_listener(self, listener: ConnectionStatusListener): - with self._csl_lock: - if listener not in self.conn_listeners: - self.conn_listeners.append(listener) - - def _notify_cs(self, status: ConnectionStatus): - # self._logger.debug(f'_notify_cs: status={status}') - with self._csl_lock: - for obj in self.conn_listeners: - # self._logger.debug(f'_notify_cs: notifying {obj}') - obj.connection_status_updated(status) - - def _prepare_keys(self): - # generate key pair - privkey = X25519PrivateKey.generate() - - self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=srlz.Encoding.Raw, - format=srlz.PublicFormat.Raw))) - - # generate shared key - device_pubkey = X25519PublicKey.from_public_bytes( - bytes(reversed(self.device_pubkey)) - ) - shared_key = bytes(reversed( - privkey.exchange(device_pubkey) - )) - - # in/out encryption keys - digest = hashes.Hash(hashes.SHA256()) - digest.update(shared_key) - - shared_sha256 = digest.finalize() - - self.encinkey = shared_sha256[:16] - self.encoutkey = shared_sha256[16:] - - self._logger.info('encryption keys have been created') - - def _handshake_callback(self, r: MessageResponse): - # if got error for our HandshakeMessage, reset everything and try again - if r is False: - # self._logger.debug('_handshake_callback: set status=RECONNETING') - self._notify_cs(ConnectionStatus.RECONNECTING) - else: - # self._logger.debug('_handshake_callback: set status=CONNECTED') - self._notify_cs(ConnectionStatus.CONNECTED) - - def run(self): - self._logger.info('starting server loop') - - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.bind(('0.0.0.0', self.source_port)) - sock.settimeout(self.read_timeout) - - while not self.interrupted: - with self._st_lock: - status = self.status - - if status in (ConnectionStatus.DISCONNECTED, ConnectionStatus.RECONNECTING): - self._cleanup() - if status == ConnectionStatus.DISCONNECTED: - break - - # no activity for some time means connection is broken - fail = False - fail_path = 0 - if self.incoming_time > 0 and time.time() - self.incoming_time >= DISCONNECT_TIMEOUT: - fail = True - fail_path = 1 - elif self.outgoing_time_1st > 0 and self.incoming_time == 0 and time.time() - self.outgoing_time_1st >= DISCONNECT_TIMEOUT: - fail = True - fail_path = 2 - - if fail: - self._logger.debug(f'run: setting status=RECONNECTING because of long inactivity, fail_path={fail_path}') - self._notify_cs(ConnectionStatus.RECONNECTING) - - # establishing a connection - if status in (ConnectionStatus.RECONNECTING, ConnectionStatus.NOT_CONNECTED): - if status == ConnectionStatus.RECONNECTING and self.reconnect_tries >= 3: - self._notify_cs(ConnectionStatus.DISCONNECTED) - continue - - self._reset_outseq() - self._prepare_keys() - - # shake the imaginary kettle's hand - wrapped = WrappedMessage(HandshakeMessage(), - handler=self._handshake_callback, - validator=lambda m: isinstance(m, (AckMessage, HandshakeResponseMessage))) - self.enqueue_message(wrapped, prepend=True) - self._notify_cs(ConnectionStatus.CONNECTING) - - # pick next (wrapped) message to send - wm = self._get_next_message() # wm means "wrapped message" - if wm: - one_shot = isinstance(wm.message, (AckMessage, NakMessage)) - - if not isinstance(wm.message, (AckMessage, NakMessage)): - old_seq = wm.seq - wm.seq = self.outseq - self._set_response_handler(wm, old_seq=old_seq) - elif wm.seq is None: - # ack/nak is a response to some incoming message (and it must have the same seqno that incoming - # message had) - raise RuntimeError(f'run: seq must be set for {wm.__class__.__name__}') - - self._logger.debug(f'run: sending message: {wm.message}, one_shot={one_shot}, phase={wm.phase}') - encrypted = False - try: - wm.message.encrypt(outkey=self.encoutkey, inkey=self.encinkey, - token=self.device_token, pubkey=self.pubkey) - encrypted = True - except ValueError as exc: - # handle "ValueError: Invalid padding bytes." - self._logger.error('run: failed to encrypt the message.') - self._logger.exception(exc) - - if encrypted: - buf = wm.message.frame.pack() - # self._logger.debug(f'run: raw data to be sent: {buf.hex()}') - - # sending the first time - if wm.phase == MessagePhase.WAITING: - sock.sendto(buf, self.get_address()) - # resending - elif wm.phase == MessagePhase.SENT: - left = RESEND_ATTEMPTS - while left > 0: - sock.sendto(buf, self.get_address()) - left -= 1 - if left > 0: - time.sleep(0.05) - - if one_shot or wm.phase == MessagePhase.SENT: - wm.phase = MessagePhase.DONE - else: - wm.phase = MessagePhase.SENT - - now = time.time() - self.outgoing_time = now - if not self.outgoing_time_1st: - self.outgoing_time_1st = now - - # receiving data - try: - data = sock.recv(4096) - self._handle_incoming(data) - except (TimeoutError, socket.timeout): - pass - - self._logger.info('bye...') - - def _get_next_message(self) -> Optional[WrappedMessage]: - message = None - lpfx = '_get_next_message:' - remove_list = [] - for wm in self.outgoing_queue: - if wm.phase == MessagePhase.DONE: - if isinstance(wm.message, (AckMessage, NakMessage, PingMessage)) or time.time() - wm.phase_update_time >= MESSAGE_QUEUE_REMOVE_DELAY: - remove_list.append(wm) - continue - message = wm - break - - for wm in remove_list: - self._logger.debug(f'{lpfx} rm path: removing id={wm.id} seq={wm.seq}') - - # clear message handler - if wm.seq in self.response_handlers: - self.response_handlers[wm.seq].call( - False, error_message=f'{lpfx} rm path: error while calling callback for seq={wm.seq}') - del self.response_handlers[wm.seq] - - # remove from queue - try: - self.outgoing_queue.remove(wm) - except ValueError as exc: - self._logger.error(f'{lpfx} rm path: removing from outgoing_queue raised an exception: {str(exc)}') - - # ping pong - if not message and self.outgoing_time_1st != 0 and self.status == ConnectionStatus.CONNECTED: - now = time.time() - out_delta = now - self.outgoing_time - in_delta = now - self.incoming_time - if max(out_delta, in_delta) > PING_FREQUENCY: - self._logger.debug(f'{lpfx} no activity: in for {in_delta:.2f}s, out for {out_delta:.2f}s, time to ping the damn thing') - message = WrappedMessage(PingMessage(), ack=True) - # add it to outgoing_queue in order to be aggressively resent in future (if needed) - self.outgoing_queue.insert(0, message) - - return message - - def _handle_incoming(self, buf: bytes): - try: - incoming_message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey) - except ValueError as exc: - # handle "ValueError: Invalid padding bytes." - self._logger.error('_handle_incoming: failed to decrypt incoming frame:') - self._logger.exception(exc) - return - - self.incoming_time = time.time() - seq = incoming_message.seq - - lpfx = f'handle_incoming({incoming_message.id}):' - self._logger.debug(f'{lpfx} received: {incoming_message}') - - if isinstance(incoming_message, (AckMessage, NakMessage)): - seq_max = self.outseq - seq_name = 'outseq' - else: - seq_max = self.inseq - seq_name = 'inseq' - self.inseq = seq - - if seq < seq_max < 0xfd: - self._logger.debug(f'{lpfx} dropping: seq={seq}, {seq_name}={seq_max}') - return - - if seq not in self.response_handlers: - self._handle_incoming_cmd(incoming_message) - return - - callback_value = None # None means don't call a callback - handler = self.response_handlers[seq] - - if handler.validate(incoming_message): - self._logger.debug(f'{lpfx} response OK') - handler.phase = MessagePhase.DONE - callback_value = incoming_message - self._incr_outseq() - else: - self._logger.warning(f'{lpfx} response is INVALID') - - # It seems that we've received an incoming CmdMessage or PingMessage with the same seqno that our outgoing - # message had. Bad, but what can I say, this is quick-and-dirty made UDP based protocol and this sort of - # shit just happens. - - # (To be fair, maybe my implementation is not perfect either. But hey, what did you expect from a - # reverse-engineered re-implementation of custom UDP-based protocol that some noname vendor uses for their - # cheap IoT devices? I think _that_ is _the_ definition of shit. At least my implementation is FOSS, which - # is more than you'll ever be able to say about them.) - - # All this crapload of code below might not be needed at all, 'cause the protocol uses separate frame seq - # numbers for IN and OUT frames and this situation is not highly likely, as Theresa May could argue. - # After a handshake, a kettle sends us 10 or so CmdMessages, and then either we continuously ping it every - # 3 seconds, or kettle pings us. This in any case widens the gap between inseq and outseq. - - # But! the seqno is only 1 byte in size and once it reaches 0xff, it circles back to zero. And that (plus, - # perhaps, some bad luck) gives a chance for a collision. - - if handler.phase == MessagePhase.DONE or isinstance(handler.message, HandshakeMessage): - # no more attempts left, returning error back to user - # as to handshake, it cannot fail. - callback_value = False - - # else: - # # try resending the message - # handler.phase_reset() - # max_seq = self.outseq - # wait_remap = {} - # for m in self.outgoing_queue: - # if m.seq in self.waiting_for_response: - # wait_remap[m.seq] = (m.seq+1) % 256 - # m.set_seq((m.seq+1) % 256) - # if m.seq > max_seq: - # max_seq = m.seq - # if max_seq > self.outseq: - # self.outseq = max_seq % 256 - # if wait_remap: - # waiting_new = {} - # for old_seq, new_seq in wait_remap.items(): - # waiting_new[new_seq] = self.waiting_for_response[old_seq] - # self.waiting_for_response = waiting_new - - if isinstance(incoming_message, (PingMessage, CmdIncomingMessage)): - # handle incoming message as usual, as we need to ack/nak it anyway - self._handle_incoming_cmd(incoming_message) - - if callback_value is not None: - handler.call(callback_value, - error_message=f'{lpfx} error while calling callback for msg id={handler.message.id} seq={seq}') - del self.response_handlers[seq] - - def _handle_incoming_cmd(self, incoming_message: Message): - if isinstance(incoming_message, (AckMessage, NakMessage)): - self._logger.debug(f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): it\'s {incoming_message.__class__.__name__}, ignoring') - return - - replied = False - with self._iml_lock: - for f in self.inc_listeners: - retval = safe_callback_call(f.incoming_message, incoming_message, - logger=self._logger, - error_message=f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): error while calling message listener') - if isinstance(retval, Message): - if isinstance(retval, (AckMessage, NakMessage)): - retval.seq = incoming_message.seq - self.enqueue_message(WrappedMessage(retval), prepend=True) - replied = True - break - else: - raise RuntimeError('are you sure your response is correct? only ack/nak are allowed') - - if not replied: - self.enqueue_message(WrappedMessage(AckMessage(incoming_message.seq)), prepend=True) - - def enqueue_message(self, wrapped: WrappedMessage, prepend=False): - self._logger.debug(f'enqueue_message: {wrapped.message}') - if not prepend: - self.outgoing_queue.append(wrapped) - else: - self.outgoing_queue.insert(0, wrapped) - - def _set_response_handler(self, wm: WrappedMessage, old_seq=None): - if old_seq in self.response_handlers: - del self.response_handlers[old_seq] - - seq = wm.seq - assert seq is not None, 'seq is not set' - - if seq in self.response_handlers: - self._logger.debug(f'_set_response_handler(seq={seq}): handler is already set, cancelling it') - self.response_handlers[seq].call(False, - error_message=f'_set_response_handler({seq}): error while calling old callback') - self.response_handlers[seq] = wm - - def _incr_outseq(self) -> None: - self.outseq = (self.outseq + 1) % 256 - - def _reset_outseq(self): - self.outseq = 0 - self._logger.debug(f'_reset_outseq: set 0') - - -MessageResponse = Union[Message, bool] diff --git a/src/polaris_kettle_bot.py b/src/polaris_kettle_bot.py index c031cb1..c48e592 100755 --- a/src/polaris_kettle_bot.py +++ b/src/polaris_kettle_bot.py @@ -13,7 +13,7 @@ from home.api.types import BotType from home.mqtt import MQTTBase from home.config import config from home.util import chunks -from polaris import ( +from syncleo import ( Kettle, PowerType, DeviceListener, @@ -21,7 +21,7 @@ from polaris import ( ConnectionStatusListener, ConnectionStatus ) -import polaris.protocol as kettle_proto +import syncleo.protocol as kettle_proto from typing import Optional, Tuple, List, Union from collections import namedtuple from functools import partial diff --git a/src/polaris_kettle_util.py b/src/polaris_kettle_util.py index 82b2588..61c1c7d 100755 --- a/src/polaris_kettle_util.py +++ b/src/polaris_kettle_util.py @@ -10,7 +10,7 @@ from argparse import ArgumentParser from queue import SimpleQueue from home.mqtt import MQTTBase from home.config import config -from polaris import ( +from syncleo import ( Kettle, PowerType, protocol as kettle_proto diff --git a/src/syncleo/__init__.py b/src/syncleo/__init__.py new file mode 100644 index 0000000..32563a5 --- /dev/null +++ b/src/syncleo/__init__.py @@ -0,0 +1,12 @@ +# Polaris PWK 1725CGLD "smart" kettle python library +# -------------------------------------------------- +# Copyright (C) Evgeny Zinoviev, 2022 +# License: BSD-3c + +from .kettle import Kettle, DeviceListener +from .protocol import ( + PowerType, + IncomingMessageListener, + ConnectionStatusListener, + ConnectionStatus +) diff --git a/src/syncleo/kettle.py b/src/syncleo/kettle.py new file mode 100644 index 0000000..d6e0dd6 --- /dev/null +++ b/src/syncleo/kettle.py @@ -0,0 +1,243 @@ +# Polaris PWK 1725CGLD smart kettle python library +# ------------------------------------------------ +# Copyright (C) Evgeny Zinoviev, 2022 +# License: BSD-3c + +from __future__ import annotations + +import threading +import logging +import zeroconf + +from abc import abstractmethod +from ipaddress import ip_address, IPv4Address, IPv6Address +from typing import Optional, List, Union + +from .protocol import ( + UDPConnection, + ModeMessage, + TargetTemperatureMessage, + PowerType, + ConnectionStatus, + ConnectionStatusListener, + WrappedMessage +) + + +class DeviceDiscover(threading.Thread, zeroconf.ServiceListener): + si: Optional[zeroconf.ServiceInfo] + _mac: str + _sb: Optional[zeroconf.ServiceBrowser] + _zc: Optional[zeroconf.Zeroconf] + _listeners: List[DeviceListener] + _valid_addresses: List[Union[IPv4Address, IPv6Address]] + _only_ipv4: bool + + def __init__(self, mac: str, + listener: Optional[DeviceListener] = None, + only_ipv4=True): + super().__init__() + self.si = None + self._mac = mac + self._zc = None + self._sb = None + self._only_ipv4 = only_ipv4 + self._valid_addresses = [] + self._listeners = [] + if isinstance(listener, DeviceListener): + self._listeners.append(listener) + self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}') + + def add_listener(self, listener: DeviceListener): + if listener not in self._listeners: + self._listeners.append(listener) + else: + self._logger.warning(f'add_listener: listener {listener} already in the listeners list') + + def set_info(self, info: zeroconf.ServiceInfo): + valid_addresses = self._get_valid_addresses(info) + if not valid_addresses: + raise ValueError('no valid addresses') + self._valid_addresses = valid_addresses + self.si = info + for f in self._listeners: + try: + f.device_updated() + except Exception as exc: + self._logger.error(f'set_info: error while calling device_updated on {f}') + self._logger.exception(exc) + + def add_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: + self._add_update_service('add_service', zc, type_, name) + + def update_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: + self._add_update_service('update_service', zc, type_, name) + + def _add_update_service(self, method: str, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: + info = zc.get_service_info(type_, name) + if name.startswith(f'{self._mac}.'): + self._logger.info(f'{method}: type={type_} name={name}') + try: + self.set_info(info) + except ValueError as exc: + self._logger.error(f'{method}: rejected: {str(exc)}') + else: + self._logger.debug(f'{method}: mac not matched: {info}') + + def remove_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: + if name.startswith(f'{self._mac}.'): + self._logger.info(f'remove_service: type={type_} name={name}') + # TODO what to do here?! + + def run(self): + self._logger.debug('starting zeroconf service browser') + ip_version = zeroconf.IPVersion.V4Only if self._only_ipv4 else zeroconf.IPVersion.All + self._zc = zeroconf.Zeroconf(ip_version=ip_version) + self._sb = zeroconf.ServiceBrowser(self._zc, "_syncleo._udp.local.", self) + self._sb.join() + + def stop(self): + if self._sb: + try: + self._sb.cancel() + except RuntimeError: + pass + self._sb = None + self._zc.close() + self._zc = None + + def _get_valid_addresses(self, si: zeroconf.ServiceInfo) -> List[Union[IPv4Address, IPv6Address]]: + valid = [] + for addr in map(ip_address, si.addresses): + if self._only_ipv4 and not isinstance(addr, IPv4Address): + continue + if isinstance(addr, IPv4Address) and str(addr).startswith('169.254.'): + continue + valid.append(addr) + return valid + + @property + def pubkey(self) -> bytes: + return bytes.fromhex(self.si.properties[b'public'].decode()) + + @property + def curve(self) -> int: + return int(self.si.properties[b'curve'].decode()) + + @property + def addr(self) -> Union[IPv4Address, IPv6Address]: + return self._valid_addresses[0] + + @property + def port(self) -> int: + return int(self.si.port) + + @property + def protocol(self) -> int: + return int(self.si.properties[b'protocol'].decode()) + + +class DeviceListener: + @abstractmethod + def device_updated(self): + pass + + +class Kettle(DeviceListener, ConnectionStatusListener): + mac: str + device: Optional[DeviceDiscover] + device_token: str + conn: Optional[UDPConnection] + conn_status: Optional[ConnectionStatus] + _read_timeout: Optional[int] + _logger: logging.Logger + _find_evt: threading.Event + + def __init__(self, mac: str, device_token: str, read_timeout: Optional[int] = None): + super().__init__() + self.mac = mac + self.device = None + self.device_token = device_token + self.conn = None + self.conn_status = None + self._read_timeout = read_timeout + self._find_evt = threading.Event() + self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}') + + def device_updated(self): + self._find_evt.set() + self._logger.info(f'device updated, service info: {self.device.si}') + + def connection_status_updated(self, status: ConnectionStatus): + self.conn_status = status + + def discover(self, wait=True, timeout=None, listener=None) -> Optional[zeroconf.ServiceInfo]: + do_start = False + if not self.device: + self.device = DeviceDiscover(self.mac, listener=self, only_ipv4=True) + do_start = True + self._logger.debug('discover: started device discovery') + else: + self._logger.warning('discover: already started') + + if listener is not None: + self.device.add_listener(listener) + + if do_start: + self.device.start() + + if wait: + self._find_evt.clear() + try: + self._find_evt.wait(timeout=timeout) + except KeyboardInterrupt: + self.device.stop() + return None + return self.device.si + + def start_server_if_needed(self, + incoming_message_listener=None, + connection_status_listener=None): + if self.conn: + self._logger.warning('start_server_if_needed: server is already started!') + self.conn.set_address(self.device.addr, self.device.port) + self.conn.set_device_pubkey(self.device.pubkey) + return + + assert self.device.curve == 29, f'curve type {self.device.curve} is not implemented' + assert self.device.protocol == 2, f'protocol {self.device.protocol} is not supported' + + kw = {} + if self._read_timeout is not None: + kw['read_timeout'] = self._read_timeout + self.conn = UDPConnection(addr=self.device.addr, + port=self.device.port, + device_pubkey=self.device.pubkey, + device_token=bytes.fromhex(self.device_token), **kw) + if incoming_message_listener: + self.conn.add_incoming_message_listener(incoming_message_listener) + + self.conn.add_connection_status_listener(self) + if connection_status_listener: + self.conn.add_connection_status_listener(connection_status_listener) + + self.conn.start() + + def stop_all(self): + # when we stop server, we should also stop device discovering service + if self.conn: + self.conn.interrupted = True + self.conn = None + self.device.stop() + self.device = None + + def is_connected(self) -> bool: + return self.conn is not None and self.conn_status == ConnectionStatus.CONNECTED + + def set_power(self, power_type: PowerType, callback: callable): + message = ModeMessage(power_type) + self.conn.enqueue_message(WrappedMessage(message, handler=callback, ack=True)) + + def set_target_temperature(self, temp: int, callback: callable): + message = TargetTemperatureMessage(temp) + self.conn.enqueue_message(WrappedMessage(message, handler=callback, ack=True)) diff --git a/src/syncleo/protocol.py b/src/syncleo/protocol.py new file mode 100644 index 0000000..36a1a8f --- /dev/null +++ b/src/syncleo/protocol.py @@ -0,0 +1,1169 @@ +# Polaris PWK 1725CGLD "smart" kettle python library +# -------------------------------------------------- +# Copyright (C) Evgeny Zinoviev, 2022 +# License: BSD-3c + +from __future__ import annotations + +import logging +import socket +import random +import struct +import threading +import time + +from abc import abstractmethod, ABC +from enum import Enum, auto +from typing import Union, Optional, Dict, Tuple, List +from ipaddress import IPv4Address, IPv6Address + +import cryptography.hazmat.primitives._serialization as srlz + +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey +from cryptography.hazmat.primitives import ciphers, padding, hashes +from cryptography.hazmat.primitives.ciphers import algorithms, modes + +ReprDict = Dict[str, Union[str, int, float, bool]] +_logger = logging.getLogger(__name__) + +PING_FREQUENCY = 3 +RESEND_ATTEMPTS = 5 +ERROR_TIMEOUT = 15 +MESSAGE_QUEUE_REMOVE_DELAY = 13 # after what time to delete (and pass False to handlers, if needed) messages with phase=DONE from queue +DISCONNECT_TIMEOUT = 15 + + +def safe_callback_call(f: callable, + *args, + logger: logging.Logger = None, + error_message: str = None): + try: + return f(*args) + except Exception as exc: + logger.error(f'{error_message}, see exception below:') + logger.exception(exc) + return None + + +# drop-in replacement for java.lang.System.arraycopy +# TODO: rewrite +def arraycopy(src, src_pos, dest, dest_pos, length): + for i in range(length): + dest[i + dest_pos] = src[i + src_pos] + + +# "convert" unsigned byte to signed +def u8_to_s8(b: int) -> int: + return struct.unpack('b', bytes([b]))[0] + + +class PowerType(Enum): + OFF = 0 # turn off + ON = 1 # turn on, set target temperature to 100 + CUSTOM = 3 # turn on, allows custom target temperature + # MYSTERY_MODE = 2 # don't know what 2 means, needs testing + # update: if I set it to '2', it just resets to '0' + + +# low-level protocol structures +# ----------------------------- + +class FrameType(Enum): + ACK = 0 + CMD = 1 + AUX = 2 + NAK = 3 + + +class FrameHead: + seq: Optional[int] # u8 + type: FrameType # u8 + length: int # u16. This is the length of FrameItem's payload + + @staticmethod + def from_bytes(buf: bytes) -> FrameHead: + seq, ft, length = struct.unpack(' bytes: + assert self.length != 0, "FrameHead.length has not been set" + assert self.seq is not None, "FrameHead.seq has not been set" + return struct.pack(' bytes: + ba = bytearray(self.head.pack()) + ba.extend(self.payload) + return bytes(ba) + + +# high-level wrappers around FrameItem +# ------------------------------------ + +class MessagePhase(Enum): + WAITING = 0 + SENT = 1 + DONE = 2 + + +class Message: + frame: Optional[FrameItem] + id: int + + _global_id = 0 + + def __init__(self): + self.frame = None + + # global internal message id, only useful for debugging purposes + self.id = self.next_id() + + def __repr__(self): + return f'<{self.__class__.__name__} id={self.id} seq={self.frame.head.seq}>' + + @staticmethod + def next_id(): + _id = Message._global_id + Message._global_id = (Message._global_id + 1) % 100000 + return _id + + @staticmethod + def from_encrypted(buf: bytes, inkey: bytes, outkey: bytes) -> Message: + _logger.debug(f'Message:from_encrypted: buf={buf.hex()}') + + assert len(buf) >= 4, 'invalid size' + head = FrameHead.from_bytes(buf[:4]) + + assert len(buf) == head.length + 4, f'invalid buf size ({len(buf)} != {head.length})' + payload = buf[4:] + b = head.seq + + j = b & 0xF + k = b >> 4 & 0xF + + key = bytearray(len(inkey)) + arraycopy(inkey, j, key, 0, len(inkey) - j) + arraycopy(inkey, 0, key, len(inkey) - j, j) + + iv = bytearray(len(outkey)) + arraycopy(outkey, k, iv, 0, len(outkey) - k) + arraycopy(outkey, 0, iv, len(outkey) - k, k) + + cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv)) + decryptor = cipher.decryptor() + decrypted_data = decryptor.update(payload) + decryptor.finalize() + + unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() + decrypted_data = unpadder.update(decrypted_data) + decrypted_data += unpadder.finalize() + + assert len(decrypted_data) != 0, 'decrypted data is null' + assert head.seq == decrypted_data[0], f'decrypted seq mismatch {head.seq} != {decrypted_data[0]}' + + # _logger.debug('Message.from_encrypted: plaintext: '+decrypted_data.hex()) + + if head.type == FrameType.ACK: + return AckMessage(head.seq) + + elif head.type == FrameType.NAK: + return NakMessage(head.seq) + + elif head.type == FrameType.AUX: + # TODO implement AUX + raise NotImplementedError('FrameType AUX is not yet implemented') + + elif head.type == FrameType.CMD: + type = decrypted_data[1] + data = decrypted_data[2:] + + cl = UnknownMessage + + subclasses = [cl for cl in CmdIncomingMessage.__subclasses__() if cl is not SimpleBooleanMessage] + subclasses.extend(SimpleBooleanMessage.__subclasses__()) + + for _cl in subclasses: + # `UnknownMessage` is a special class that holds a packed command that we don't recognize. + # It will be used anyway if we don't find a match, so skip it here + if _cl == UnknownMessage: + continue + + if _cl.TYPE == type: + cl = _cl + break + + m = cl.from_packed_data(data, seq=head.seq) + if isinstance(m, UnknownMessage): + m.set_type(type) + return m + + else: + raise NotImplementedError(f'Unexpected frame type: {head.type}') + + def pack_data(self) -> bytes: + return b'' + + @property + def seq(self) -> Union[int, None]: + try: + return self.frame.head.seq + except: + return None + + @seq.setter + def seq(self, seq: int): + self.frame.head.seq = seq + + def encrypt(self, outkey: bytes, inkey: bytes, token: bytes, pubkey: bytes): + assert self.frame is not None + + data = self._get_data_to_encrypt() + assert data is not None + + b = self.frame.head.seq + i = b & 0xf + j = b >> 4 & 0xf + + outkey = bytearray(outkey) + + l = len(outkey) + key = bytearray(l) + + arraycopy(outkey, i, key, 0, l-i) + arraycopy(outkey, 0, key, l-i, i) + + inkey = bytearray(inkey) + + l = len(inkey) + iv = bytearray(l) + + arraycopy(inkey, j, iv, 0, l-j) + arraycopy(inkey, 0, iv, l-j, j) + + cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv)) + encryptor = cipher.encryptor() + + newdata = bytearray(len(data)+1) + newdata[0] = b + + arraycopy(data, 0, newdata, 1, len(data)) + + newdata = bytes(newdata) + _logger.debug('frame payload to be encrypted: ' + newdata.hex()) + + padder = padding.PKCS7(algorithms.AES.block_size).padder() + ciphertext = bytearray() + ciphertext.extend(encryptor.update(padder.update(newdata) + padder.finalize())) + ciphertext.extend(encryptor.finalize()) + + self.frame.setpayload(ciphertext) + + def _get_data_to_encrypt(self) -> bytes: + return self.pack_data() + + +class AckMessage(Message, ABC): + def __init__(self, seq: Optional[int] = None): + super().__init__() + self.frame = FrameItem(FrameHead(seq, FrameType.ACK, None)) + + +class NakMessage(Message, ABC): + def __init__(self, seq: Optional[int] = None): + super().__init__() + self.frame = FrameItem(FrameHead(seq, FrameType.NAK, None)) + + +class CmdMessage(Message): + type: Optional[int] + data: bytes + + TYPE = None + + def _get_data_to_encrypt(self) -> bytes: + buf = bytearray() + buf.append(self.get_type()) + buf.extend(self.pack_data()) + return bytes(buf) + + def __init__(self, seq: Optional[int] = None): + super().__init__() + self.frame = FrameItem(FrameHead(seq, FrameType.CMD)) + self.data = b'' + + def _repr_fields(self) -> ReprDict: + return { + 'cmd': self.get_type() + } + + def __repr__(self): + params = [ + __name__+'.'+self.__class__.__name__, + f'id={self.id}', + f'seq={self.seq}' + ] + fields = self._repr_fields() + if fields: + for k, v in fields.items(): + params.append(f'{k}={v}') + elif self.data: + params.append(f'data={self.data.hex()}') + return '<'+' '.join(params)+'>' + + def get_type(self) -> int: + return self.__class__.TYPE + + +class CmdIncomingMessage(CmdMessage): + @staticmethod + @abstractmethod + def from_packed_data(cls, data: bytes, seq: Optional[int] = None): + pass + + @abstractmethod + def _repr_fields(self) -> ReprDict: + pass + + +class CmdOutgoingMessage(CmdMessage): + @abstractmethod + def pack_data(self) -> bytes: + return b'' + + +class ModeMessage(CmdOutgoingMessage, CmdIncomingMessage): + TYPE = 1 + + pt: PowerType + + def __init__(self, power_type: PowerType, seq: Optional[int] = None): + super().__init__(seq) + self.pt = power_type + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> ModeMessage: + assert len(data) == 1, 'data size expected to be 1' + mode, = struct.unpack('B', data) + return ModeMessage(PowerType(mode), seq=seq) + + def pack_data(self) -> bytes: + return self.pt.value.to_bytes(1, byteorder='little') + + def _repr_fields(self) -> ReprDict: + return {'mode': self.pt.name} + + +class TargetTemperatureMessage(CmdOutgoingMessage, CmdIncomingMessage): + temperature: int + + TYPE = 2 + + def __init__(self, temp: int, seq: Optional[int] = None): + super().__init__(seq) + self.temperature = temp + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> TargetTemperatureMessage: + assert len(data) == 2, 'data size expected to be 2' + nat, frac = struct.unpack('BB', data) + temp = int(nat + (frac / 100)) + return TargetTemperatureMessage(temp, seq=seq) + + def pack_data(self) -> bytes: + return bytes([self.temperature, 0]) + + def _repr_fields(self) -> ReprDict: + return {'temperature': self.temperature} + + +class PingMessage(CmdIncomingMessage, CmdOutgoingMessage): + TYPE = 255 + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> PingMessage: + assert len(data) == 0, 'no data expected' + return PingMessage(seq=seq) + + def pack_data(self) -> bytes: + return b'' + + def _repr_fields(self) -> ReprDict: + return {} + + +# This is the first protocol message. Sent by a client. +# Kettle usually ACKs this, but sometimes i don't get any ACK and the very next message is HandshakeResponseMessage. +class HandshakeMessage(CmdMessage): + TYPE = 0 + + def encrypt(self, + outkey: bytes, + inkey: bytes, + token: bytes, + pubkey: bytes): + cipher = ciphers.Cipher(algorithms.AES(outkey), modes.CBC(inkey)) + encryptor = cipher.encryptor() + + ciphertext = bytearray() + ciphertext.extend(encryptor.update(token)) + ciphertext.extend(encryptor.finalize()) + + pld = bytearray() + pld.append(0) + pld.extend(pubkey) + pld.extend(ciphertext) + + self.frame.setpayload(pld) + + +# Kettle either sends this right after the handshake, of first it ACKs the handshake then sends this. +class HandshakeResponseMessage(CmdIncomingMessage): + TYPE = 0 + + protocol: int + fw_major: int + fw_minor: int + mode: int + token: bytes + + def __init__(self, + protocol: int, + fw_major: int, + fw_minor: int, + mode: int, + token: bytes, + seq: Optional[int] = None): + super().__init__(seq) + self.protocol = protocol + self.fw_major = fw_major + self.fw_minor = fw_minor + self.mode = mode + self.token = token + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> HandshakeResponseMessage: + protocol, fw_major, fw_minor, mode = struct.unpack(' ReprDict: + return { + 'protocol': self.protocol, + 'fw': f'{self.fw_major}.{self.fw_minor}', + 'mode': self.mode, + 'token': self.token.hex() + } + + +# Apparently, some hardware info. +# On the other hand, if you look at com.syncleiot.iottransport.commands.CmdHardware, its mqtt topic says "mcu_firmware". +# My device returns 1.1.1. The kettle uses on ESP8266 ESP-12F MCU under the hood (or, more precisely, under a piece of +# cheap plastic), so maybe 1.1.1 is some MCU ROM version. +class DeviceHardwareMessage(CmdIncomingMessage): + TYPE = 143 # -113 + + hw: List[int] + + def __init__(self, hw: List[int], seq: Optional[int] = None): + super().__init__(seq) + self.hw = hw + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> DeviceHardwareMessage: + assert len(data) == 3, 'invalid data size, expected 3' + hw = list(struct.unpack(' ReprDict: + return {'device_hardware': '.'.join(map(str, self.hw))} + + +# This message is sent by kettle right after the HandshakeMessageResponse. +# The diagnostic data is supposed to be sent to vendor, which we, obviously, not going to do. +# So just ACK and skip it. +class DeviceDiagnosticMessage(CmdIncomingMessage): + TYPE = 145 # -111 + + diag_data: bytes + + def __init__(self, diag_data: bytes, seq: Optional[int] = None): + super().__init__(seq) + self.diag_data = diag_data + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> DeviceDiagnosticMessage: + return DeviceDiagnosticMessage(diag_data=data, seq=seq) + + def _repr_fields(self) -> ReprDict: + return {'diag_data': self.diag_data.hex()} + + +class SimpleBooleanMessage(ABC, CmdIncomingMessage): + value: bool + + def __init__(self, value: bool, seq: Optional[int] = None): + super().__init__(seq) + self.value = value + + @classmethod + def from_packed_data(cls, data: bytes, seq: Optional[int] = None): + assert len(data) == 1, 'invalid data size, expected 1' + enabled, = struct.unpack(' ReprDict: + pass + + +class AccessControlMessage(SimpleBooleanMessage): + TYPE = 133 # -123 + + def _repr_fields(self) -> ReprDict: + return {'acl_enabled': self.value} + + +class ErrorMessage(SimpleBooleanMessage): + TYPE = 7 + + def _repr_fields(self) -> ReprDict: + return {'error': self.value} + + +class ChildLockMessage(SimpleBooleanMessage): + TYPE = 30 + + def _repr_fields(self) -> ReprDict: + return {'child_lock': self.value} + + +class VolumeMessage(SimpleBooleanMessage): + TYPE = 9 + + def _repr_fields(self) -> ReprDict: + return {'volume': self.value} + + +class BacklightMessage(SimpleBooleanMessage): + TYPE = 28 + + def _repr_fields(self) -> ReprDict: + return {'backlight': self.value} + + +class CurrentTemperatureMessage(CmdIncomingMessage): + TYPE = 20 + + current_temperature: int + + def __init__(self, temp: int, seq: Optional[int] = None): + super().__init__(seq) + self.current_temperature = temp + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> CurrentTemperatureMessage: + assert len(data) == 2, 'data size expected to be 2' + nat, frac = struct.unpack('BB', data) + temp = int(nat + (frac / 100)) + return CurrentTemperatureMessage(temp, seq=seq) + + def pack_data(self) -> bytes: + return bytes([self.current_temperature, 0]) + + def _repr_fields(self) -> ReprDict: + return {'current_temperature': self.current_temperature} + + +class UnknownMessage(CmdIncomingMessage): + type: Optional[int] + data: bytes + + def __init__(self, data: bytes, **kwargs): + super().__init__(**kwargs) + self.type = None + self.data = data + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> UnknownMessage: + return UnknownMessage(data, seq=seq) + + def set_type(self, type: int): + self.type = type + + def get_type(self) -> int: + return self.type + + def _repr_fields(self) -> ReprDict: + return { + 'type': self.type, + 'data': self.data.hex() + } + + +class WrappedMessage: + _message: Message + _handler: Optional[callable] + _validator: Optional[callable] + _logger: Optional[logging.Logger] + _phase: MessagePhase + _phase_update_time: float + + def __init__(self, + message: Message, + handler: Optional[callable] = None, + validator: Optional[callable] = None, + ack=False): + self._message = message + self._handler = handler + self._validator = validator + self._logger = None + self._phase = MessagePhase.WAITING + self._phase_update_time = 0 + if not validator and ack: + self._validator = lambda m: isinstance(m, AckMessage) + + def setlogger(self, logger: logging.Logger): + self._logger = logger + + def validate(self, message: Message): + if not self._validator: + return True + return self._validator(message) + + def call(self, *args, error_message: str = None) -> None: + if not self._handler: + return + try: + self._handler(*args) + except Exception as exc: + logger = self._logger or logging.getLogger(self.__class__.__name__) + logger.error(f'{error_message}, see exception below:') + logger.exception(exc) + + @property + def phase(self) -> MessagePhase: + return self._phase + + @phase.setter + def phase(self, phase: MessagePhase): + self._phase = phase + self._phase_update_time = 0 if phase == MessagePhase.WAITING else time.time() + + @property + def phase_update_time(self) -> float: + return self._phase_update_time + + @property + def message(self) -> Message: + return self._message + + @property + def id(self) -> int: + return self._message.id + + @property + def seq(self) -> int: + return self._message.seq + + @seq.setter + def seq(self, seq: int): + self._message.seq = seq + + def __repr__(self): + return f'<{__name__}.{self.__class__.__name__} message={self._message.__repr__()}>' + + +# Connection stuff +# Well, strictly speaking, as it's UDP, there's no connection, but who cares. +# --------------------------------------------------------------------------- + +class IncomingMessageListener: + @abstractmethod + def incoming_message(self, message: Message) -> Optional[Message]: + pass + + +class ConnectionStatus(Enum): + NOT_CONNECTED = auto() + CONNECTING = auto() + CONNECTED = auto() + RECONNECTING = auto() + DISCONNECTED = auto() + + +class ConnectionStatusListener: + @abstractmethod + def connection_status_updated(self, status: ConnectionStatus): + pass + + +class UDPConnection(threading.Thread, ConnectionStatusListener): + inseq: int + outseq: int + source_port: int + device_addr: str + device_port: int + device_token: bytes + device_pubkey: bytes + interrupted: bool + response_handlers: Dict[int, WrappedMessage] + outgoing_queue: List[WrappedMessage] + pubkey: Optional[bytes] + encinkey: Optional[bytes] + encoutkey: Optional[bytes] + inc_listeners: List[IncomingMessageListener] + conn_listeners: List[ConnectionStatusListener] + outgoing_time: float + outgoing_time_1st: float + incoming_time: float + status: ConnectionStatus + reconnect_tries: int + read_timeout: int + + _addr_lock: threading.Lock + _iml_lock: threading.Lock + _csl_lock: threading.Lock + _st_lock: threading.Lock + + def __init__(self, + addr: Union[IPv4Address, IPv6Address], + port: int, + device_pubkey: bytes, + device_token: bytes, + read_timeout: int = 1): + super().__init__() + self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__} <{hex(id(self))}>') + self.setName(self.__class__.__name__) + + self.inseq = 0 + self.outseq = 0 + self.source_port = random.randint(1024, 65535) + self.device_addr = str(addr) + self.device_port = port + self.device_token = device_token + self.device_pubkey = device_pubkey + self.outgoing_queue = [] + self.response_handlers = {} + self.interrupted = False + self.outgoing_time = 0 + self.outgoing_time_1st = 0 + self.incoming_time = 0 + self.inc_listeners = [] + self.conn_listeners = [self] + self.status = ConnectionStatus.NOT_CONNECTED + self.reconnect_tries = 0 + self.read_timeout = read_timeout + + self._iml_lock = threading.Lock() + self._csl_lock = threading.Lock() + self._addr_lock = threading.Lock() + self._st_lock = threading.Lock() + + self.pubkey = None + self.encinkey = None + self.encoutkey = None + + def connection_status_updated(self, status: ConnectionStatus): + # self._logger.info(f'connection_status_updated: status = {status}') + with self._st_lock: + # self._logger.debug(f'connection_status_updated: lock acquired') + self.status = status + if status == ConnectionStatus.RECONNECTING: + self.reconnect_tries += 1 + if status in (ConnectionStatus.CONNECTED, ConnectionStatus.NOT_CONNECTED, ConnectionStatus.DISCONNECTED): + self.reconnect_tries = 0 + + def _cleanup(self): + # erase outgoing queue + for wm in self.outgoing_queue: + wm.call(False, + error_message=f'_cleanup: exception while calling cb(False) on message {wm.message}') + self.outgoing_queue = [] + self.response_handlers = {} + + # reset timestamps + self.incoming_time = 0 + self.outgoing_time = 0 + self.outgoing_time_1st = 0 + + self._logger.debug('_cleanup: done') + + def set_address(self, addr: Union[IPv4Address, IPv6Address], port: int): + with self._addr_lock: + if self.device_addr != str(addr) or self.device_port != port: + self.device_addr = str(addr) + self.device_port = port + self._logger.info(f'updated device network address: {self.device_addr}:{self.device_port}') + + def set_device_pubkey(self, pubkey: bytes): + if self.device_pubkey.hex() != pubkey.hex(): + self._logger.info(f'device pubkey has changed (old={self.device_pubkey.hex()}, new={pubkey.hex()})') + self.device_pubkey = pubkey + self._notify_cs(ConnectionStatus.RECONNECTING) + + def get_address(self) -> Tuple[str, int]: + with self._addr_lock: + return self.device_addr, self.device_port + + def add_incoming_message_listener(self, listener: IncomingMessageListener): + with self._iml_lock: + if listener not in self.inc_listeners: + self.inc_listeners.append(listener) + + def add_connection_status_listener(self, listener: ConnectionStatusListener): + with self._csl_lock: + if listener not in self.conn_listeners: + self.conn_listeners.append(listener) + + def _notify_cs(self, status: ConnectionStatus): + # self._logger.debug(f'_notify_cs: status={status}') + with self._csl_lock: + for obj in self.conn_listeners: + # self._logger.debug(f'_notify_cs: notifying {obj}') + obj.connection_status_updated(status) + + def _prepare_keys(self): + # generate key pair + privkey = X25519PrivateKey.generate() + + self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=srlz.Encoding.Raw, + format=srlz.PublicFormat.Raw))) + + # generate shared key + device_pubkey = X25519PublicKey.from_public_bytes( + bytes(reversed(self.device_pubkey)) + ) + shared_key = bytes(reversed( + privkey.exchange(device_pubkey) + )) + + # in/out encryption keys + digest = hashes.Hash(hashes.SHA256()) + digest.update(shared_key) + + shared_sha256 = digest.finalize() + + self.encinkey = shared_sha256[:16] + self.encoutkey = shared_sha256[16:] + + self._logger.info('encryption keys have been created') + + def _handshake_callback(self, r: MessageResponse): + # if got error for our HandshakeMessage, reset everything and try again + if r is False: + # self._logger.debug('_handshake_callback: set status=RECONNETING') + self._notify_cs(ConnectionStatus.RECONNECTING) + else: + # self._logger.debug('_handshake_callback: set status=CONNECTED') + self._notify_cs(ConnectionStatus.CONNECTED) + + def run(self): + self._logger.info('starting server loop') + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.bind(('0.0.0.0', self.source_port)) + sock.settimeout(self.read_timeout) + + while not self.interrupted: + with self._st_lock: + status = self.status + + if status in (ConnectionStatus.DISCONNECTED, ConnectionStatus.RECONNECTING): + self._cleanup() + if status == ConnectionStatus.DISCONNECTED: + break + + # no activity for some time means connection is broken + fail = False + fail_path = 0 + if self.incoming_time > 0 and time.time() - self.incoming_time >= DISCONNECT_TIMEOUT: + fail = True + fail_path = 1 + elif self.outgoing_time_1st > 0 and self.incoming_time == 0 and time.time() - self.outgoing_time_1st >= DISCONNECT_TIMEOUT: + fail = True + fail_path = 2 + + if fail: + self._logger.debug(f'run: setting status=RECONNECTING because of long inactivity, fail_path={fail_path}') + self._notify_cs(ConnectionStatus.RECONNECTING) + + # establishing a connection + if status in (ConnectionStatus.RECONNECTING, ConnectionStatus.NOT_CONNECTED): + if status == ConnectionStatus.RECONNECTING and self.reconnect_tries >= 3: + self._notify_cs(ConnectionStatus.DISCONNECTED) + continue + + self._reset_outseq() + self._prepare_keys() + + # shake the imaginary kettle's hand + wrapped = WrappedMessage(HandshakeMessage(), + handler=self._handshake_callback, + validator=lambda m: isinstance(m, (AckMessage, HandshakeResponseMessage))) + self.enqueue_message(wrapped, prepend=True) + self._notify_cs(ConnectionStatus.CONNECTING) + + # pick next (wrapped) message to send + wm = self._get_next_message() # wm means "wrapped message" + if wm: + one_shot = isinstance(wm.message, (AckMessage, NakMessage)) + + if not isinstance(wm.message, (AckMessage, NakMessage)): + old_seq = wm.seq + wm.seq = self.outseq + self._set_response_handler(wm, old_seq=old_seq) + elif wm.seq is None: + # ack/nak is a response to some incoming message (and it must have the same seqno that incoming + # message had) + raise RuntimeError(f'run: seq must be set for {wm.__class__.__name__}') + + self._logger.debug(f'run: sending message: {wm.message}, one_shot={one_shot}, phase={wm.phase}') + encrypted = False + try: + wm.message.encrypt(outkey=self.encoutkey, inkey=self.encinkey, + token=self.device_token, pubkey=self.pubkey) + encrypted = True + except ValueError as exc: + # handle "ValueError: Invalid padding bytes." + self._logger.error('run: failed to encrypt the message.') + self._logger.exception(exc) + + if encrypted: + buf = wm.message.frame.pack() + # self._logger.debug(f'run: raw data to be sent: {buf.hex()}') + + # sending the first time + if wm.phase == MessagePhase.WAITING: + sock.sendto(buf, self.get_address()) + # resending + elif wm.phase == MessagePhase.SENT: + left = RESEND_ATTEMPTS + while left > 0: + sock.sendto(buf, self.get_address()) + left -= 1 + if left > 0: + time.sleep(0.05) + + if one_shot or wm.phase == MessagePhase.SENT: + wm.phase = MessagePhase.DONE + else: + wm.phase = MessagePhase.SENT + + now = time.time() + self.outgoing_time = now + if not self.outgoing_time_1st: + self.outgoing_time_1st = now + + # receiving data + try: + data = sock.recv(4096) + self._handle_incoming(data) + except (TimeoutError, socket.timeout): + pass + + self._logger.info('bye...') + + def _get_next_message(self) -> Optional[WrappedMessage]: + message = None + lpfx = '_get_next_message:' + remove_list = [] + for wm in self.outgoing_queue: + if wm.phase == MessagePhase.DONE: + if isinstance(wm.message, (AckMessage, NakMessage, PingMessage)) or time.time() - wm.phase_update_time >= MESSAGE_QUEUE_REMOVE_DELAY: + remove_list.append(wm) + continue + message = wm + break + + for wm in remove_list: + self._logger.debug(f'{lpfx} rm path: removing id={wm.id} seq={wm.seq}') + + # clear message handler + if wm.seq in self.response_handlers: + self.response_handlers[wm.seq].call( + False, error_message=f'{lpfx} rm path: error while calling callback for seq={wm.seq}') + del self.response_handlers[wm.seq] + + # remove from queue + try: + self.outgoing_queue.remove(wm) + except ValueError as exc: + self._logger.error(f'{lpfx} rm path: removing from outgoing_queue raised an exception: {str(exc)}') + + # ping pong + if not message and self.outgoing_time_1st != 0 and self.status == ConnectionStatus.CONNECTED: + now = time.time() + out_delta = now - self.outgoing_time + in_delta = now - self.incoming_time + if max(out_delta, in_delta) > PING_FREQUENCY: + self._logger.debug(f'{lpfx} no activity: in for {in_delta:.2f}s, out for {out_delta:.2f}s, time to ping the damn thing') + message = WrappedMessage(PingMessage(), ack=True) + # add it to outgoing_queue in order to be aggressively resent in future (if needed) + self.outgoing_queue.insert(0, message) + + return message + + def _handle_incoming(self, buf: bytes): + try: + incoming_message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey) + except ValueError as exc: + # handle "ValueError: Invalid padding bytes." + self._logger.error('_handle_incoming: failed to decrypt incoming frame:') + self._logger.exception(exc) + return + + self.incoming_time = time.time() + seq = incoming_message.seq + + lpfx = f'handle_incoming({incoming_message.id}):' + self._logger.debug(f'{lpfx} received: {incoming_message}') + + if isinstance(incoming_message, (AckMessage, NakMessage)): + seq_max = self.outseq + seq_name = 'outseq' + else: + seq_max = self.inseq + seq_name = 'inseq' + self.inseq = seq + + if seq < seq_max < 0xfd: + self._logger.debug(f'{lpfx} dropping: seq={seq}, {seq_name}={seq_max}') + return + + if seq not in self.response_handlers: + self._handle_incoming_cmd(incoming_message) + return + + callback_value = None # None means don't call a callback + handler = self.response_handlers[seq] + + if handler.validate(incoming_message): + self._logger.debug(f'{lpfx} response OK') + handler.phase = MessagePhase.DONE + callback_value = incoming_message + self._incr_outseq() + else: + self._logger.warning(f'{lpfx} response is INVALID') + + # It seems that we've received an incoming CmdMessage or PingMessage with the same seqno that our outgoing + # message had. Bad, but what can I say, this is quick-and-dirty made UDP based protocol and this sort of + # shit just happens. + + # (To be fair, maybe my implementation is not perfect either. But hey, what did you expect from a + # reverse-engineered re-implementation of custom UDP-based protocol that some noname vendor uses for their + # cheap IoT devices? I think _that_ is _the_ definition of shit. At least my implementation is FOSS, which + # is more than you'll ever be able to say about them.) + + # All this crapload of code below might not be needed at all, 'cause the protocol uses separate frame seq + # numbers for IN and OUT frames and this situation is not highly likely, as Theresa May could argue. + # After a handshake, a kettle sends us 10 or so CmdMessages, and then either we continuously ping it every + # 3 seconds, or kettle pings us. This in any case widens the gap between inseq and outseq. + + # But! the seqno is only 1 byte in size and once it reaches 0xff, it circles back to zero. And that (plus, + # perhaps, some bad luck) gives a chance for a collision. + + if handler.phase == MessagePhase.DONE or isinstance(handler.message, HandshakeMessage): + # no more attempts left, returning error back to user + # as to handshake, it cannot fail. + callback_value = False + + # else: + # # try resending the message + # handler.phase_reset() + # max_seq = self.outseq + # wait_remap = {} + # for m in self.outgoing_queue: + # if m.seq in self.waiting_for_response: + # wait_remap[m.seq] = (m.seq+1) % 256 + # m.set_seq((m.seq+1) % 256) + # if m.seq > max_seq: + # max_seq = m.seq + # if max_seq > self.outseq: + # self.outseq = max_seq % 256 + # if wait_remap: + # waiting_new = {} + # for old_seq, new_seq in wait_remap.items(): + # waiting_new[new_seq] = self.waiting_for_response[old_seq] + # self.waiting_for_response = waiting_new + + if isinstance(incoming_message, (PingMessage, CmdIncomingMessage)): + # handle incoming message as usual, as we need to ack/nak it anyway + self._handle_incoming_cmd(incoming_message) + + if callback_value is not None: + handler.call(callback_value, + error_message=f'{lpfx} error while calling callback for msg id={handler.message.id} seq={seq}') + del self.response_handlers[seq] + + def _handle_incoming_cmd(self, incoming_message: Message): + if isinstance(incoming_message, (AckMessage, NakMessage)): + self._logger.debug(f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): it\'s {incoming_message.__class__.__name__}, ignoring') + return + + replied = False + with self._iml_lock: + for f in self.inc_listeners: + retval = safe_callback_call(f.incoming_message, incoming_message, + logger=self._logger, + error_message=f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): error while calling message listener') + if isinstance(retval, Message): + if isinstance(retval, (AckMessage, NakMessage)): + retval.seq = incoming_message.seq + self.enqueue_message(WrappedMessage(retval), prepend=True) + replied = True + break + else: + raise RuntimeError('are you sure your response is correct? only ack/nak are allowed') + + if not replied: + self.enqueue_message(WrappedMessage(AckMessage(incoming_message.seq)), prepend=True) + + def enqueue_message(self, wrapped: WrappedMessage, prepend=False): + self._logger.debug(f'enqueue_message: {wrapped.message}') + if not prepend: + self.outgoing_queue.append(wrapped) + else: + self.outgoing_queue.insert(0, wrapped) + + def _set_response_handler(self, wm: WrappedMessage, old_seq=None): + if old_seq in self.response_handlers: + del self.response_handlers[old_seq] + + seq = wm.seq + assert seq is not None, 'seq is not set' + + if seq in self.response_handlers: + self._logger.debug(f'_set_response_handler(seq={seq}): handler is already set, cancelling it') + self.response_handlers[seq].call(False, + error_message=f'_set_response_handler({seq}): error while calling old callback') + self.response_handlers[seq] = wm + + def _incr_outseq(self) -> None: + self.outseq = (self.outseq + 1) % 256 + + def _reset_outseq(self): + self.outseq = 0 + self._logger.debug(f'_reset_outseq: set 0') + + +MessageResponse = Union[Message, bool] diff --git a/test/test_polaris_stuff.py b/test/test_polaris_stuff.py index f3c8e6a..b921891 100755 --- a/test/test_polaris_stuff.py +++ b/test/test_polaris_stuff.py @@ -7,13 +7,13 @@ sys.path.extend([ ) ]) -import src.polaris as polaris +import src.syncleo as polaris if __name__ == '__main__': - sc = [cl for cl in polaris.protocol.CmdIncomingMessage.__subclasses__() - if cl is not polaris.protocol.SimpleBooleanMessage] - sc.extend(polaris.protocol.SimpleBooleanMessage.__subclasses__()) + sc = [cl for cl in syncleo.protocol.CmdIncomingMessage.__subclasses__() + if cl is not syncleo.protocol.SimpleBooleanMessage] + sc.extend(syncleo.protocol.SimpleBooleanMessage.__subclasses__()) for cl in sc: # if cl == polaris.protocol.HandshakeMessage: # print('skip') -- cgit v1.2.3