# Polaris PWK 1725CGLD "smart" kettle python library # -------------------------------------------------- # Copyright (C) Evgeny Zinoviev, 2022 # License: BSD-3c from __future__ import annotations import logging import socket import random import struct import threading import time from abc import abstractmethod, ABC from enum import Enum, auto from typing import Union, Optional, Dict, Tuple, List from ipaddress import IPv4Address, IPv6Address import cryptography.hazmat.primitives._serialization as srlz from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey from cryptography.hazmat.primitives import ciphers, padding, hashes from cryptography.hazmat.primitives.ciphers import algorithms, modes ReprDict = Dict[str, Union[str, int, float, bool]] _logger = logging.getLogger(__name__) PING_FREQUENCY = 3 RESEND_ATTEMPTS = 5 READ_TIMEOUT = 1 ERROR_TIMEOUT = 15 MESSAGE_QUEUE_REMOVE_DELAY = 13 # after what time to delete (and pass False to handlers, if needed) messages with phase=DONE from queue DISCONNECT_TIMEOUT = 15 def safe_callback_call(f: callable, *args, logger: logging.Logger = None, error_message: str = None): try: return f(*args) except Exception as exc: logger.error(f'{error_message}, see exception below:') logger.exception(exc) return None # drop-in replacement for java.lang.System.arraycopy # TODO: rewrite def arraycopy(src, src_pos, dest, dest_pos, length): for i in range(length): dest[i + dest_pos] = src[i + src_pos] # "convert" unsigned byte to signed def u8_to_s8(b: int) -> int: return struct.unpack('b', bytes([b]))[0] class PowerType(Enum): OFF = 0 # turn off ON = 1 # turn on, set target temperature to 100 CUSTOM = 3 # turn on, allows custom target temperature # MYSTERY_MODE = 2 # don't know what 2 means, needs testing # update: if I set it to '2', it just resets to '0' # low-level protocol structures # ----------------------------- class FrameType(Enum): ACK = 0 CMD = 1 AUX = 2 NAK = 3 class FrameHead: seq: Optional[int] # u8 type: FrameType # u8 length: int # u16. This is the length of FrameItem's payload @staticmethod def from_bytes(buf: bytes) -> FrameHead: seq, ft, length = struct.unpack(' bytes: assert self.length != 0, "FrameHead.length has not been set" assert self.seq is not None, "FrameHead.seq has not been set" return struct.pack(' bytes: ba = bytearray(self.head.pack()) ba.extend(self.payload) return bytes(ba) # high-level wrappers around FrameItem # ------------------------------------ class MessagePhase(Enum): WAITING = 0 SENT = 1 DONE = 2 class Message: frame: Optional[FrameItem] id: int _global_id = 0 def __init__(self): self.frame = None # global internal message id, only useful for debugging purposes self.id = self.next_id() def __repr__(self): return f'<{self.__class__.__name__} id={self.id} seq={self.frame.head.seq}>' @staticmethod def next_id(): _id = Message._global_id Message._global_id = (Message._global_id + 1) % 100000 return _id @staticmethod def from_encrypted(buf: bytes, inkey: bytes, outkey: bytes) -> Message: _logger.debug(f'Message:from_encrypted: buf={buf.hex()}') assert len(buf) >= 4, 'invalid size' head = FrameHead.from_bytes(buf[:4]) assert len(buf) == head.length + 4, f'invalid buf size ({len(buf)} != {head.length})' payload = buf[4:] b = head.seq j = b & 0xF k = b >> 4 & 0xF key = bytearray(len(inkey)) arraycopy(inkey, j, key, 0, len(inkey) - j) arraycopy(inkey, 0, key, len(inkey) - j, j) iv = bytearray(len(outkey)) arraycopy(outkey, k, iv, 0, len(outkey) - k) arraycopy(outkey, 0, iv, len(outkey) - k, k) cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv)) decryptor = cipher.decryptor() decrypted_data = decryptor.update(payload) + decryptor.finalize() unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() decrypted_data = unpadder.update(decrypted_data) decrypted_data += unpadder.finalize() assert len(decrypted_data) != 0, 'decrypted data is null' assert head.seq == decrypted_data[0], f'decrypted seq mismatch {head.seq} != {decrypted_data[0]}' # _logger.debug('Message.from_encrypted: plaintext: '+decrypted_data.hex()) if head.type == FrameType.ACK: return AckMessage(head.seq) elif head.type == FrameType.NAK: return NakMessage(head.seq) elif head.type == FrameType.AUX: # TODO implement AUX raise NotImplementedError('FrameType AUX is not yet implemented') elif head.type == FrameType.CMD: type = decrypted_data[1] data = decrypted_data[2:] cl = UnknownMessage subclasses = [cl for cl in CmdIncomingMessage.__subclasses__() if cl is not SimpleBooleanMessage] subclasses.extend(SimpleBooleanMessage.__subclasses__()) for _cl in subclasses: # `UnknownMessage` is a special class that holds a packed command that we don't recognize. # It will be used anyway if we don't find a match, so skip it here if _cl == UnknownMessage: continue if _cl.TYPE == type: cl = _cl break m = cl.from_packed_data(data, seq=head.seq) if isinstance(m, UnknownMessage): m.set_type(type) return m else: raise NotImplementedError(f'Unexpected frame type: {head.type}') def pack_data(self) -> bytes: return b'' @property def seq(self) -> Union[int, None]: try: return self.frame.head.seq except: return None @seq.setter def seq(self, seq: int): self.frame.head.seq = seq def encrypt(self, outkey: bytes, inkey: bytes, token: bytes, pubkey: bytes): assert self.frame is not None data = self._get_data_to_encrypt() assert data is not None b = self.frame.head.seq i = b & 0xf j = b >> 4 & 0xf outkey = bytearray(outkey) l = len(outkey) key = bytearray(l) arraycopy(outkey, i, key, 0, l-i) arraycopy(outkey, 0, key, l-i, i) inkey = bytearray(inkey) l = len(inkey) iv = bytearray(l) arraycopy(inkey, j, iv, 0, l-j) arraycopy(inkey, 0, iv, l-j, j) cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv)) encryptor = cipher.encryptor() newdata = bytearray(len(data)+1) newdata[0] = b arraycopy(data, 0, newdata, 1, len(data)) newdata = bytes(newdata) _logger.debug('frame payload to be encrypted: ' + newdata.hex()) padder = padding.PKCS7(algorithms.AES.block_size).padder() ciphertext = bytearray() ciphertext.extend(encryptor.update(padder.update(newdata) + padder.finalize())) ciphertext.extend(encryptor.finalize()) self.frame.setpayload(ciphertext) def _get_data_to_encrypt(self) -> bytes: return self.pack_data() class AckMessage(Message, ABC): def __init__(self, seq: Optional[int] = None): super().__init__() self.frame = FrameItem(FrameHead(seq, FrameType.ACK, None)) class NakMessage(Message, ABC): def __init__(self, seq: Optional[int] = None): super().__init__() self.frame = FrameItem(FrameHead(seq, FrameType.NAK, None)) class CmdMessage(Message): type: Optional[int] data: bytes TYPE = None def _get_data_to_encrypt(self) -> bytes: buf = bytearray() buf.append(self.get_type()) buf.extend(self.pack_data()) return bytes(buf) def __init__(self, seq: Optional[int] = None): super().__init__() self.frame = FrameItem(FrameHead(seq, FrameType.CMD)) self.data = b'' def _repr_fields(self) -> ReprDict: return { 'cmd': self.get_type() } def __repr__(self): params = [ __name__+'.'+self.__class__.__name__, f'id={self.id}', f'seq={self.seq}' ] fields = self._repr_fields() if fields: for k, v in fields.items(): params.append(f'{k}={v}') elif self.data: params.append(f'data={self.data.hex()}') return '<'+' '.join(params)+'>' def get_type(self) -> int: return self.__class__.TYPE class CmdIncomingMessage(CmdMessage): @staticmethod @abstractmethod def from_packed_data(cls, data: bytes, seq: Optional[int] = None): pass @abstractmethod def _repr_fields(self) -> ReprDict: pass class CmdOutgoingMessage(CmdMessage): @abstractmethod def pack_data(self) -> bytes: return b'' class ModeMessage(CmdOutgoingMessage, CmdIncomingMessage): TYPE = 1 pt: PowerType def __init__(self, power_type: PowerType, seq: Optional[int] = None): super().__init__(seq) self.pt = power_type @classmethod def from_packed_data(cls, data: bytes, seq=0) -> ModeMessage: assert len(data) == 1, 'data size expected to be 1' mode, = struct.unpack('B', data) return ModeMessage(PowerType(mode), seq=seq) def pack_data(self) -> bytes: return self.pt.value.to_bytes(1, byteorder='little') def _repr_fields(self) -> ReprDict: return {'mode': self.pt.name} class TargetTemperatureMessage(CmdOutgoingMessage, CmdIncomingMessage): temperature: int TYPE = 2 def __init__(self, temp: int, seq: Optional[int] = None): super().__init__(seq) self.temperature = temp @classmethod def from_packed_data(cls, data: bytes, seq=0) -> TargetTemperatureMessage: assert len(data) == 2, 'data size expected to be 2' nat, frac = struct.unpack('BB', data) temp = int(nat + (frac / 100)) return TargetTemperatureMessage(temp, seq=seq) def pack_data(self) -> bytes: return bytes([self.temperature, 0]) def _repr_fields(self) -> ReprDict: return {'temperature': self.temperature} class PingMessage(CmdIncomingMessage, CmdOutgoingMessage): TYPE = 255 @classmethod def from_packed_data(cls, data: bytes, seq=0) -> PingMessage: assert len(data) == 0, 'no data expected' return PingMessage(seq=seq) def pack_data(self) -> bytes: return b'' def _repr_fields(self) -> ReprDict: return {} # This is the first protocol message. Sent by a client. # Kettle usually ACKs this, but sometimes i don't get any ACK and the very next message is HandshakeResponseMessage. class HandshakeMessage(CmdMessage): TYPE = 0 def encrypt(self, outkey: bytes, inkey: bytes, token: bytes, pubkey: bytes): cipher = ciphers.Cipher(algorithms.AES(outkey), modes.CBC(inkey)) encryptor = cipher.encryptor() ciphertext = bytearray() ciphertext.extend(encryptor.update(token)) ciphertext.extend(encryptor.finalize()) pld = bytearray() pld.append(0) pld.extend(pubkey) pld.extend(ciphertext) self.frame.setpayload(pld) # Kettle either sends this right after the handshake, of first it ACKs the handshake then sends this. class HandshakeResponseMessage(CmdIncomingMessage): TYPE = 0 protocol: int fw_major: int fw_minor: int mode: int token: bytes def __init__(self, protocol: int, fw_major: int, fw_minor: int, mode: int, token: bytes, seq: Optional[int] = None): super().__init__(seq) self.protocol = protocol self.fw_major = fw_major self.fw_minor = fw_minor self.mode = mode self.token = token @classmethod def from_packed_data(cls, data: bytes, seq=0) -> HandshakeResponseMessage: protocol, fw_major, fw_minor, mode = struct.unpack(' ReprDict: return { 'protocol': self.protocol, 'fw': f'{self.fw_major}.{self.fw_minor}', 'mode': self.mode, 'token': self.token.hex() } # Apparently, some hardware info. # On the other hand, if you look at com.syncleiot.iottransport.commands.CmdHardware, its mqtt topic says "mcu_firmware". # My device returns 1.1.1. The kettle uses on ESP8266 ESP-12F MCU under the hood (or, more precisely, under a piece of # cheap plastic), so maybe 1.1.1 is some MCU ROM version. class DeviceHardwareMessage(CmdIncomingMessage): TYPE = 143 # -113 hw: List[int] def __init__(self, hw: List[int], seq: Optional[int] = None): super().__init__(seq) self.hw = hw @classmethod def from_packed_data(cls, data: bytes, seq=0) -> DeviceHardwareMessage: assert len(data) == 3, 'invalid data size, expected 3' hw = list(struct.unpack(' ReprDict: return {'device_hardware': '.'.join(map(str, self.hw))} # This message is sent by kettle right after the HandshakeMessageResponse. # The diagnostic data is supposed to be sent to vendor, which we, obviously, not going to do. # So just ACK and skip it. class DeviceDiagnosticMessage(CmdIncomingMessage): TYPE = 145 # -111 diag_data: bytes def __init__(self, diag_data: bytes, seq: Optional[int] = None): super().__init__(seq) self.diag_data = diag_data @classmethod def from_packed_data(cls, data: bytes, seq=0) -> DeviceDiagnosticMessage: return DeviceDiagnosticMessage(diag_data=data, seq=seq) def _repr_fields(self) -> ReprDict: return {'diag_data': self.diag_data.hex()} class SimpleBooleanMessage(ABC, CmdIncomingMessage): value: bool def __init__(self, value: bool, seq: Optional[int] = None): super().__init__(seq) self.value = value @classmethod def from_packed_data(cls, data: bytes, seq: Optional[int] = None): assert len(data) == 1, 'invalid data size, expected 1' enabled, = struct.unpack(' ReprDict: pass class AccessControlMessage(SimpleBooleanMessage): TYPE = 133 # -123 def _repr_fields(self) -> ReprDict: return {'acl_enabled': self.value} class ErrorMessage(SimpleBooleanMessage): TYPE = 7 def _repr_fields(self) -> ReprDict: return {'error': self.value} class ChildLockMessage(SimpleBooleanMessage): TYPE = 30 def _repr_fields(self) -> ReprDict: return {'child_lock': self.value} class VolumeMessage(SimpleBooleanMessage): TYPE = 9 def _repr_fields(self) -> ReprDict: return {'volume': self.value} class BacklightMessage(SimpleBooleanMessage): TYPE = 28 def _repr_fields(self) -> ReprDict: return {'backlight': self.value} class CurrentTemperatureMessage(CmdIncomingMessage): TYPE = 20 current_temperature: int def __init__(self, temp: int, seq: Optional[int] = None): super().__init__(seq) self.current_temperature = temp @classmethod def from_packed_data(cls, data: bytes, seq=0) -> CurrentTemperatureMessage: assert len(data) == 2, 'data size expected to be 2' nat, frac = struct.unpack('BB', data) temp = int(nat + (frac / 100)) return CurrentTemperatureMessage(temp, seq=seq) def pack_data(self) -> bytes: return bytes([self.current_temperature, 0]) def _repr_fields(self) -> ReprDict: return {'current_temperature': self.current_temperature} class UnknownMessage(CmdIncomingMessage): type: Optional[int] data: bytes def __init__(self, data: bytes, **kwargs): super().__init__(**kwargs) self.type = None self.data = data @classmethod def from_packed_data(cls, data: bytes, seq=0) -> UnknownMessage: return UnknownMessage(data, seq=seq) def set_type(self, type: int): self.type = type def get_type(self) -> int: return self.type def _repr_fields(self) -> ReprDict: return { 'type': self.type, 'data': self.data.hex() } class WrappedMessage: _message: Message _handler: Optional[callable] _validator: Optional[callable] _logger: Optional[logging.Logger] _phase: MessagePhase _phase_update_time: float def __init__(self, message: Message, handler: Optional[callable] = None, validator: Optional[callable] = None, ack=False): self._message = message self._handler = handler self._validator = validator self._logger = None self._phase = MessagePhase.WAITING self._phase_update_time = 0 if not validator and ack: self._validator = lambda m: isinstance(m, AckMessage) def setlogger(self, logger: logging.Logger): self._logger = logger def validate(self, message: Message): if not self._validator: return True return self._validator(message) def call(self, *args, error_message: str = None) -> None: if not self._handler: return try: self._handler(*args) except Exception as exc: logger = self._logger or logging.getLogger(self.__class__.__name__) logger.error(f'{error_message}, see exception below:') logger.exception(exc) @property def phase(self) -> MessagePhase: return self._phase @phase.setter def phase(self, phase: MessagePhase): self._phase = phase self._phase_update_time = 0 if phase == MessagePhase.WAITING else time.time() @property def phase_update_time(self) -> float: return self._phase_update_time @property def message(self) -> Message: return self._message @property def id(self) -> int: return self._message.id @property def seq(self) -> int: return self._message.seq @seq.setter def seq(self, seq: int): self._message.seq = seq def __repr__(self): return f'<{__name__}.{self.__class__.__name__} message={self._message.__repr__()}>' # Connection stuff # Well, strictly speaking, as it's UDP, there's no connection, but who cares. # --------------------------------------------------------------------------- class IncomingMessageListener: @abstractmethod def incoming_message(self, message: Message) -> Optional[Message]: pass class ConnectionStatus(Enum): NOT_CONNECTED = auto() CONNECTING = auto() CONNECTED = auto() RECONNECTING = auto() DISCONNECTED = auto() class ConnectionStatusListener: @abstractmethod def connection_status_updated(self, status: ConnectionStatus): pass class UDPConnection(threading.Thread, ConnectionStatusListener): inseq: int outseq: int source_port: int device_addr: str device_port: int device_token: bytes device_pubkey: bytes interrupted: bool response_handlers: Dict[int, WrappedMessage] outgoing_queue: List[WrappedMessage] pubkey: Optional[bytes] encinkey: Optional[bytes] encoutkey: Optional[bytes] inc_listeners: List[IncomingMessageListener] conn_listeners: List[ConnectionStatusListener] outgoing_time: float outgoing_time_1st: float incoming_time: float status: ConnectionStatus reconnect_tries: int _addr_lock: threading.Lock _iml_lock: threading.Lock _csl_lock: threading.Lock _st_lock: threading.Lock def __init__(self, addr: Union[IPv4Address, IPv6Address], port: int, device_pubkey: bytes, device_token: bytes): super().__init__() self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__} <{hex(id(self))}>') self.setName(self.__class__.__name__) self.inseq = 0 self.outseq = 0 self.source_port = random.randint(1024, 65535) self.device_addr = str(addr) self.device_port = port self.device_token = device_token self.device_pubkey = device_pubkey self.outgoing_queue = [] self.response_handlers = {} self.interrupted = False self.outgoing_time = 0 self.outgoing_time_1st = 0 self.incoming_time = 0 self.inc_listeners = [] self.conn_listeners = [self] self.status = ConnectionStatus.NOT_CONNECTED self.reconnect_tries = 0 self._iml_lock = threading.Lock() self._csl_lock = threading.Lock() self._addr_lock = threading.Lock() self._st_lock = threading.Lock() self.pubkey = None self.encinkey = None self.encoutkey = None def connection_status_updated(self, status: ConnectionStatus): # self._logger.info(f'connection_status_updated: status = {status}') with self._st_lock: # self._logger.debug(f'connection_status_updated: lock acquired') self.status = status if status == ConnectionStatus.RECONNECTING: self.reconnect_tries += 1 if status in (ConnectionStatus.CONNECTED, ConnectionStatus.NOT_CONNECTED, ConnectionStatus.DISCONNECTED): self.reconnect_tries = 0 def _cleanup(self): # erase outgoing queue for wm in self.outgoing_queue: wm.call(False, error_message=f'_cleanup: exception while calling cb(False) on message {wm.message}') self.outgoing_queue = [] self.response_handlers = {} # reset timestamps self.incoming_time = 0 self.outgoing_time = 0 self.outgoing_time_1st = 0 self._logger.debug('_cleanup: done') def set_address(self, addr: Union[IPv4Address, IPv6Address], port: int): with self._addr_lock: if self.device_addr != str(addr) or self.device_port != port: self.device_addr = str(addr) self.device_port = port self._logger.info(f'updated device network address: {self.device_addr}:{self.device_port}') def set_device_pubkey(self, pubkey: bytes): if self.device_pubkey.hex() != pubkey.hex(): self._logger.info(f'device pubkey has changed (old={self.device_pubkey.hex()}, new={pubkey.hex()})') self.device_pubkey = pubkey self._notify_cs(ConnectionStatus.RECONNECTING) def get_address(self) -> Tuple[str, int]: with self._addr_lock: return self.device_addr, self.device_port def add_incoming_message_listener(self, listener: IncomingMessageListener): with self._iml_lock: if listener not in self.inc_listeners: self.inc_listeners.append(listener) def add_connection_status_listener(self, listener: ConnectionStatusListener): with self._csl_lock: if listener not in self.conn_listeners: self.conn_listeners.append(listener) def _notify_cs(self, status: ConnectionStatus): # self._logger.debug(f'_notify_cs: status={status}') with self._csl_lock: for obj in self.conn_listeners: # self._logger.debug(f'_notify_cs: notifying {obj}') obj.connection_status_updated(status) def _prepare_keys(self): # generate key pair privkey = X25519PrivateKey.generate() self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=srlz.Encoding.Raw, format=srlz.PublicFormat.Raw))) # generate shared key device_pubkey = X25519PublicKey.from_public_bytes( bytes(reversed(self.device_pubkey)) ) shared_key = bytes(reversed( privkey.exchange(device_pubkey) )) # in/out encryption keys digest = hashes.Hash(hashes.SHA256()) digest.update(shared_key) shared_sha256 = digest.finalize() self.encinkey = shared_sha256[:16] self.encoutkey = shared_sha256[16:] self._logger.info('encryption keys have been created') def _handshake_callback(self, r: MessageResponse): # if got error for our HandshakeMessage, reset everything and try again if r is False: # self._logger.debug('_handshake_callback: set status=RECONNETING') self._notify_cs(ConnectionStatus.RECONNECTING) else: # self._logger.debug('_handshake_callback: set status=CONNECTED') self._notify_cs(ConnectionStatus.CONNECTED) def run(self): self._logger.info('starting server loop') sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind(('0.0.0.0', self.source_port)) sock.settimeout(READ_TIMEOUT) while not self.interrupted: with self._st_lock: status = self.status if status in (ConnectionStatus.DISCONNECTED, ConnectionStatus.RECONNECTING): self._cleanup() if status == ConnectionStatus.DISCONNECTED: break # no activity for some time means connection is broken fail = False fail_path = 0 if self.incoming_time > 0 and time.time() - self.incoming_time >= DISCONNECT_TIMEOUT: fail = True fail_path = 1 elif self.outgoing_time_1st > 0 and self.incoming_time == 0 and time.time() - self.outgoing_time_1st >= DISCONNECT_TIMEOUT: fail = True fail_path = 2 if fail: self._logger.debug(f'run: setting status=RECONNECTING because of long inactivity, fail_path={fail_path}') self._notify_cs(ConnectionStatus.RECONNECTING) # establishing a connection if status in (ConnectionStatus.RECONNECTING, ConnectionStatus.NOT_CONNECTED): if status == ConnectionStatus.RECONNECTING and self.reconnect_tries >= 3: self._notify_cs(ConnectionStatus.DISCONNECTED) continue self._reset_outseq() self._prepare_keys() # shake the imaginary kettle's hand wrapped = WrappedMessage(HandshakeMessage(), handler=self._handshake_callback, validator=lambda m: isinstance(m, (AckMessage, HandshakeResponseMessage))) self.enqueue_message(wrapped, prepend=True) self._notify_cs(ConnectionStatus.CONNECTING) # pick next (wrapped) message to send wm = self._get_next_message() # wm means "wrapped message" if wm: if not isinstance(wm.message, (AckMessage, NakMessage)): old_seq = wm.seq wm.seq = self.outseq self._set_response_handler(wm, old_seq=old_seq) elif wm.seq is None: # ack/nak is a response to some incoming message (and it must have the same seqno that incoming # message had) raise RuntimeError(f'run: seq must be set for {wm.__class__.__name__}') self._logger.debug(f'run: sending message: {wm.message}') wm.message.encrypt(outkey=self.encoutkey, inkey=self.encinkey, token=self.device_token, pubkey=self.pubkey) buf = wm.message.frame.pack() one_shot = isinstance(wm.message, (AckMessage, NakMessage)) # self._logger.debug(f'run: raw data to be sent: {buf.hex()}') # sending the first time if wm.phase == MessagePhase.WAITING: sock.sendto(buf, self.get_address()) # resending elif wm.phase == MessagePhase.SENT: left = RESEND_ATTEMPTS while left > 0: sock.sendto(buf, self.get_address()) left -= 1 if left > 0: time.sleep(0.05) if one_shot or wm.phase == MessagePhase.SENT: wm.phase = MessagePhase.DONE else: wm.phase = MessagePhase.SENT now = time.time() self.outgoing_time = now if not self.outgoing_time_1st: self.outgoing_time_1st = now # receiving data try: data = sock.recv(4096) self._handle_incoming(data) except (TimeoutError, socket.timeout): pass self._logger.info('bye...') def _get_next_message(self) -> Optional[WrappedMessage]: message = None lpfx = '_get_next_message:' remove_list = [] for wm in self.outgoing_queue: if wm.phase == MessagePhase.DONE: if isinstance(wm.message, (AckMessage, NakMessage)) or time.time() - wm.phase_update_time >= MESSAGE_QUEUE_REMOVE_DELAY: remove_list.append(wm) continue message = wm break for wm in remove_list: self._logger.debug(f'{lpfx} rm path: removing id={wm.id} seq={wm.seq}') # clear message handler if wm.seq in self.response_handlers: self.response_handlers[wm.seq].call( False, error_message=f'{lpfx} rm path: error while calling callback for seq={wm.seq}') del self.response_handlers[wm.seq] # remove from queue try: self.outgoing_queue.remove(wm) except ValueError as exc: self._logger.error(f'{lpfx} rm path: removing from outgoing_queue raised an exception: {str(exc)}') # ping pong if self.outgoing_time_1st != 0 and self.status == ConnectionStatus.CONNECTED: now = time.time() out_delta = now - self.outgoing_time in_delta = now - self.incoming_time if not message and max(out_delta, in_delta) > PING_FREQUENCY: self._logger.debug(f'{lpfx} no activity: in for {in_delta:.2f}s, out for {out_delta:.2f}s, time to ping the damn thing') message = WrappedMessage(PingMessage(), ack=True) return message def _handle_incoming(self, buf: bytes): self.incoming_time = time.time() incoming_message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey) seq = incoming_message.seq lpfx = f'handle_incoming({incoming_message.id}):' self._logger.debug(f'{lpfx} received: {incoming_message}') if isinstance(incoming_message, (AckMessage, NakMessage)): seq_max = self.outseq seq_name = 'outseq' else: seq_max = self.inseq seq_name = 'inseq' self.inseq = seq if seq < seq_max < 0xfd: self._logger.warning(f'{lpfx} dropping: seq={seq}, {seq_name}={seq_name}') return if seq not in self.response_handlers: self._handle_incoming_cmd(incoming_message) return callback_value = None # None means don't call a callback handler = self.response_handlers[seq] if handler.validate(incoming_message): self._logger.debug(f'{lpfx} response OK') handler.phase = MessagePhase.DONE callback_value = incoming_message self._incr_outseq() else: self._logger.warning(f'{lpfx} response is INVALID') # It seems that we've received an incoming CmdMessage or PingMessage with the same seqno that our outgoing # message had. Bad, but what can I say, this is quick-and-dirty made UDP based protocol and this sort of # shit just happens. # (To be fair, maybe my implementation is not perfect either. But hey, what did you expect from a # reverse-engineered re-implementation of custom UDP-based protocol that some noname vendor uses for their # cheap IoT devices? I think _that_ is _the_ definition of shit. At least my implementation is FOSS, which # is more than you'll ever be able to say about them.) # All this crapload of code below might not be needed at all, 'cause the protocol uses separate frame seq # numbers for IN and OUT frames and this situation is not highly likely, as Theresa May could argue. # After a handshake, a kettle sends us 10 or so CmdMessages, and then either we continuously ping it every # 3 seconds, or kettle pings us. This in any case widens the gap between inseq and outseq. # But! the seqno is only 1 byte in size and once it reaches 0xff, it circles back to zero. And that (plus, # perhaps, some bad luck) gives a chance for a collision. if handler.phase == MessagePhase.DONE or isinstance(handler.message, HandshakeMessage): # no more attempts left, returning error back to user # as to handshake, it cannot fail. callback_value = False # else: # # try resending the message # handler.phase_reset() # max_seq = self.outseq # wait_remap = {} # for m in self.outgoing_queue: # if m.seq in self.waiting_for_response: # wait_remap[m.seq] = (m.seq+1) % 256 # m.set_seq((m.seq+1) % 256) # if m.seq > max_seq: # max_seq = m.seq # if max_seq > self.outseq: # self.outseq = max_seq % 256 # if wait_remap: # waiting_new = {} # for old_seq, new_seq in wait_remap.items(): # waiting_new[new_seq] = self.waiting_for_response[old_seq] # self.waiting_for_response = waiting_new if isinstance(incoming_message, (PingMessage, CmdIncomingMessage)): # handle incoming message as usual, as we need to ack/nak it anyway self._handle_incoming_cmd(incoming_message) if callback_value is not None: handler.call(callback_value, error_message=f'{lpfx} error while calling callback for msg id={handler.message.id} seq={seq}') del self.response_handlers[seq] def _handle_incoming_cmd(self, incoming_message: Message): if isinstance(incoming_message, (AckMessage, NakMessage)): self._logger.debug(f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): it\'s {incoming_message.__class__.__name__}, ignoring') return replied = False with self._iml_lock: for f in self.inc_listeners: retval = safe_callback_call(f.incoming_message, incoming_message, logger=self._logger, error_message=f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): error while calling message listener') if isinstance(retval, Message): if isinstance(retval, (AckMessage, NakMessage)): retval.seq = incoming_message.seq self.enqueue_message(WrappedMessage(retval), prepend=True) replied = True break else: raise RuntimeError('are you sure your response is correct? only ack/nak are allowed') if not replied: self.enqueue_message(WrappedMessage(AckMessage(incoming_message.seq)), prepend=True) def enqueue_message(self, wrapped: WrappedMessage, prepend=False): self._logger.debug(f'enqueue_message: {wrapped.message}') if not prepend: self.outgoing_queue.append(wrapped) else: self.outgoing_queue.insert(0, wrapped) def _set_response_handler(self, wm: WrappedMessage, old_seq=None): if old_seq in self.response_handlers: del self.response_handlers[old_seq] seq = wm.seq assert seq is not None, 'seq is not set' if seq in self.response_handlers: self._logger.warning(f'_set_response_handler(seq={seq}): handler is already set, cancelling it') self.response_handlers[seq].call(False, error_message=f'_set_response_handler({seq}): error while calling old callback') self.response_handlers[seq] = wm def _incr_outseq(self) -> None: self.outseq = (self.outseq + 1) % 256 def _reset_outseq(self): self.outseq = 0 self._logger.debug(f'_reset_outseq: set 0') MessageResponse = Union[Message, bool]