summaryrefslogtreecommitdiff
path: root/src/polaris
diff options
context:
space:
mode:
authorEvgeny Zinoviev <me@ch1p.io>2022-06-28 03:22:30 +0300
committerEvgeny Zinoviev <me@ch1p.io>2022-06-30 03:47:49 +0300
commit8f20c9b825cabab7a3f0f5dd2cfe000cc7f72c28 (patch)
treeb5d7446e7b2fcfd42b1e5029aeef33ecb5f9715f /src/polaris
parentee09bc98aedfc6a65a5026432b399345a30a39c8 (diff)
polaris pwk 1725cgld full support
- significant improvements, correctnesses and stability fixes in protocol implementation - correct handling of device appearances and disappearances - flawlessly functioning telegram bot that re-renders kettle's state (temperature and other) in real time
Diffstat (limited to 'src/polaris')
-rw-r--r--src/polaris/__init__.py14
-rw-r--r--src/polaris/kettle.py271
-rw-r--r--src/polaris/protocol.py1015
3 files changed, 1093 insertions, 207 deletions
diff --git a/src/polaris/__init__.py b/src/polaris/__init__.py
index aa077ce..f1a7c1d 100644
--- a/src/polaris/__init__.py
+++ b/src/polaris/__init__.py
@@ -1,4 +1,12 @@
-# SPDX-License-Identifier: BSD-3-Clause
+# Polaris PWK 1725CGLD "smart" kettle python library
+# --------------------------------------------------
+# Copyright (C) Evgeny Zinoviev, 2022
+# License: BSD-3c
-from .kettle import Kettle
-from .protocol import Message, FrameType, PowerType \ No newline at end of file
+from .kettle import Kettle, DeviceListener
+from .protocol import (
+ PowerType,
+ IncomingMessageListener,
+ ConnectionStatusListener,
+ ConnectionStatus
+) \ No newline at end of file
diff --git a/src/polaris/kettle.py b/src/polaris/kettle.py
index 37f6813..1b995fd 100644
--- a/src/polaris/kettle.py
+++ b/src/polaris/kettle.py
@@ -1,97 +1,238 @@
-# SPDX-License-Identifier: BSD-3-Clause
+# Polaris PWK 1725CGLD smart kettle python library
+# ------------------------------------------------
+# Copyright (C) Evgeny Zinoviev, 2022
+# License: BSD-3c
from __future__ import annotations
+import threading
import logging
import zeroconf
-import cryptography.hazmat.primitives._serialization
-from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
-from cryptography.hazmat.primitives import hashes
-
-from functools import partial
-from abc import ABC
-from ipaddress import ip_address
-from typing import Optional
+from abc import abstractmethod
+from ipaddress import ip_address, IPv4Address, IPv6Address
+from typing import Optional, List, Union
from .protocol import (
- Connection,
+ UDPConnection,
ModeMessage,
- HandshakeMessage,
TargetTemperatureMessage,
- Message,
- PowerType
+ PowerType,
+ ConnectionStatus,
+ ConnectionStatusListener,
+ WrappedMessage
)
-_logger = logging.getLogger(__name__)
-
-# Polaris PWK 1725CGLD IoT kettle
-class Kettle(zeroconf.ServiceListener, ABC):
- macaddr: str
- device_token: str
- sb: Optional[zeroconf.ServiceBrowser]
- found_device: Optional[zeroconf.ServiceInfo]
- conn: Optional[Connection]
+class DeviceDiscover(threading.Thread, zeroconf.ServiceListener):
+ si: Optional[zeroconf.ServiceInfo]
+ _mac: str
+ _sb: Optional[zeroconf.ServiceBrowser]
+ _zc: Optional[zeroconf.Zeroconf]
+ _listeners: List[DeviceListener]
+ _valid_addresses: List[Union[IPv4Address, IPv6Address]]
+ _only_ipv4: bool
- def __init__(self, mac: str, device_token: str):
+ def __init__(self, mac: str,
+ listener: Optional[DeviceListener] = None,
+ only_ipv4=True):
super().__init__()
- self.zeroconf = zeroconf.Zeroconf()
- self.sb = None
- self.macaddr = mac
- self.device_token = device_token
- self.found_device = None
- self.conn = None
+ self.si = None
+ self._mac = mac
+ self._zc = None
+ self._sb = None
+ self._only_ipv4 = only_ipv4
+ self._valid_addresses = []
+ self._listeners = []
+ if isinstance(listener, DeviceListener):
+ self._listeners.append(listener)
+ self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}')
+
+ def add_listener(self, listener: DeviceListener):
+ if listener not in self._listeners:
+ self._listeners.append(listener)
+ else:
+ self._logger.warning(f'add_listener: listener {listener} already in the listeners list')
+
+ def set_info(self, info: zeroconf.ServiceInfo):
+ valid_addresses = self._get_valid_addresses(info)
+ if not valid_addresses:
+ raise ValueError('no valid addresses')
+ self._valid_addresses = valid_addresses
+ self.si = info
+ for f in self._listeners:
+ try:
+ f.device_updated()
+ except Exception as exc:
+ self._logger.error(f'set_info: error while calling device_updated on {f}')
+ self._logger.exception(exc)
- def find(self) -> zeroconf.ServiceInfo:
- self.sb = zeroconf.ServiceBrowser(self.zeroconf, "_syncleo._udp.local.", self)
- self.sb.join()
+ def add_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None:
+ self._add_update_service('add_service', zc, type_, name)
- return self.found_device
+ def update_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None:
+ self._add_update_service('update_service', zc, type_, name)
- # zeroconf.ServiceListener implementation
- def add_service(self,
- zc: zeroconf.Zeroconf,
- type_: str,
- name: str) -> None:
- if name.startswith(f'{self.macaddr}.'):
- info = zc.get_service_info(type_, name)
+ def _add_update_service(self, method: str, zc: zeroconf.Zeroconf, type_: str, name: str) -> None:
+ info = zc.get_service_info(type_, name)
+ if name.startswith(f'{self._mac}.'):
+ self._logger.info(f'{method}: type={type_} name={name}')
+ try:
+ self.set_info(info)
+ except ValueError as exc:
+ self._logger.error(f'{method}: rejected: {str(exc)}')
+ else:
+ self._logger.debug(f'{method}: mac not matched: {info}')
+
+ def remove_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None:
+ if name.startswith(f'{self._mac}.'):
+ self._logger.info(f'remove_service: type={type_} name={name}')
+ # TODO what to do here?!
+
+ def run(self):
+ self._logger.info('starting zeroconf service browser')
+ ip_version = zeroconf.IPVersion.V4Only if self._only_ipv4 else zeroconf.IPVersion.All
+ self._zc = zeroconf.Zeroconf(ip_version=ip_version)
+ self._sb = zeroconf.ServiceBrowser(self._zc, "_syncleo._udp.local.", self)
+ self._sb.join()
+
+ def stop(self):
+ if self._sb:
try:
- self.sb.cancel()
+ self._sb.cancel()
except RuntimeError:
pass
- self.zeroconf.close()
- self.found_device = info
+ self._sb = None
+ self._zc.close()
+ self._zc = None
+
+ def _get_valid_addresses(self, si: zeroconf.ServiceInfo) -> List[Union[IPv4Address, IPv6Address]]:
+ valid = []
+ for addr in map(ip_address, si.addresses):
+ if self._only_ipv4 and not isinstance(addr, IPv4Address):
+ continue
+ if isinstance(addr, IPv4Address) and str(addr).startswith('169.254.'):
+ continue
+ valid.append(addr)
+ return valid
- assert self.device_curve == 29, f'curve type {self.device_curve} is not implemented'
-
- def start_server(self, callback: callable):
- addresses = list(map(ip_address, self.found_device.addresses))
- self.conn = Connection(addr=addresses[0],
- port=int(self.found_device.port),
- device_pubkey=self.device_pubkey,
- device_token=bytes.fromhex(self.device_token))
+ @property
+ def pubkey(self) -> bytes:
+ return bytes.fromhex(self.si.properties[b'public'].decode())
- # shake the kettle's hand
- self._pass_message(HandshakeMessage(), callback)
- self.conn.start()
+ @property
+ def curve(self) -> int:
+ return int(self.si.properties[b'curve'].decode())
- def stop_server(self):
- self.conn.interrupted = True
+ @property
+ def addr(self) -> Union[IPv4Address, IPv6Address]:
+ return self._valid_addresses[0]
@property
- def device_pubkey(self) -> bytes:
- return bytes.fromhex(self.found_device.properties[b'public'].decode())
+ def port(self) -> int:
+ return int(self.si.port)
@property
- def device_curve(self) -> int:
- return int(self.found_device.properties[b'curve'].decode())
+ def protocol(self) -> int:
+ return int(self.si.properties[b'protocol'].decode())
+
+
+class DeviceListener:
+ @abstractmethod
+ def device_updated(self):
+ pass
+
+
+class Kettle(DeviceListener, ConnectionStatusListener):
+ mac: str
+ device: Optional[DeviceDiscover]
+ device_token: str
+ conn: Optional[UDPConnection]
+ conn_status: Optional[ConnectionStatus]
+ _logger: logging.Logger
+ _find_evt: threading.Event
+
+ def __init__(self, mac: str, device_token: str):
+ super().__init__()
+ self.mac = mac
+ self.device = None
+ self.device_token = device_token
+ self.conn = None
+ self.conn_status = None
+ self._find_evt = threading.Event()
+ self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}')
+
+ def device_updated(self):
+ self._find_evt.set()
+ self._logger.info(f'device updated, service info: {self.device.si}')
+
+ def connection_status_updated(self, status: ConnectionStatus):
+ self.conn_status = status
+
+ def discover(self, wait=True, timeout=None, listener=None) -> Optional[zeroconf.ServiceInfo]:
+ do_start = False
+ if not self.device:
+ self.device = DeviceDiscover(self.mac, listener=self, only_ipv4=True)
+ do_start = True
+ self._logger.debug('discover: started device discovery')
+ else:
+ self._logger.warning('discover: already started')
+
+ if listener is not None:
+ self.device.add_listener(listener)
+
+ if do_start:
+ self.device.start()
+
+ if wait:
+ self._find_evt.clear()
+ try:
+ self._find_evt.wait(timeout=timeout)
+ except KeyboardInterrupt:
+ self.device.stop()
+ return None
+ return self.device.si
+
+ def start_server_if_needed(self,
+ incoming_message_listener=None,
+ connection_status_listener=None):
+ if self.conn:
+ self._logger.warning('start_server_if_needed: server is already started!')
+ self.conn.set_address(self.device.addr, self.device.port)
+ self.conn.set_device_pubkey(self.device.pubkey)
+ return
+
+ assert self.device.curve == 29, f'curve type {self.device.curve} is not implemented'
+ assert self.device.protocol == 2, f'protocol {self.device.protocol} is not supported'
+
+ self.conn = UDPConnection(addr=self.device.addr,
+ port=self.device.port,
+ device_pubkey=self.device.pubkey,
+ device_token=bytes.fromhex(self.device_token))
+ if incoming_message_listener:
+ self.conn.add_incoming_message_listener(incoming_message_listener)
+
+ self.conn.add_connection_status_listener(self)
+ if connection_status_listener:
+ self.conn.add_connection_status_listener(connection_status_listener)
+
+ self.conn.start()
+
+ def stop_all(self):
+ # when we stop server, we should also stop device discovering service
+ if self.conn:
+ self.conn.interrupted = True
+ self.conn = None
+ self.device.stop()
+ self.device = None
+
+ def is_connected(self) -> bool:
+ return self.conn is not None and self.conn_status == ConnectionStatus.CONNECTED
def set_power(self, power_type: PowerType, callback: callable):
- self._pass_message(ModeMessage(power_type), callback)
+ message = ModeMessage(power_type)
+ self.conn.enqueue_message(WrappedMessage(message, handler=callback, ack=True))
def set_target_temperature(self, temp: int, callback: callable):
- self._pass_message(TargetTemperatureMessage(temp), callback)
-
- def _pass_message(self, message: Message, callback: callable):
- self.conn.send_message(message, partial(callback, self))
+ message = TargetTemperatureMessage(temp)
+ self.conn.enqueue_message(WrappedMessage(message, handler=callback, ack=True))
diff --git a/src/polaris/protocol.py b/src/polaris/protocol.py
index cc4e36a..5d7390f 100644
--- a/src/polaris/protocol.py
+++ b/src/polaris/protocol.py
@@ -1,39 +1,61 @@
-# SPDX-License-Identifier: BSD-3-Clause
+# Polaris PWK 1725CGLD "smart" kettle python library
+# --------------------------------------------------
+# Copyright (C) Evgeny Zinoviev, 2022
+# License: BSD-3c
from __future__ import annotations
+
import logging
import socket
import random
import struct
import threading
-import queue
+import time
-from enum import Enum
-from typing import Union, Optional, Any
+from abc import abstractmethod, ABC
+from enum import Enum, auto
+from typing import Union, Optional, Dict, Tuple, List
from ipaddress import IPv4Address, IPv6Address
-import cryptography.hazmat.primitives._serialization
+import cryptography.hazmat.primitives._serialization as srlz
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives import ciphers, padding, hashes
from cryptography.hazmat.primitives.ciphers import algorithms, modes
-
+ReprDict = Dict[str, Union[str, int, float, bool]]
_logger = logging.getLogger(__name__)
+PING_FREQUENCY = 3
+RESEND_ATTEMPTS = 5
+READ_TIMEOUT = 1
+ERROR_TIMEOUT = 15
+MESSAGE_QUEUE_REMOVE_DELAY = 13 # after what time to delete (and pass False to handlers, if needed) messages with phase=DONE from queue
+DISCONNECT_TIMEOUT = 15
+
-# drop-in replacement for Java API
+def safe_callback_call(f: callable,
+ *args,
+ logger: logging.Logger = None,
+ error_message: str = None):
+ try:
+ return f(*args)
+ except Exception as exc:
+ logger.error(f'{error_message}, see exception below:')
+ logger.exception(exc)
+ return None
+
+
+# drop-in replacement for java.lang.System.arraycopy
# TODO: rewrite
def arraycopy(src, src_pos, dest, dest_pos, length):
for i in range(length):
dest[i + dest_pos] = src[i + src_pos]
-class FrameType(Enum):
- ACK = 0
- CMD = 1
- AUX = 2
- NAK = 3
+# "convert" unsigned byte to signed
+def u8_to_s8(b: int) -> int:
+ return struct.unpack('b', bytes([b]))[0]
class PowerType(Enum):
@@ -44,23 +66,37 @@ class PowerType(Enum):
# update: if I set it to '2', it just resets to '0'
+# low-level protocol structures
+# -----------------------------
+
+class FrameType(Enum):
+ ACK = 0
+ CMD = 1
+ AUX = 2
+ NAK = 3
+
+
class FrameHead:
- seq: int # u8
+ seq: Optional[int] # u8
type: FrameType # u8
- length: int # u16
+ length: int # u16. This is the length of FrameItem's payload
@staticmethod
def from_bytes(buf: bytes) -> FrameHead:
seq, ft, length = struct.unpack('<BBH', buf)
return FrameHead(seq, FrameType(ft), length)
- def __init__(self, seq: int, frame_type: FrameType, length: Optional[int] = None):
+ def __init__(self,
+ seq: Optional[int],
+ frame_type: FrameType,
+ length: Optional[int] = None):
self.seq = seq
self.type = frame_type
self.length = length or 0
def pack(self) -> bytes:
assert self.length != 0, "FrameHead.length has not been set"
+ assert self.seq is not None, "FrameHead.seq has not been set"
return struct.pack('<BBH', self.seq, self.type.value, self.length)
@@ -84,30 +120,47 @@ class FrameItem:
return bytes(ba)
+# high-level wrappers around FrameItem
+# ------------------------------------
+
+class MessagePhase(Enum):
+ WAITING = 0
+ SENT = 1
+ DONE = 2
+
+
class Message:
frame: Optional[FrameItem]
+ id: int
+
+ _global_id = 0
def __init__(self):
self.frame = None
+ # global internal message id, only useful for debugging purposes
+ self.id = self.next_id()
+
+ def __repr__(self):
+ return f'<{self.__class__.__name__} id={self.id} seq={self.frame.head.seq}>'
+
@staticmethod
- def from_encrypted(buf: bytes,
- inkey: bytes,
- outkey: bytes) -> Message:
- # _logger.debug('[from_encrypted] buf='+buf.hex())
- # print(f'buf len={len(buf)}')
+ def next_id():
+ _id = Message._global_id
+ Message._global_id += 1
+ return _id
+
+ @staticmethod
+ def from_encrypted(buf: bytes, inkey: bytes, outkey: bytes) -> Message:
+ _logger.debug(f'Message:from_encrypted: buf={buf.hex()}')
+
assert len(buf) >= 4, 'invalid size'
head = FrameHead.from_bytes(buf[:4])
assert len(buf) == head.length + 4, f'invalid buf size ({len(buf)} != {head.length})'
-
payload = buf[4:]
-
- # byte b = paramFrameHead.seq;
b = head.seq
- # TODO check if protocol is 2, otherwise raise an exception
-
j = b & 0xF
k = b >> 4 & 0xF
@@ -123,15 +176,9 @@ class Message:
decryptor = cipher.decryptor()
decrypted_data = decryptor.update(payload) + decryptor.finalize()
- # print(f'head.length={head.length} len(decr)={len(decrypted_data)}')
- # if len(decrypted_data) > head.length:
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
decrypted_data = unpadder.update(decrypted_data)
- # try:
decrypted_data += unpadder.finalize()
- # except ValueError as exc:
- # _logger.exception(exc)
- # pass
assert len(decrypted_data) != 0, 'decrypted data is null'
assert head.seq == decrypted_data[0], f'decrypted seq mismatch {head.seq} != {decrypted_data[0]}'
@@ -145,29 +192,54 @@ class Message:
return NakMessage(head.seq)
elif head.type == FrameType.AUX:
+ # TODO implement AUX
raise NotImplementedError('FrameType AUX is not yet implemented')
elif head.type == FrameType.CMD:
- cmd = decrypted_data[0]
+ type = decrypted_data[1]
data = decrypted_data[2:]
- return CmdMessage(head.seq, cmd, data)
+
+ cl = UnknownMessage
+
+ subclasses = [cl for cl in CmdIncomingMessage.__subclasses__() if cl is not SimpleBooleanMessage]
+ subclasses.extend(SimpleBooleanMessage.__subclasses__())
+
+ for _cl in subclasses:
+ # `UnknownMessage` is a special class that holds a packed command that we don't recognize.
+ # It will be used anyway if we don't find a match, so skip it here
+ if _cl == UnknownMessage:
+ continue
+
+ if _cl.TYPE == type:
+ cl = _cl
+ break
+
+ m = cl.from_packed_data(data, seq=head.seq)
+ if isinstance(m, UnknownMessage):
+ m.set_type(type)
+ return m
else:
raise NotImplementedError(f'Unexpected frame type: {head.type}')
- @property
- def data(self) -> bytes:
+ def pack_data(self) -> bytes:
return b''
- def encrypt(self,
- outkey: bytes,
- inkey: bytes,
- token: bytes,
- pubkey: bytes):
+ @property
+ def seq(self) -> Union[int, None]:
+ try:
+ return self.frame.head.seq
+ except:
+ return None
+
+ @seq.setter
+ def seq(self, seq: int):
+ self.frame.head.seq = seq
+ def encrypt(self, outkey: bytes, inkey: bytes, token: bytes, pubkey: bytes):
assert self.frame is not None
- data = self.data
+ data = self._get_data_to_encrypt()
assert data is not None
b = self.frame.head.seq
@@ -199,7 +271,7 @@ class Message:
arraycopy(data, 0, newdata, 1, len(data))
newdata = bytes(newdata)
- _logger.debug('payload to be sent: ' + newdata.hex())
+ _logger.debug('frame payload to be encrypted: ' + newdata.hex())
padder = padding.PKCS7(algorithms.AES.block_size).padder()
ciphertext = bytearray()
@@ -208,74 +280,143 @@ class Message:
self.frame.setpayload(ciphertext)
- def set_seq(self, seq: int):
- self.frame.head.seq = seq
-
- def __repr__(self):
- return f'<{self.__class__.__name__} seq={self.frame.head.seq}>'
+ def _get_data_to_encrypt(self) -> bytes:
+ return self.pack_data()
-class AckMessage(Message):
- def __init__(self, seq: int = 0):
+class AckMessage(Message, ABC):
+ def __init__(self, seq: Optional[int] = None):
super().__init__()
- self.frame = FrameItem(FrameHead(seq, FrameType.ACK, 0))
+ self.frame = FrameItem(FrameHead(seq, FrameType.ACK, None))
-class NakMessage(Message):
- def __init__(self, seq: int = 0):
+class NakMessage(Message, ABC):
+ def __init__(self, seq: Optional[int] = None):
super().__init__()
- self.frame = FrameItem(FrameHead(seq, FrameType.NAK, 0))
+ self.frame = FrameItem(FrameHead(seq, FrameType.NAK, None))
class CmdMessage(Message):
- _type: Optional[int]
- _data: bytes
+ type: Optional[int]
+ data: bytes
- def __init__(self, seq=0,
- type: Optional[int] = None,
- data: bytes = b''):
- super().__init__()
- self._data = data
- if type is not None:
- self.frame = FrameItem(FrameHead(seq, FrameType.CMD))
- self._type = type
- else:
- self._type = None
+ TYPE = None
- @property
- def data(self) -> bytes:
+ def _get_data_to_encrypt(self) -> bytes:
buf = bytearray()
- buf.append(self._type)
- buf.extend(self._data)
+ buf.append(self.get_type())
+ buf.extend(self.pack_data())
return bytes(buf)
+ def __init__(self, seq: Optional[int] = None):
+ super().__init__()
+ self.frame = FrameItem(FrameHead(seq, FrameType.CMD))
+ self.data = b''
+
+ def _repr_fields(self) -> ReprDict:
+ return {
+ 'cmd': self.get_type()
+ }
+
def __repr__(self):
params = [
__name__+'.'+self.__class__.__name__,
- f'seq={self.frame.head.seq}',
- # f'type={self.frame.head.type}',
- f'cmd={self._type}'
+ f'id={self.id}',
+ f'seq={self.seq}'
]
- if self._data:
- params.append(f'data={self._data.hex()}')
+ fields = self._repr_fields()
+ if fields:
+ for k, v in fields.items():
+ params.append(f'{k}={v}')
+ elif self.data:
+ params.append(f'data={self.data.hex()}')
return '<'+' '.join(params)+'>'
+ def get_type(self) -> int:
+ return self.__class__.TYPE
+
+
+class CmdIncomingMessage(CmdMessage):
+ @staticmethod
+ @abstractmethod
+ def from_packed_data(cls, data: bytes, seq: Optional[int] = None):
+ pass
+
+ @abstractmethod
+ def _repr_fields(self) -> ReprDict:
+ pass
+
+
+class CmdOutgoingMessage(CmdMessage):
+ @abstractmethod
+ def pack_data(self) -> bytes:
+ return b''
+
+
+class ModeMessage(CmdOutgoingMessage, CmdIncomingMessage):
+ TYPE = 1
+
+ pt: PowerType
+
+ def __init__(self, power_type: PowerType, seq: Optional[int] = None):
+ super().__init__(seq)
+ self.pt = power_type
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> ModeMessage:
+ assert len(data) == 1, 'data size expected to be 1'
+ mode, = struct.unpack('B', data)
+ return ModeMessage(PowerType(mode), seq=seq)
+
+ def pack_data(self) -> bytes:
+ return self.pt.value.to_bytes(1, byteorder='little')
-class ModeMessage(CmdMessage):
- def __init__(self, power_type: PowerType):
- super().__init__(type=1,
- data=(power_type.value).to_bytes(1, byteorder='little'))
+ def _repr_fields(self) -> ReprDict:
+ return {'mode': self.pt.name}
-class TargetTemperatureMessage(CmdMessage):
- def __init__(self, temp: int):
- super().__init__(type=2,
- data=bytes(bytearray([temp, 0])))
+class TargetTemperatureMessage(CmdOutgoingMessage, CmdIncomingMessage):
+ temperature: int
+ TYPE = 2
+ def __init__(self, temp: int, seq: Optional[int] = None):
+ super().__init__(seq)
+ self.temperature = temp
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> TargetTemperatureMessage:
+ assert len(data) == 2, 'data size expected to be 2'
+ nat, frac = struct.unpack('BB', data)
+ temp = int(nat + (frac / 100))
+ return TargetTemperatureMessage(temp, seq=seq)
+
+ def pack_data(self) -> bytes:
+ return bytes([self.temperature, 0])
+
+ def _repr_fields(self) -> ReprDict:
+ return {'temperature': self.temperature}
+
+
+class PingMessage(CmdIncomingMessage, CmdOutgoingMessage):
+ TYPE = 255
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> PingMessage:
+ assert len(data) == 0, 'no data expected'
+ return PingMessage(seq=seq)
+
+ def pack_data(self) -> bytes:
+ return b''
+
+ def _repr_fields(self) -> ReprDict:
+ return {}
+
+
+# This is the first protocol message. Sent by a client.
+# Kettle usually ACKs this, but sometimes i don't get any ACK and the very next message is HandshakeResponseMessage.
class HandshakeMessage(CmdMessage):
- def __init__(self):
- super().__init__(type=0)
+ TYPE = 0
def encrypt(self,
outkey: bytes,
@@ -297,20 +438,312 @@ class HandshakeMessage(CmdMessage):
self.frame.setpayload(pld)
-# TODO
-# implement resending UDP messages if no answer has been received in a second
-# try at least 5 times, then give up
-class Connection(threading.Thread):
- seq_no: int
+# Kettle either sends this right after the handshake, of first it ACKs the handshake then sends this.
+class HandshakeResponseMessage(CmdIncomingMessage):
+ TYPE = 0
+
+ protocol: int
+ fw_major: int
+ fw_minor: int
+ mode: int
+ token: bytes
+
+ def __init__(self,
+ protocol: int,
+ fw_major: int,
+ fw_minor: int,
+ mode: int,
+ token: bytes,
+ seq: Optional[int] = None):
+ super().__init__(seq)
+ self.protocol = protocol
+ self.fw_major = fw_major
+ self.fw_minor = fw_minor
+ self.mode = mode
+ self.token = token
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> HandshakeResponseMessage:
+ protocol, fw_major, fw_minor, mode = struct.unpack('<HBBB', data[:5])
+ return HandshakeResponseMessage(protocol, fw_major, fw_minor, mode, token=data[5:], seq=seq)
+
+ def _repr_fields(self) -> ReprDict:
+ return {
+ 'protocol': self.protocol,
+ 'fw': f'{self.fw_major}.{self.fw_minor}',
+ 'mode': self.mode,
+ 'token': self.token.hex()
+ }
+
+
+# Apparently, some hardware info.
+# On the other hand, if you look at com.syncleiot.iottransport.commands.CmdHardware, its mqtt topic is "mcu_firmware".
+# My device returns 1.1.1. The thing uses on ESP8266 MCU under the hood (or, more precisely, under a piece of cheap
+# plastic), so maybe 1.1.1 is the MCU fw revision.
+class DeviceHardwareMessage(CmdIncomingMessage):
+ TYPE = 143 # -113
+
+ hw: List[int]
+
+ def __init__(self, hw: List[int], seq: Optional[int] = None):
+ super().__init__(seq)
+ self.hw = hw
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> DeviceHardwareMessage:
+ assert len(data) == 3, 'invalid data size, expected 3'
+ hw = list(struct.unpack('<BBB', data))
+ return DeviceHardwareMessage(hw, seq=seq)
+
+ def _repr_fields(self) -> ReprDict:
+ return {'device_hardware': '.'.join(map(str, self.hw))}
+
+
+# This message is sent by kettle right after the HandshakeMessageResponse.
+# The diagnostic data is supposed to be sent to vendor, which we, obviously, not going to do.
+# So just ACK and skip it.
+class DeviceDiagnosticMessage(CmdIncomingMessage):
+ TYPE = 145 # -111
+
+ diag_data: bytes
+
+ def __init__(self, diag_data: bytes, seq: Optional[int] = None):
+ super().__init__(seq)
+ self.diag_data = diag_data
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> DeviceDiagnosticMessage:
+ return DeviceDiagnosticMessage(diag_data=data, seq=seq)
+
+ def _repr_fields(self) -> ReprDict:
+ return {'diag_data': self.diag_data.hex()}
+
+
+class SimpleBooleanMessage(ABC, CmdIncomingMessage):
+ value: bool
+
+ def __init__(self, value: bool, seq: Optional[int] = None):
+ super().__init__(seq)
+ self.value = value
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq: Optional[int] = None):
+ assert len(data) == 1, 'invalid data size, expected 1'
+ enabled, = struct.unpack('<B', data)
+ return cls(value=enabled == 1, seq=seq)
+
+ @abstractmethod
+ def _repr_fields(self) -> ReprDict:
+ pass
+
+
+class AccessControlMessage(SimpleBooleanMessage):
+ TYPE = 133 # -123
+
+ def _repr_fields(self) -> ReprDict:
+ return {'acl_enabled': self.value}
+
+
+class ErrorMessage(SimpleBooleanMessage):
+ TYPE = 7
+
+ def _repr_fields(self) -> ReprDict:
+ return {'error': self.value}
+
+
+class ChildLockMessage(SimpleBooleanMessage):
+ TYPE = 30
+
+ def _repr_fields(self) -> ReprDict:
+ return {'child_lock': self.value}
+
+
+class VolumeMessage(SimpleBooleanMessage):
+ TYPE = 9
+
+ def _repr_fields(self) -> ReprDict:
+ return {'volume': self.value}
+
+
+class BacklightMessage(SimpleBooleanMessage):
+ TYPE = 28
+
+ def _repr_fields(self) -> ReprDict:
+ return {'backlight': self.value}
+
+
+class CurrentTemperatureMessage(CmdIncomingMessage):
+ TYPE = 20
+
+ current_temperature: int
+
+ def __init__(self, temp: int, seq: Optional[int] = None):
+ super().__init__(seq)
+ self.current_temperature = temp
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> CurrentTemperatureMessage:
+ assert len(data) == 2, 'data size expected to be 2'
+ nat, frac = struct.unpack('BB', data)
+ temp = int(nat + (frac / 100))
+ return CurrentTemperatureMessage(temp, seq=seq)
+
+ def pack_data(self) -> bytes:
+ return bytes([self.current_temperature, 0])
+
+ def _repr_fields(self) -> ReprDict:
+ return {'current_temperature': self.current_temperature}
+
+
+class UnknownMessage(CmdIncomingMessage):
+ type: Optional[int]
+ data: bytes
+
+ def __init__(self, data: bytes, **kwargs):
+ super().__init__(**kwargs)
+ self.type = None
+ self.data = data
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> UnknownMessage:
+ return UnknownMessage(data, seq=seq)
+
+ def set_type(self, type: int):
+ self.type = type
+
+ def get_type(self) -> int:
+ return self.type
+
+ def _repr_fields(self) -> ReprDict:
+ return {
+ 'type': self.type,
+ 'data': self.data.hex()
+ }
+
+
+class WrappedMessage:
+ _message: Message
+ _handler: Optional[callable]
+ _validator: Optional[callable]
+ _logger: Optional[logging.Logger]
+ _phase: MessagePhase
+ _phase_update_time: float
+
+ def __init__(self,
+ message: Message,
+ handler: Optional[callable] = None,
+ validator: Optional[callable] = None,
+ ack=False):
+ self._message = message
+ self._handler = handler
+ self._validator = validator
+ self._logger = None
+ self._phase = MessagePhase.WAITING
+ self._phase_update_time = 0
+ if not validator and ack:
+ self._validator = lambda m: isinstance(m, AckMessage)
+
+ def setlogger(self, logger: logging.Logger):
+ self._logger = logger
+
+ def validate(self, message: Message):
+ if not self._validator:
+ return True
+ return self._validator(message)
+
+ def call(self, *args, error_message: str = None) -> None:
+ if not self._handler:
+ return
+ try:
+ self._handler(*args)
+ except Exception as exc:
+ logger = self._logger or logging.getLogger(self.__class__.__name__)
+ logger.error(f'{error_message}, see exception below:')
+ logger.exception(exc)
+
+ @property
+ def phase(self) -> MessagePhase:
+ return self._phase
+
+ @phase.setter
+ def phase(self, phase: MessagePhase):
+ self._phase = phase
+ self._phase_update_time = 0 if phase == MessagePhase.WAITING else time.time()
+
+ @property
+ def phase_update_time(self) -> float:
+ return self._phase_update_time
+
+ @property
+ def message(self) -> Message:
+ return self._message
+
+ @property
+ def id(self) -> int:
+ return self._message.id
+
+ @property
+ def seq(self) -> int:
+ return self._message.seq
+
+ @seq.setter
+ def seq(self, seq: int):
+ self._message.seq = seq
+
+ def __repr__(self):
+ return f'<{__name__}.{self.__class__.__name__} message={self._message.__repr__()}>'
+
+
+# Connection stuff
+# Well, strictly speaking, as it's UDP, there's no connection, but who cares.
+# ---------------------------------------------------------------------------
+
+class IncomingMessageListener:
+ @abstractmethod
+ def incoming_message(self, message: Message) -> Optional[Message]:
+ pass
+
+
+class ConnectionStatus(Enum):
+ NOT_CONNECTED = auto()
+ CONNECTING = auto()
+ CONNECTED = auto()
+ RECONNECTING = auto()
+ DISCONNECTED = auto()
+
+
+class ConnectionStatusListener:
+ @abstractmethod
+ def connection_status_updated(self, status: ConnectionStatus):
+ pass
+
+
+class UDPConnection(threading.Thread, ConnectionStatusListener):
+ inseq: int
+ outseq: int
source_port: int
device_addr: str
device_port: int
device_token: bytes
+ device_pubkey: bytes
interrupted: bool
- waiting_for_response: dict[int, callable]
+ response_handlers: Dict[int, WrappedMessage]
+ outgoing_queue: List[WrappedMessage]
pubkey: Optional[bytes]
encinkey: Optional[bytes]
encoutkey: Optional[bytes]
+ inc_listeners: List[IncomingMessageListener]
+ conn_listeners: List[ConnectionStatusListener]
+ outgoing_time: float
+ outgoing_time_1st: float
+ incoming_time: float
+ status: ConnectionStatus
+ reconnect_tries: int
+
+ _addr_lock: threading.Lock
+ _iml_lock: threading.Lock
+ _csl_lock: threading.Lock
+ _st_lock: threading.Lock
def __init__(self,
addr: Union[IPv4Address, IPv6Address],
@@ -318,36 +751,105 @@ class Connection(threading.Thread):
device_pubkey: bytes,
device_token: bytes):
super().__init__()
- self.logger = logging.getLogger(__name__+'.'+self.__class__.__name__)
+ self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__} <{hex(id(self))}>')
self.setName(self.__class__.__name__)
- # self.daemon = True
- self.seq_no = -1
+ self.inseq = 0
+ self.outseq = 0
self.source_port = random.randint(1024, 65535)
self.device_addr = str(addr)
self.device_port = port
self.device_token = device_token
- self.lock = threading.Lock()
- self.outgoing_queue = queue.SimpleQueue()
- self.waiting_for_response = {}
+ self.device_pubkey = device_pubkey
+ self.outgoing_queue = []
+ self.response_handlers = {}
self.interrupted = False
+ self.outgoing_time = 0
+ self.outgoing_time_1st = 0
+ self.incoming_time = 0
+ self.inc_listeners = []
+ self.conn_listeners = [self]
+ self.status = ConnectionStatus.NOT_CONNECTED
+ self.reconnect_tries = 0
+
+ self._iml_lock = threading.Lock()
+ self._csl_lock = threading.Lock()
+ self._addr_lock = threading.Lock()
+ self._st_lock = threading.Lock()
self.pubkey = None
self.encinkey = None
self.encoutkey = None
- self.prepare_keys(device_pubkey)
-
- def prepare_keys(self, device_pubkey: bytes):
+ def connection_status_updated(self, status: ConnectionStatus):
+ # self._logger.info(f'connection_status_updated: status = {status}')
+ with self._st_lock:
+ # self._logger.debug(f'connection_status_updated: lock acquired')
+ self.status = status
+ if status == ConnectionStatus.RECONNECTING:
+ self.reconnect_tries += 1
+ if status in (ConnectionStatus.CONNECTED, ConnectionStatus.NOT_CONNECTED, ConnectionStatus.DISCONNECTED):
+ self.reconnect_tries = 0
+
+ def _cleanup(self):
+ # erase outgoing queue
+ for wm in self.outgoing_queue:
+ wm.call(False,
+ error_message=f'_cleanup: exception while calling cb(False) on message {wm.message}')
+ self.outgoing_queue = []
+ self.response_handlers = {}
+
+ # reset timestamps
+ self.incoming_time = 0
+ self.outgoing_time = 0
+ self.outgoing_time_1st = 0
+
+ self._logger.info('_cleanup: done')
+
+ def set_address(self, addr: Union[IPv4Address, IPv6Address], port: int):
+ with self._addr_lock:
+ if self.device_addr != str(addr) or self.device_port != port:
+ self.device_addr = str(addr)
+ self.device_port = port
+ self._logger.info(f'updated device network address: {self.device_addr}:{self.device_port}')
+
+ def set_device_pubkey(self, pubkey: bytes):
+ if self.device_pubkey.hex() != pubkey.hex():
+ self._logger.info(f'device pubkey has changed (old={self.device_pubkey.hex()}, new={pubkey.hex()})')
+ self.device_pubkey = pubkey
+ self._notify_cs(ConnectionStatus.RECONNECTING)
+
+ def get_address(self) -> Tuple[str, int]:
+ with self._addr_lock:
+ return self.device_addr, self.device_port
+
+ def add_incoming_message_listener(self, listener: IncomingMessageListener):
+ with self._iml_lock:
+ if listener not in self.inc_listeners:
+ self.inc_listeners.append(listener)
+
+ def add_connection_status_listener(self, listener: ConnectionStatusListener):
+ with self._csl_lock:
+ if listener not in self.conn_listeners:
+ self.conn_listeners.append(listener)
+
+ def _notify_cs(self, status: ConnectionStatus):
+ # self._logger.debug(f'_notify_cs: status={status}')
+ with self._csl_lock:
+ for obj in self.conn_listeners:
+ # self._logger.debug(f'_notify_cs: notifying {obj}')
+ obj.connection_status_updated(status)
+
+ def _prepare_keys(self):
# generate key pair
privkey = X25519PrivateKey.generate()
- self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=cryptography.hazmat.primitives._serialization.Encoding.Raw,
- format=cryptography.hazmat.primitives._serialization.PublicFormat.Raw)))
+ self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=srlz.Encoding.Raw,
+ format=srlz.PublicFormat.Raw)))
# generate shared key
device_pubkey = X25519PublicKey.from_public_bytes(
- bytes(reversed(device_pubkey))
+ bytes(reversed(self.device_pubkey))
)
shared_key = bytes(reversed(
privkey.exchange(device_pubkey)
@@ -362,51 +864,286 @@ class Connection(threading.Thread):
self.encinkey = shared_sha256[:16]
self.encoutkey = shared_sha256[16:]
+ self._logger.info('encryption keys have been created')
+
+ def _handshake_callback(self, r: MessageResponse):
+ # if got error for our HandshakeMessage, reset everything and try again
+ if r is False:
+ # self._logger.debug('_handshake_callback: set status=RECONNETING')
+ self._notify_cs(ConnectionStatus.RECONNECTING)
+ else:
+ # self._logger.debug('_handshake_callback: set status=CONNECTED')
+ self._notify_cs(ConnectionStatus.CONNECTED)
+
def run(self):
+ self._logger.info('starting server loop')
+
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(('0.0.0.0', self.source_port))
- sock.settimeout(1)
+ sock.settimeout(READ_TIMEOUT)
while not self.interrupted:
- if not self.outgoing_queue.empty():
- message = self.outgoing_queue.get()
- message.encrypt(outkey=self.encoutkey,
- inkey=self.encinkey,
- token=self.device_token,
- pubkey=self.pubkey)
- buf = message.frame.pack()
- self.logger.debug('send: '+buf.hex())
- self.logger.debug(f'sendto: {self.device_addr}:{self.device_port}')
- sock.sendto(buf, (self.device_addr, self.device_port))
+ with self._st_lock:
+ status = self.status
+
+ if status in (ConnectionStatus.DISCONNECTED, ConnectionStatus.RECONNECTING):
+ self._cleanup()
+ if status == ConnectionStatus.DISCONNECTED:
+ break
+
+ # no activity for some time means connection is broken
+ fail = False
+ fail_path = 0
+ if self.incoming_time > 0 and time.time() - self.incoming_time >= DISCONNECT_TIMEOUT:
+ fail = True
+ fail_path = 1
+ elif self.outgoing_time_1st > 0 and self.incoming_time == 0 and time.time() - self.outgoing_time_1st >= DISCONNECT_TIMEOUT:
+ fail = True
+ fail_path = 2
+
+ if fail:
+ self._logger.debug(f'run: setting status=RECONNECTING because of long inactivity, fail_path={fail_path}')
+ self._notify_cs(ConnectionStatus.RECONNECTING)
+
+ # establishing a connection
+ if status in (ConnectionStatus.RECONNECTING, ConnectionStatus.NOT_CONNECTED):
+ if status == ConnectionStatus.RECONNECTING and self.reconnect_tries >= 3:
+ self._notify_cs(ConnectionStatus.DISCONNECTED)
+ continue
+
+ self._reset_outseq()
+ self._prepare_keys()
+
+ # shake the imaginary kettle's hand
+ wrapped = WrappedMessage(HandshakeMessage(),
+ handler=self._handshake_callback,
+ validator=lambda m: isinstance(m, (AckMessage, HandshakeResponseMessage)))
+ self.enqueue_message(wrapped, prepend=True)
+ self._notify_cs(ConnectionStatus.CONNECTING)
+
+ # pick next (wrapped) message to send
+ wm = self._get_next_message() # wm means "wrapped message"
+ if wm:
+ if not isinstance(wm.message, (AckMessage, NakMessage)):
+ old_seq = wm.seq
+ wm.seq = self.outseq
+ self._set_response_handler(wm, old_seq=old_seq)
+ elif wm.seq is None:
+ # ack/nak is a response to some incoming message (and it must have the same seqno that incoming
+ # message had)
+ raise RuntimeError(f'run: seq must be set for {wm.__class__.__name__}')
+
+ self._logger.debug(f'run: sending message: {wm.message}')
+ wm.message.encrypt(outkey=self.encoutkey, inkey=self.encinkey,
+ token=self.device_token, pubkey=self.pubkey)
+ buf = wm.message.frame.pack()
+ one_shot = isinstance(wm.message, (AckMessage, NakMessage, PingMessage))
+ # self._logger.debug(f'run: raw data to be sent: {buf.hex()}')
+
+ # sending the first time
+ if wm.phase == MessagePhase.WAITING:
+ sock.sendto(buf, self.get_address())
+ # resending
+ elif wm.phase == MessagePhase.SENT:
+ left = RESEND_ATTEMPTS
+ while left > 0:
+ sock.sendto(buf, self.get_address())
+ left -= 1
+ if left > 0:
+ time.sleep(0.05)
+
+ if one_shot or wm.phase == MessagePhase.SENT:
+ wm.phase = MessagePhase.DONE
+ else:
+ wm.phase = MessagePhase.SENT
+
+ now = time.time()
+ self.outgoing_time = now
+ if not self.outgoing_time_1st:
+ self.outgoing_time_1st = now
+
+ # receiving data
try:
data = sock.recv(4096)
- self.handle_incoming(data)
- except TimeoutError:
+ self._handle_incoming(data)
+ except (TimeoutError, socket.timeout):
pass
- def handle_incoming(self, buf: bytes):
- self.logger.debug('handle_incoming: '+buf.hex())
- message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey)
- if message.frame.head.seq in self.waiting_for_response:
- self.logger.info(f'received awaited message: {message}')
+ self._logger.info('bye...')
+
+ def _get_next_message(self) -> Optional[WrappedMessage]:
+ message = None
+ lpfx = '_get_next_message:'
+ remove_list = []
+ for wm in self.outgoing_queue:
+ if wm.phase == MessagePhase.DONE:
+ if isinstance(wm.message, (AckMessage, NakMessage)) or time.time() - wm.phase_update_time >= MESSAGE_QUEUE_REMOVE_DELAY:
+ remove_list.append(wm)
+ continue
+ message = wm
+ break
+
+ for wm in remove_list:
+ self._logger.debug(f'{lpfx} rm path: removing id={wm.id} seq={wm.seq}')
+
+ # clear message handler
+ if wm.seq in self.response_handlers:
+ self.response_handlers[wm.seq].call(
+ False, error_message=f'{lpfx} rm path: error while calling callback for seq={wm.seq}')
+ del self.response_handlers[wm.seq]
+
+ # remove from queue
try:
- f = self.waiting_for_response[message.frame.head.seq]
- f(message)
- except Exception as exc:
- self.logger.exception(exc)
- finally:
- del self.waiting_for_response[message.frame.head.seq]
+ self.outgoing_queue.remove(wm)
+ except ValueError as exc:
+ self._logger.error(f'{lpfx} rm path: removing from outgoing_queue raised an exception: {str(exc)}')
+
+ # ping pong
+ if self.outgoing_time_1st != 0 and self.status == ConnectionStatus.CONNECTED:
+ now = time.time()
+ out_delta = now - self.outgoing_time
+ in_delta = now - self.incoming_time
+ if not message and max(out_delta, in_delta) > PING_FREQUENCY:
+ self._logger.debug(f'{lpfx} no activity: in for {in_delta:.2f}s, out for {out_delta:.2f}s, time to ping the damn thing')
+ message = WrappedMessage(PingMessage(), ack=True)
+
+ return message
+
+ def _handle_incoming(self, buf: bytes):
+ self.incoming_time = time.time()
+
+ incoming_message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey)
+ seq = incoming_message.seq
+
+ lpfx = f'handle_incoming({incoming_message.id}):'
+ self._logger.debug(f'{lpfx} received: {incoming_message}')
+
+ if isinstance(incoming_message, (AckMessage, NakMessage)):
+ seq_max = self.outseq
+ seq_name = 'outseq'
+ else:
+ seq_max = self.inseq
+ seq_name = 'inseq'
+ self.inseq = seq
+
+ if seq < seq_max < 0xfd:
+ self._logger.warning(f'{lpfx} dropping: seq={seq}, {seq_name}={seq_name}')
+ return
+
+ if seq not in self.response_handlers:
+ self._handle_incoming_cmd(incoming_message)
+ return
+
+ callback_value = None # None means don't call a callback
+ handler = self.response_handlers[seq]
+
+ if handler.validate(incoming_message):
+ self._logger.info(f'{lpfx} response OK')
+ handler.phase = MessagePhase.DONE
+ callback_value = incoming_message
+ self._incr_outseq()
else:
- self.logger.info(f'received message (not awaited): {message}')
-
- def send_message(self, message: Message, callback: callable):
- seq = self.next_seqno()
- message.set_seq(seq)
- self.outgoing_queue.put(message)
- self.waiting_for_response[seq] = callback
-
- def next_seqno(self) -> int:
- with self.lock:
- self.seq_no += 1
- self.logger.debug(f'next_seqno: set to {self.seq_no}')
- return self.seq_no
+ self._logger.info(f'{lpfx} response is INVALID')
+
+ # It seems that we've received an incoming CmdMessage or PingMessage with the same seqno that our outgoing
+ # message had. Bad, but what can I say, this is quick-and-dirty made UDP based protocol and this sort of
+ # shit just happens.
+
+ # (To be fair, maybe my implementation is not perfect either. But hey, what did you expect from a
+ # reverse-engineered re-implementation of custom UDP-based protocol that some noname vendor uses for their
+ # cheap IoT devices? I think _that_ is _the_ definition of shit. At least my implementation is FOSS, which
+ # is more than you'll ever be able to say about them.)
+
+ # All this crapload of code below might not be needed at all, 'cause the protocol uses separate frame seq
+ # numbers for IN and OUT frames and this situation is not highly likely, as Theresa May could argue.
+ # After a handshake, a kettle sends us 10 or so CmdMessages, and then either we continuously ping it every
+ # 3 seconds, or kettle pings us. This in any case widens the gap between inseq and outseq.
+
+ # But! the seqno is only 1 byte in size and once it reaches 0xff, it circles back to zero. And that (plus,
+ # perhaps, some bad luck) gives a chance for a collision.
+
+ if handler.phase == MessagePhase.DONE or isinstance(handler.message, HandshakeMessage):
+ # no more attempts left, returning error back to user
+ # as to handshake, it cannot fail.
+ callback_value = False
+
+ # else:
+ # # try resending the message
+ # handler.phase_reset()
+ # max_seq = self.outseq
+ # wait_remap = {}
+ # for m in self.outgoing_queue:
+ # if m.seq in self.waiting_for_response:
+ # wait_remap[m.seq] = (m.seq+1) % 256
+ # m.set_seq((m.seq+1) % 256)
+ # if m.seq > max_seq:
+ # max_seq = m.seq
+ # if max_seq > self.outseq:
+ # self.outseq = max_seq % 256
+ # if wait_remap:
+ # waiting_new = {}
+ # for old_seq, new_seq in wait_remap.items():
+ # waiting_new[new_seq] = self.waiting_for_response[old_seq]
+ # self.waiting_for_response = waiting_new
+
+ if isinstance(incoming_message, (PingMessage, CmdIncomingMessage)):
+ # handle incoming message as usual, as we need to ack/nak it anyway
+ self._handle_incoming_cmd(incoming_message)
+
+ if callback_value is not None:
+ handler.call(callback_value,
+ error_message=f'{lpfx} error while calling callback for msg id={handler.message.id} seq={seq}')
+ del self.response_handlers[seq]
+
+ def _handle_incoming_cmd(self, incoming_message: Message):
+ if isinstance(incoming_message, (AckMessage, NakMessage)):
+ self._logger.debug(f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): it\'s {incoming_message.__class__.__name__}, ignoring')
+ return
+
+ replied = False
+ with self._iml_lock:
+ for f in self.inc_listeners:
+ retval = safe_callback_call(f.incoming_message, incoming_message,
+ logger=self._logger,
+ error_message=f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): error while calling message listener')
+ if isinstance(retval, Message):
+ if isinstance(retval, (AckMessage, NakMessage)):
+ retval.seq = incoming_message.seq
+ self.enqueue_message(WrappedMessage(retval), prepend=True)
+ replied = True
+ break
+ else:
+ raise RuntimeError('are you sure your response is correct? only ack/nak are allowed')
+
+ if not replied:
+ self.enqueue_message(WrappedMessage(AckMessage(incoming_message.seq)), prepend=True)
+
+ def enqueue_message(self, wrapped: WrappedMessage, prepend=False):
+ self._logger.debug(f'enqueue_message: {wrapped.message}')
+ if not prepend:
+ self.outgoing_queue.append(wrapped)
+ else:
+ self.outgoing_queue.insert(0, wrapped)
+
+ def _set_response_handler(self, wm: WrappedMessage, old_seq=None):
+ if old_seq in self.response_handlers:
+ del self.response_handlers[old_seq]
+
+ seq = wm.seq
+ assert seq is not None, 'seq is not set'
+
+ if seq in self.response_handlers:
+ self._logger.warning(f'_set_response_handler(seq={seq}): handler is already set, cancelling it')
+ self.response_handlers[seq].call(False,
+ error_message=f'_set_response_handler({seq}): error while calling old callback')
+ self.response_handlers[seq] = wm
+
+ def _incr_outseq(self) -> None:
+ self.outseq = (self.outseq + 1) % 256
+
+ def _reset_outseq(self):
+ self.outseq = 0
+ self._logger.debug(f'_reset_outseq: set 0')
+
+
+MessageResponse = Union[Message, bool]