diff options
Diffstat (limited to 'include/py/syncleo/protocol.py')
-rw-r--r-- | include/py/syncleo/protocol.py | 1169 |
1 files changed, 1169 insertions, 0 deletions
diff --git a/include/py/syncleo/protocol.py b/include/py/syncleo/protocol.py new file mode 100644 index 0000000..36a1a8f --- /dev/null +++ b/include/py/syncleo/protocol.py @@ -0,0 +1,1169 @@ +# Polaris PWK 1725CGLD "smart" kettle python library +# -------------------------------------------------- +# Copyright (C) Evgeny Zinoviev, 2022 +# License: BSD-3c + +from __future__ import annotations + +import logging +import socket +import random +import struct +import threading +import time + +from abc import abstractmethod, ABC +from enum import Enum, auto +from typing import Union, Optional, Dict, Tuple, List +from ipaddress import IPv4Address, IPv6Address + +import cryptography.hazmat.primitives._serialization as srlz + +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey +from cryptography.hazmat.primitives import ciphers, padding, hashes +from cryptography.hazmat.primitives.ciphers import algorithms, modes + +ReprDict = Dict[str, Union[str, int, float, bool]] +_logger = logging.getLogger(__name__) + +PING_FREQUENCY = 3 +RESEND_ATTEMPTS = 5 +ERROR_TIMEOUT = 15 +MESSAGE_QUEUE_REMOVE_DELAY = 13 # after what time to delete (and pass False to handlers, if needed) messages with phase=DONE from queue +DISCONNECT_TIMEOUT = 15 + + +def safe_callback_call(f: callable, + *args, + logger: logging.Logger = None, + error_message: str = None): + try: + return f(*args) + except Exception as exc: + logger.error(f'{error_message}, see exception below:') + logger.exception(exc) + return None + + +# drop-in replacement for java.lang.System.arraycopy +# TODO: rewrite +def arraycopy(src, src_pos, dest, dest_pos, length): + for i in range(length): + dest[i + dest_pos] = src[i + src_pos] + + +# "convert" unsigned byte to signed +def u8_to_s8(b: int) -> int: + return struct.unpack('b', bytes([b]))[0] + + +class PowerType(Enum): + OFF = 0 # turn off + ON = 1 # turn on, set target temperature to 100 + CUSTOM = 3 # turn on, allows custom target temperature + # MYSTERY_MODE = 2 # don't know what 2 means, needs testing + # update: if I set it to '2', it just resets to '0' + + +# low-level protocol structures +# ----------------------------- + +class FrameType(Enum): + ACK = 0 + CMD = 1 + AUX = 2 + NAK = 3 + + +class FrameHead: + seq: Optional[int] # u8 + type: FrameType # u8 + length: int # u16. This is the length of FrameItem's payload + + @staticmethod + def from_bytes(buf: bytes) -> FrameHead: + seq, ft, length = struct.unpack('<BBH', buf) + return FrameHead(seq, FrameType(ft), length) + + def __init__(self, + seq: Optional[int], + frame_type: FrameType, + length: Optional[int] = None): + self.seq = seq + self.type = frame_type + self.length = length or 0 + + def pack(self) -> bytes: + assert self.length != 0, "FrameHead.length has not been set" + assert self.seq is not None, "FrameHead.seq has not been set" + return struct.pack('<BBH', self.seq, self.type.value, self.length) + + +class FrameItem: + head: FrameHead + payload: bytes + + def __init__(self, head: FrameHead, payload: Optional[bytes] = None): + self.head = head + self.payload = payload + + def setpayload(self, payload: Union[bytes, bytearray]): + if isinstance(payload, bytearray): + payload = bytes(payload) + self.payload = payload + self.head.length = len(payload) + + def pack(self) -> bytes: + ba = bytearray(self.head.pack()) + ba.extend(self.payload) + return bytes(ba) + + +# high-level wrappers around FrameItem +# ------------------------------------ + +class MessagePhase(Enum): + WAITING = 0 + SENT = 1 + DONE = 2 + + +class Message: + frame: Optional[FrameItem] + id: int + + _global_id = 0 + + def __init__(self): + self.frame = None + + # global internal message id, only useful for debugging purposes + self.id = self.next_id() + + def __repr__(self): + return f'<{self.__class__.__name__} id={self.id} seq={self.frame.head.seq}>' + + @staticmethod + def next_id(): + _id = Message._global_id + Message._global_id = (Message._global_id + 1) % 100000 + return _id + + @staticmethod + def from_encrypted(buf: bytes, inkey: bytes, outkey: bytes) -> Message: + _logger.debug(f'Message:from_encrypted: buf={buf.hex()}') + + assert len(buf) >= 4, 'invalid size' + head = FrameHead.from_bytes(buf[:4]) + + assert len(buf) == head.length + 4, f'invalid buf size ({len(buf)} != {head.length})' + payload = buf[4:] + b = head.seq + + j = b & 0xF + k = b >> 4 & 0xF + + key = bytearray(len(inkey)) + arraycopy(inkey, j, key, 0, len(inkey) - j) + arraycopy(inkey, 0, key, len(inkey) - j, j) + + iv = bytearray(len(outkey)) + arraycopy(outkey, k, iv, 0, len(outkey) - k) + arraycopy(outkey, 0, iv, len(outkey) - k, k) + + cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv)) + decryptor = cipher.decryptor() + decrypted_data = decryptor.update(payload) + decryptor.finalize() + + unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() + decrypted_data = unpadder.update(decrypted_data) + decrypted_data += unpadder.finalize() + + assert len(decrypted_data) != 0, 'decrypted data is null' + assert head.seq == decrypted_data[0], f'decrypted seq mismatch {head.seq} != {decrypted_data[0]}' + + # _logger.debug('Message.from_encrypted: plaintext: '+decrypted_data.hex()) + + if head.type == FrameType.ACK: + return AckMessage(head.seq) + + elif head.type == FrameType.NAK: + return NakMessage(head.seq) + + elif head.type == FrameType.AUX: + # TODO implement AUX + raise NotImplementedError('FrameType AUX is not yet implemented') + + elif head.type == FrameType.CMD: + type = decrypted_data[1] + data = decrypted_data[2:] + + cl = UnknownMessage + + subclasses = [cl for cl in CmdIncomingMessage.__subclasses__() if cl is not SimpleBooleanMessage] + subclasses.extend(SimpleBooleanMessage.__subclasses__()) + + for _cl in subclasses: + # `UnknownMessage` is a special class that holds a packed command that we don't recognize. + # It will be used anyway if we don't find a match, so skip it here + if _cl == UnknownMessage: + continue + + if _cl.TYPE == type: + cl = _cl + break + + m = cl.from_packed_data(data, seq=head.seq) + if isinstance(m, UnknownMessage): + m.set_type(type) + return m + + else: + raise NotImplementedError(f'Unexpected frame type: {head.type}') + + def pack_data(self) -> bytes: + return b'' + + @property + def seq(self) -> Union[int, None]: + try: + return self.frame.head.seq + except: + return None + + @seq.setter + def seq(self, seq: int): + self.frame.head.seq = seq + + def encrypt(self, outkey: bytes, inkey: bytes, token: bytes, pubkey: bytes): + assert self.frame is not None + + data = self._get_data_to_encrypt() + assert data is not None + + b = self.frame.head.seq + i = b & 0xf + j = b >> 4 & 0xf + + outkey = bytearray(outkey) + + l = len(outkey) + key = bytearray(l) + + arraycopy(outkey, i, key, 0, l-i) + arraycopy(outkey, 0, key, l-i, i) + + inkey = bytearray(inkey) + + l = len(inkey) + iv = bytearray(l) + + arraycopy(inkey, j, iv, 0, l-j) + arraycopy(inkey, 0, iv, l-j, j) + + cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv)) + encryptor = cipher.encryptor() + + newdata = bytearray(len(data)+1) + newdata[0] = b + + arraycopy(data, 0, newdata, 1, len(data)) + + newdata = bytes(newdata) + _logger.debug('frame payload to be encrypted: ' + newdata.hex()) + + padder = padding.PKCS7(algorithms.AES.block_size).padder() + ciphertext = bytearray() + ciphertext.extend(encryptor.update(padder.update(newdata) + padder.finalize())) + ciphertext.extend(encryptor.finalize()) + + self.frame.setpayload(ciphertext) + + def _get_data_to_encrypt(self) -> bytes: + return self.pack_data() + + +class AckMessage(Message, ABC): + def __init__(self, seq: Optional[int] = None): + super().__init__() + self.frame = FrameItem(FrameHead(seq, FrameType.ACK, None)) + + +class NakMessage(Message, ABC): + def __init__(self, seq: Optional[int] = None): + super().__init__() + self.frame = FrameItem(FrameHead(seq, FrameType.NAK, None)) + + +class CmdMessage(Message): + type: Optional[int] + data: bytes + + TYPE = None + + def _get_data_to_encrypt(self) -> bytes: + buf = bytearray() + buf.append(self.get_type()) + buf.extend(self.pack_data()) + return bytes(buf) + + def __init__(self, seq: Optional[int] = None): + super().__init__() + self.frame = FrameItem(FrameHead(seq, FrameType.CMD)) + self.data = b'' + + def _repr_fields(self) -> ReprDict: + return { + 'cmd': self.get_type() + } + + def __repr__(self): + params = [ + __name__+'.'+self.__class__.__name__, + f'id={self.id}', + f'seq={self.seq}' + ] + fields = self._repr_fields() + if fields: + for k, v in fields.items(): + params.append(f'{k}={v}') + elif self.data: + params.append(f'data={self.data.hex()}') + return '<'+' '.join(params)+'>' + + def get_type(self) -> int: + return self.__class__.TYPE + + +class CmdIncomingMessage(CmdMessage): + @staticmethod + @abstractmethod + def from_packed_data(cls, data: bytes, seq: Optional[int] = None): + pass + + @abstractmethod + def _repr_fields(self) -> ReprDict: + pass + + +class CmdOutgoingMessage(CmdMessage): + @abstractmethod + def pack_data(self) -> bytes: + return b'' + + +class ModeMessage(CmdOutgoingMessage, CmdIncomingMessage): + TYPE = 1 + + pt: PowerType + + def __init__(self, power_type: PowerType, seq: Optional[int] = None): + super().__init__(seq) + self.pt = power_type + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> ModeMessage: + assert len(data) == 1, 'data size expected to be 1' + mode, = struct.unpack('B', data) + return ModeMessage(PowerType(mode), seq=seq) + + def pack_data(self) -> bytes: + return self.pt.value.to_bytes(1, byteorder='little') + + def _repr_fields(self) -> ReprDict: + return {'mode': self.pt.name} + + +class TargetTemperatureMessage(CmdOutgoingMessage, CmdIncomingMessage): + temperature: int + + TYPE = 2 + + def __init__(self, temp: int, seq: Optional[int] = None): + super().__init__(seq) + self.temperature = temp + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> TargetTemperatureMessage: + assert len(data) == 2, 'data size expected to be 2' + nat, frac = struct.unpack('BB', data) + temp = int(nat + (frac / 100)) + return TargetTemperatureMessage(temp, seq=seq) + + def pack_data(self) -> bytes: + return bytes([self.temperature, 0]) + + def _repr_fields(self) -> ReprDict: + return {'temperature': self.temperature} + + +class PingMessage(CmdIncomingMessage, CmdOutgoingMessage): + TYPE = 255 + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> PingMessage: + assert len(data) == 0, 'no data expected' + return PingMessage(seq=seq) + + def pack_data(self) -> bytes: + return b'' + + def _repr_fields(self) -> ReprDict: + return {} + + +# This is the first protocol message. Sent by a client. +# Kettle usually ACKs this, but sometimes i don't get any ACK and the very next message is HandshakeResponseMessage. +class HandshakeMessage(CmdMessage): + TYPE = 0 + + def encrypt(self, + outkey: bytes, + inkey: bytes, + token: bytes, + pubkey: bytes): + cipher = ciphers.Cipher(algorithms.AES(outkey), modes.CBC(inkey)) + encryptor = cipher.encryptor() + + ciphertext = bytearray() + ciphertext.extend(encryptor.update(token)) + ciphertext.extend(encryptor.finalize()) + + pld = bytearray() + pld.append(0) + pld.extend(pubkey) + pld.extend(ciphertext) + + self.frame.setpayload(pld) + + +# Kettle either sends this right after the handshake, of first it ACKs the handshake then sends this. +class HandshakeResponseMessage(CmdIncomingMessage): + TYPE = 0 + + protocol: int + fw_major: int + fw_minor: int + mode: int + token: bytes + + def __init__(self, + protocol: int, + fw_major: int, + fw_minor: int, + mode: int, + token: bytes, + seq: Optional[int] = None): + super().__init__(seq) + self.protocol = protocol + self.fw_major = fw_major + self.fw_minor = fw_minor + self.mode = mode + self.token = token + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> HandshakeResponseMessage: + protocol, fw_major, fw_minor, mode = struct.unpack('<HBBB', data[:5]) + return HandshakeResponseMessage(protocol, fw_major, fw_minor, mode, token=data[5:], seq=seq) + + def _repr_fields(self) -> ReprDict: + return { + 'protocol': self.protocol, + 'fw': f'{self.fw_major}.{self.fw_minor}', + 'mode': self.mode, + 'token': self.token.hex() + } + + +# Apparently, some hardware info. +# On the other hand, if you look at com.syncleiot.iottransport.commands.CmdHardware, its mqtt topic says "mcu_firmware". +# My device returns 1.1.1. The kettle uses on ESP8266 ESP-12F MCU under the hood (or, more precisely, under a piece of +# cheap plastic), so maybe 1.1.1 is some MCU ROM version. +class DeviceHardwareMessage(CmdIncomingMessage): + TYPE = 143 # -113 + + hw: List[int] + + def __init__(self, hw: List[int], seq: Optional[int] = None): + super().__init__(seq) + self.hw = hw + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> DeviceHardwareMessage: + assert len(data) == 3, 'invalid data size, expected 3' + hw = list(struct.unpack('<BBB', data)) + return DeviceHardwareMessage(hw, seq=seq) + + def _repr_fields(self) -> ReprDict: + return {'device_hardware': '.'.join(map(str, self.hw))} + + +# This message is sent by kettle right after the HandshakeMessageResponse. +# The diagnostic data is supposed to be sent to vendor, which we, obviously, not going to do. +# So just ACK and skip it. +class DeviceDiagnosticMessage(CmdIncomingMessage): + TYPE = 145 # -111 + + diag_data: bytes + + def __init__(self, diag_data: bytes, seq: Optional[int] = None): + super().__init__(seq) + self.diag_data = diag_data + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> DeviceDiagnosticMessage: + return DeviceDiagnosticMessage(diag_data=data, seq=seq) + + def _repr_fields(self) -> ReprDict: + return {'diag_data': self.diag_data.hex()} + + +class SimpleBooleanMessage(ABC, CmdIncomingMessage): + value: bool + + def __init__(self, value: bool, seq: Optional[int] = None): + super().__init__(seq) + self.value = value + + @classmethod + def from_packed_data(cls, data: bytes, seq: Optional[int] = None): + assert len(data) == 1, 'invalid data size, expected 1' + enabled, = struct.unpack('<B', data) + return cls(value=enabled == 1, seq=seq) + + @abstractmethod + def _repr_fields(self) -> ReprDict: + pass + + +class AccessControlMessage(SimpleBooleanMessage): + TYPE = 133 # -123 + + def _repr_fields(self) -> ReprDict: + return {'acl_enabled': self.value} + + +class ErrorMessage(SimpleBooleanMessage): + TYPE = 7 + + def _repr_fields(self) -> ReprDict: + return {'error': self.value} + + +class ChildLockMessage(SimpleBooleanMessage): + TYPE = 30 + + def _repr_fields(self) -> ReprDict: + return {'child_lock': self.value} + + +class VolumeMessage(SimpleBooleanMessage): + TYPE = 9 + + def _repr_fields(self) -> ReprDict: + return {'volume': self.value} + + +class BacklightMessage(SimpleBooleanMessage): + TYPE = 28 + + def _repr_fields(self) -> ReprDict: + return {'backlight': self.value} + + +class CurrentTemperatureMessage(CmdIncomingMessage): + TYPE = 20 + + current_temperature: int + + def __init__(self, temp: int, seq: Optional[int] = None): + super().__init__(seq) + self.current_temperature = temp + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> CurrentTemperatureMessage: + assert len(data) == 2, 'data size expected to be 2' + nat, frac = struct.unpack('BB', data) + temp = int(nat + (frac / 100)) + return CurrentTemperatureMessage(temp, seq=seq) + + def pack_data(self) -> bytes: + return bytes([self.current_temperature, 0]) + + def _repr_fields(self) -> ReprDict: + return {'current_temperature': self.current_temperature} + + +class UnknownMessage(CmdIncomingMessage): + type: Optional[int] + data: bytes + + def __init__(self, data: bytes, **kwargs): + super().__init__(**kwargs) + self.type = None + self.data = data + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> UnknownMessage: + return UnknownMessage(data, seq=seq) + + def set_type(self, type: int): + self.type = type + + def get_type(self) -> int: + return self.type + + def _repr_fields(self) -> ReprDict: + return { + 'type': self.type, + 'data': self.data.hex() + } + + +class WrappedMessage: + _message: Message + _handler: Optional[callable] + _validator: Optional[callable] + _logger: Optional[logging.Logger] + _phase: MessagePhase + _phase_update_time: float + + def __init__(self, + message: Message, + handler: Optional[callable] = None, + validator: Optional[callable] = None, + ack=False): + self._message = message + self._handler = handler + self._validator = validator + self._logger = None + self._phase = MessagePhase.WAITING + self._phase_update_time = 0 + if not validator and ack: + self._validator = lambda m: isinstance(m, AckMessage) + + def setlogger(self, logger: logging.Logger): + self._logger = logger + + def validate(self, message: Message): + if not self._validator: + return True + return self._validator(message) + + def call(self, *args, error_message: str = None) -> None: + if not self._handler: + return + try: + self._handler(*args) + except Exception as exc: + logger = self._logger or logging.getLogger(self.__class__.__name__) + logger.error(f'{error_message}, see exception below:') + logger.exception(exc) + + @property + def phase(self) -> MessagePhase: + return self._phase + + @phase.setter + def phase(self, phase: MessagePhase): + self._phase = phase + self._phase_update_time = 0 if phase == MessagePhase.WAITING else time.time() + + @property + def phase_update_time(self) -> float: + return self._phase_update_time + + @property + def message(self) -> Message: + return self._message + + @property + def id(self) -> int: + return self._message.id + + @property + def seq(self) -> int: + return self._message.seq + + @seq.setter + def seq(self, seq: int): + self._message.seq = seq + + def __repr__(self): + return f'<{__name__}.{self.__class__.__name__} message={self._message.__repr__()}>' + + +# Connection stuff +# Well, strictly speaking, as it's UDP, there's no connection, but who cares. +# --------------------------------------------------------------------------- + +class IncomingMessageListener: + @abstractmethod + def incoming_message(self, message: Message) -> Optional[Message]: + pass + + +class ConnectionStatus(Enum): + NOT_CONNECTED = auto() + CONNECTING = auto() + CONNECTED = auto() + RECONNECTING = auto() + DISCONNECTED = auto() + + +class ConnectionStatusListener: + @abstractmethod + def connection_status_updated(self, status: ConnectionStatus): + pass + + +class UDPConnection(threading.Thread, ConnectionStatusListener): + inseq: int + outseq: int + source_port: int + device_addr: str + device_port: int + device_token: bytes + device_pubkey: bytes + interrupted: bool + response_handlers: Dict[int, WrappedMessage] + outgoing_queue: List[WrappedMessage] + pubkey: Optional[bytes] + encinkey: Optional[bytes] + encoutkey: Optional[bytes] + inc_listeners: List[IncomingMessageListener] + conn_listeners: List[ConnectionStatusListener] + outgoing_time: float + outgoing_time_1st: float + incoming_time: float + status: ConnectionStatus + reconnect_tries: int + read_timeout: int + + _addr_lock: threading.Lock + _iml_lock: threading.Lock + _csl_lock: threading.Lock + _st_lock: threading.Lock + + def __init__(self, + addr: Union[IPv4Address, IPv6Address], + port: int, + device_pubkey: bytes, + device_token: bytes, + read_timeout: int = 1): + super().__init__() + self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__} <{hex(id(self))}>') + self.setName(self.__class__.__name__) + + self.inseq = 0 + self.outseq = 0 + self.source_port = random.randint(1024, 65535) + self.device_addr = str(addr) + self.device_port = port + self.device_token = device_token + self.device_pubkey = device_pubkey + self.outgoing_queue = [] + self.response_handlers = {} + self.interrupted = False + self.outgoing_time = 0 + self.outgoing_time_1st = 0 + self.incoming_time = 0 + self.inc_listeners = [] + self.conn_listeners = [self] + self.status = ConnectionStatus.NOT_CONNECTED + self.reconnect_tries = 0 + self.read_timeout = read_timeout + + self._iml_lock = threading.Lock() + self._csl_lock = threading.Lock() + self._addr_lock = threading.Lock() + self._st_lock = threading.Lock() + + self.pubkey = None + self.encinkey = None + self.encoutkey = None + + def connection_status_updated(self, status: ConnectionStatus): + # self._logger.info(f'connection_status_updated: status = {status}') + with self._st_lock: + # self._logger.debug(f'connection_status_updated: lock acquired') + self.status = status + if status == ConnectionStatus.RECONNECTING: + self.reconnect_tries += 1 + if status in (ConnectionStatus.CONNECTED, ConnectionStatus.NOT_CONNECTED, ConnectionStatus.DISCONNECTED): + self.reconnect_tries = 0 + + def _cleanup(self): + # erase outgoing queue + for wm in self.outgoing_queue: + wm.call(False, + error_message=f'_cleanup: exception while calling cb(False) on message {wm.message}') + self.outgoing_queue = [] + self.response_handlers = {} + + # reset timestamps + self.incoming_time = 0 + self.outgoing_time = 0 + self.outgoing_time_1st = 0 + + self._logger.debug('_cleanup: done') + + def set_address(self, addr: Union[IPv4Address, IPv6Address], port: int): + with self._addr_lock: + if self.device_addr != str(addr) or self.device_port != port: + self.device_addr = str(addr) + self.device_port = port + self._logger.info(f'updated device network address: {self.device_addr}:{self.device_port}') + + def set_device_pubkey(self, pubkey: bytes): + if self.device_pubkey.hex() != pubkey.hex(): + self._logger.info(f'device pubkey has changed (old={self.device_pubkey.hex()}, new={pubkey.hex()})') + self.device_pubkey = pubkey + self._notify_cs(ConnectionStatus.RECONNECTING) + + def get_address(self) -> Tuple[str, int]: + with self._addr_lock: + return self.device_addr, self.device_port + + def add_incoming_message_listener(self, listener: IncomingMessageListener): + with self._iml_lock: + if listener not in self.inc_listeners: + self.inc_listeners.append(listener) + + def add_connection_status_listener(self, listener: ConnectionStatusListener): + with self._csl_lock: + if listener not in self.conn_listeners: + self.conn_listeners.append(listener) + + def _notify_cs(self, status: ConnectionStatus): + # self._logger.debug(f'_notify_cs: status={status}') + with self._csl_lock: + for obj in self.conn_listeners: + # self._logger.debug(f'_notify_cs: notifying {obj}') + obj.connection_status_updated(status) + + def _prepare_keys(self): + # generate key pair + privkey = X25519PrivateKey.generate() + + self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=srlz.Encoding.Raw, + format=srlz.PublicFormat.Raw))) + + # generate shared key + device_pubkey = X25519PublicKey.from_public_bytes( + bytes(reversed(self.device_pubkey)) + ) + shared_key = bytes(reversed( + privkey.exchange(device_pubkey) + )) + + # in/out encryption keys + digest = hashes.Hash(hashes.SHA256()) + digest.update(shared_key) + + shared_sha256 = digest.finalize() + + self.encinkey = shared_sha256[:16] + self.encoutkey = shared_sha256[16:] + + self._logger.info('encryption keys have been created') + + def _handshake_callback(self, r: MessageResponse): + # if got error for our HandshakeMessage, reset everything and try again + if r is False: + # self._logger.debug('_handshake_callback: set status=RECONNETING') + self._notify_cs(ConnectionStatus.RECONNECTING) + else: + # self._logger.debug('_handshake_callback: set status=CONNECTED') + self._notify_cs(ConnectionStatus.CONNECTED) + + def run(self): + self._logger.info('starting server loop') + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.bind(('0.0.0.0', self.source_port)) + sock.settimeout(self.read_timeout) + + while not self.interrupted: + with self._st_lock: + status = self.status + + if status in (ConnectionStatus.DISCONNECTED, ConnectionStatus.RECONNECTING): + self._cleanup() + if status == ConnectionStatus.DISCONNECTED: + break + + # no activity for some time means connection is broken + fail = False + fail_path = 0 + if self.incoming_time > 0 and time.time() - self.incoming_time >= DISCONNECT_TIMEOUT: + fail = True + fail_path = 1 + elif self.outgoing_time_1st > 0 and self.incoming_time == 0 and time.time() - self.outgoing_time_1st >= DISCONNECT_TIMEOUT: + fail = True + fail_path = 2 + + if fail: + self._logger.debug(f'run: setting status=RECONNECTING because of long inactivity, fail_path={fail_path}') + self._notify_cs(ConnectionStatus.RECONNECTING) + + # establishing a connection + if status in (ConnectionStatus.RECONNECTING, ConnectionStatus.NOT_CONNECTED): + if status == ConnectionStatus.RECONNECTING and self.reconnect_tries >= 3: + self._notify_cs(ConnectionStatus.DISCONNECTED) + continue + + self._reset_outseq() + self._prepare_keys() + + # shake the imaginary kettle's hand + wrapped = WrappedMessage(HandshakeMessage(), + handler=self._handshake_callback, + validator=lambda m: isinstance(m, (AckMessage, HandshakeResponseMessage))) + self.enqueue_message(wrapped, prepend=True) + self._notify_cs(ConnectionStatus.CONNECTING) + + # pick next (wrapped) message to send + wm = self._get_next_message() # wm means "wrapped message" + if wm: + one_shot = isinstance(wm.message, (AckMessage, NakMessage)) + + if not isinstance(wm.message, (AckMessage, NakMessage)): + old_seq = wm.seq + wm.seq = self.outseq + self._set_response_handler(wm, old_seq=old_seq) + elif wm.seq is None: + # ack/nak is a response to some incoming message (and it must have the same seqno that incoming + # message had) + raise RuntimeError(f'run: seq must be set for {wm.__class__.__name__}') + + self._logger.debug(f'run: sending message: {wm.message}, one_shot={one_shot}, phase={wm.phase}') + encrypted = False + try: + wm.message.encrypt(outkey=self.encoutkey, inkey=self.encinkey, + token=self.device_token, pubkey=self.pubkey) + encrypted = True + except ValueError as exc: + # handle "ValueError: Invalid padding bytes." + self._logger.error('run: failed to encrypt the message.') + self._logger.exception(exc) + + if encrypted: + buf = wm.message.frame.pack() + # self._logger.debug(f'run: raw data to be sent: {buf.hex()}') + + # sending the first time + if wm.phase == MessagePhase.WAITING: + sock.sendto(buf, self.get_address()) + # resending + elif wm.phase == MessagePhase.SENT: + left = RESEND_ATTEMPTS + while left > 0: + sock.sendto(buf, self.get_address()) + left -= 1 + if left > 0: + time.sleep(0.05) + + if one_shot or wm.phase == MessagePhase.SENT: + wm.phase = MessagePhase.DONE + else: + wm.phase = MessagePhase.SENT + + now = time.time() + self.outgoing_time = now + if not self.outgoing_time_1st: + self.outgoing_time_1st = now + + # receiving data + try: + data = sock.recv(4096) + self._handle_incoming(data) + except (TimeoutError, socket.timeout): + pass + + self._logger.info('bye...') + + def _get_next_message(self) -> Optional[WrappedMessage]: + message = None + lpfx = '_get_next_message:' + remove_list = [] + for wm in self.outgoing_queue: + if wm.phase == MessagePhase.DONE: + if isinstance(wm.message, (AckMessage, NakMessage, PingMessage)) or time.time() - wm.phase_update_time >= MESSAGE_QUEUE_REMOVE_DELAY: + remove_list.append(wm) + continue + message = wm + break + + for wm in remove_list: + self._logger.debug(f'{lpfx} rm path: removing id={wm.id} seq={wm.seq}') + + # clear message handler + if wm.seq in self.response_handlers: + self.response_handlers[wm.seq].call( + False, error_message=f'{lpfx} rm path: error while calling callback for seq={wm.seq}') + del self.response_handlers[wm.seq] + + # remove from queue + try: + self.outgoing_queue.remove(wm) + except ValueError as exc: + self._logger.error(f'{lpfx} rm path: removing from outgoing_queue raised an exception: {str(exc)}') + + # ping pong + if not message and self.outgoing_time_1st != 0 and self.status == ConnectionStatus.CONNECTED: + now = time.time() + out_delta = now - self.outgoing_time + in_delta = now - self.incoming_time + if max(out_delta, in_delta) > PING_FREQUENCY: + self._logger.debug(f'{lpfx} no activity: in for {in_delta:.2f}s, out for {out_delta:.2f}s, time to ping the damn thing') + message = WrappedMessage(PingMessage(), ack=True) + # add it to outgoing_queue in order to be aggressively resent in future (if needed) + self.outgoing_queue.insert(0, message) + + return message + + def _handle_incoming(self, buf: bytes): + try: + incoming_message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey) + except ValueError as exc: + # handle "ValueError: Invalid padding bytes." + self._logger.error('_handle_incoming: failed to decrypt incoming frame:') + self._logger.exception(exc) + return + + self.incoming_time = time.time() + seq = incoming_message.seq + + lpfx = f'handle_incoming({incoming_message.id}):' + self._logger.debug(f'{lpfx} received: {incoming_message}') + + if isinstance(incoming_message, (AckMessage, NakMessage)): + seq_max = self.outseq + seq_name = 'outseq' + else: + seq_max = self.inseq + seq_name = 'inseq' + self.inseq = seq + + if seq < seq_max < 0xfd: + self._logger.debug(f'{lpfx} dropping: seq={seq}, {seq_name}={seq_max}') + return + + if seq not in self.response_handlers: + self._handle_incoming_cmd(incoming_message) + return + + callback_value = None # None means don't call a callback + handler = self.response_handlers[seq] + + if handler.validate(incoming_message): + self._logger.debug(f'{lpfx} response OK') + handler.phase = MessagePhase.DONE + callback_value = incoming_message + self._incr_outseq() + else: + self._logger.warning(f'{lpfx} response is INVALID') + + # It seems that we've received an incoming CmdMessage or PingMessage with the same seqno that our outgoing + # message had. Bad, but what can I say, this is quick-and-dirty made UDP based protocol and this sort of + # shit just happens. + + # (To be fair, maybe my implementation is not perfect either. But hey, what did you expect from a + # reverse-engineered re-implementation of custom UDP-based protocol that some noname vendor uses for their + # cheap IoT devices? I think _that_ is _the_ definition of shit. At least my implementation is FOSS, which + # is more than you'll ever be able to say about them.) + + # All this crapload of code below might not be needed at all, 'cause the protocol uses separate frame seq + # numbers for IN and OUT frames and this situation is not highly likely, as Theresa May could argue. + # After a handshake, a kettle sends us 10 or so CmdMessages, and then either we continuously ping it every + # 3 seconds, or kettle pings us. This in any case widens the gap between inseq and outseq. + + # But! the seqno is only 1 byte in size and once it reaches 0xff, it circles back to zero. And that (plus, + # perhaps, some bad luck) gives a chance for a collision. + + if handler.phase == MessagePhase.DONE or isinstance(handler.message, HandshakeMessage): + # no more attempts left, returning error back to user + # as to handshake, it cannot fail. + callback_value = False + + # else: + # # try resending the message + # handler.phase_reset() + # max_seq = self.outseq + # wait_remap = {} + # for m in self.outgoing_queue: + # if m.seq in self.waiting_for_response: + # wait_remap[m.seq] = (m.seq+1) % 256 + # m.set_seq((m.seq+1) % 256) + # if m.seq > max_seq: + # max_seq = m.seq + # if max_seq > self.outseq: + # self.outseq = max_seq % 256 + # if wait_remap: + # waiting_new = {} + # for old_seq, new_seq in wait_remap.items(): + # waiting_new[new_seq] = self.waiting_for_response[old_seq] + # self.waiting_for_response = waiting_new + + if isinstance(incoming_message, (PingMessage, CmdIncomingMessage)): + # handle incoming message as usual, as we need to ack/nak it anyway + self._handle_incoming_cmd(incoming_message) + + if callback_value is not None: + handler.call(callback_value, + error_message=f'{lpfx} error while calling callback for msg id={handler.message.id} seq={seq}') + del self.response_handlers[seq] + + def _handle_incoming_cmd(self, incoming_message: Message): + if isinstance(incoming_message, (AckMessage, NakMessage)): + self._logger.debug(f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): it\'s {incoming_message.__class__.__name__}, ignoring') + return + + replied = False + with self._iml_lock: + for f in self.inc_listeners: + retval = safe_callback_call(f.incoming_message, incoming_message, + logger=self._logger, + error_message=f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): error while calling message listener') + if isinstance(retval, Message): + if isinstance(retval, (AckMessage, NakMessage)): + retval.seq = incoming_message.seq + self.enqueue_message(WrappedMessage(retval), prepend=True) + replied = True + break + else: + raise RuntimeError('are you sure your response is correct? only ack/nak are allowed') + + if not replied: + self.enqueue_message(WrappedMessage(AckMessage(incoming_message.seq)), prepend=True) + + def enqueue_message(self, wrapped: WrappedMessage, prepend=False): + self._logger.debug(f'enqueue_message: {wrapped.message}') + if not prepend: + self.outgoing_queue.append(wrapped) + else: + self.outgoing_queue.insert(0, wrapped) + + def _set_response_handler(self, wm: WrappedMessage, old_seq=None): + if old_seq in self.response_handlers: + del self.response_handlers[old_seq] + + seq = wm.seq + assert seq is not None, 'seq is not set' + + if seq in self.response_handlers: + self._logger.debug(f'_set_response_handler(seq={seq}): handler is already set, cancelling it') + self.response_handlers[seq].call(False, + error_message=f'_set_response_handler({seq}): error while calling old callback') + self.response_handlers[seq] = wm + + def _incr_outseq(self) -> None: + self.outseq = (self.outseq + 1) % 256 + + def _reset_outseq(self): + self.outseq = 0 + self._logger.debug(f'_reset_outseq: set 0') + + +MessageResponse = Union[Message, bool] |