summaryrefslogtreecommitdiff
path: root/src/syncleo/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/syncleo/protocol.py')
-rw-r--r--src/syncleo/protocol.py1169
1 files changed, 0 insertions, 1169 deletions
diff --git a/src/syncleo/protocol.py b/src/syncleo/protocol.py
deleted file mode 100644
index 36a1a8f..0000000
--- a/src/syncleo/protocol.py
+++ /dev/null
@@ -1,1169 +0,0 @@
-# Polaris PWK 1725CGLD "smart" kettle python library
-# --------------------------------------------------
-# Copyright (C) Evgeny Zinoviev, 2022
-# License: BSD-3c
-
-from __future__ import annotations
-
-import logging
-import socket
-import random
-import struct
-import threading
-import time
-
-from abc import abstractmethod, ABC
-from enum import Enum, auto
-from typing import Union, Optional, Dict, Tuple, List
-from ipaddress import IPv4Address, IPv6Address
-
-import cryptography.hazmat.primitives._serialization as srlz
-
-from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
-from cryptography.hazmat.primitives import ciphers, padding, hashes
-from cryptography.hazmat.primitives.ciphers import algorithms, modes
-
-ReprDict = Dict[str, Union[str, int, float, bool]]
-_logger = logging.getLogger(__name__)
-
-PING_FREQUENCY = 3
-RESEND_ATTEMPTS = 5
-ERROR_TIMEOUT = 15
-MESSAGE_QUEUE_REMOVE_DELAY = 13 # after what time to delete (and pass False to handlers, if needed) messages with phase=DONE from queue
-DISCONNECT_TIMEOUT = 15
-
-
-def safe_callback_call(f: callable,
- *args,
- logger: logging.Logger = None,
- error_message: str = None):
- try:
- return f(*args)
- except Exception as exc:
- logger.error(f'{error_message}, see exception below:')
- logger.exception(exc)
- return None
-
-
-# drop-in replacement for java.lang.System.arraycopy
-# TODO: rewrite
-def arraycopy(src, src_pos, dest, dest_pos, length):
- for i in range(length):
- dest[i + dest_pos] = src[i + src_pos]
-
-
-# "convert" unsigned byte to signed
-def u8_to_s8(b: int) -> int:
- return struct.unpack('b', bytes([b]))[0]
-
-
-class PowerType(Enum):
- OFF = 0 # turn off
- ON = 1 # turn on, set target temperature to 100
- CUSTOM = 3 # turn on, allows custom target temperature
- # MYSTERY_MODE = 2 # don't know what 2 means, needs testing
- # update: if I set it to '2', it just resets to '0'
-
-
-# low-level protocol structures
-# -----------------------------
-
-class FrameType(Enum):
- ACK = 0
- CMD = 1
- AUX = 2
- NAK = 3
-
-
-class FrameHead:
- seq: Optional[int] # u8
- type: FrameType # u8
- length: int # u16. This is the length of FrameItem's payload
-
- @staticmethod
- def from_bytes(buf: bytes) -> FrameHead:
- seq, ft, length = struct.unpack('<BBH', buf)
- return FrameHead(seq, FrameType(ft), length)
-
- def __init__(self,
- seq: Optional[int],
- frame_type: FrameType,
- length: Optional[int] = None):
- self.seq = seq
- self.type = frame_type
- self.length = length or 0
-
- def pack(self) -> bytes:
- assert self.length != 0, "FrameHead.length has not been set"
- assert self.seq is not None, "FrameHead.seq has not been set"
- return struct.pack('<BBH', self.seq, self.type.value, self.length)
-
-
-class FrameItem:
- head: FrameHead
- payload: bytes
-
- def __init__(self, head: FrameHead, payload: Optional[bytes] = None):
- self.head = head
- self.payload = payload
-
- def setpayload(self, payload: Union[bytes, bytearray]):
- if isinstance(payload, bytearray):
- payload = bytes(payload)
- self.payload = payload
- self.head.length = len(payload)
-
- def pack(self) -> bytes:
- ba = bytearray(self.head.pack())
- ba.extend(self.payload)
- return bytes(ba)
-
-
-# high-level wrappers around FrameItem
-# ------------------------------------
-
-class MessagePhase(Enum):
- WAITING = 0
- SENT = 1
- DONE = 2
-
-
-class Message:
- frame: Optional[FrameItem]
- id: int
-
- _global_id = 0
-
- def __init__(self):
- self.frame = None
-
- # global internal message id, only useful for debugging purposes
- self.id = self.next_id()
-
- def __repr__(self):
- return f'<{self.__class__.__name__} id={self.id} seq={self.frame.head.seq}>'
-
- @staticmethod
- def next_id():
- _id = Message._global_id
- Message._global_id = (Message._global_id + 1) % 100000
- return _id
-
- @staticmethod
- def from_encrypted(buf: bytes, inkey: bytes, outkey: bytes) -> Message:
- _logger.debug(f'Message:from_encrypted: buf={buf.hex()}')
-
- assert len(buf) >= 4, 'invalid size'
- head = FrameHead.from_bytes(buf[:4])
-
- assert len(buf) == head.length + 4, f'invalid buf size ({len(buf)} != {head.length})'
- payload = buf[4:]
- b = head.seq
-
- j = b & 0xF
- k = b >> 4 & 0xF
-
- key = bytearray(len(inkey))
- arraycopy(inkey, j, key, 0, len(inkey) - j)
- arraycopy(inkey, 0, key, len(inkey) - j, j)
-
- iv = bytearray(len(outkey))
- arraycopy(outkey, k, iv, 0, len(outkey) - k)
- arraycopy(outkey, 0, iv, len(outkey) - k, k)
-
- cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv))
- decryptor = cipher.decryptor()
- decrypted_data = decryptor.update(payload) + decryptor.finalize()
-
- unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
- decrypted_data = unpadder.update(decrypted_data)
- decrypted_data += unpadder.finalize()
-
- assert len(decrypted_data) != 0, 'decrypted data is null'
- assert head.seq == decrypted_data[0], f'decrypted seq mismatch {head.seq} != {decrypted_data[0]}'
-
- # _logger.debug('Message.from_encrypted: plaintext: '+decrypted_data.hex())
-
- if head.type == FrameType.ACK:
- return AckMessage(head.seq)
-
- elif head.type == FrameType.NAK:
- return NakMessage(head.seq)
-
- elif head.type == FrameType.AUX:
- # TODO implement AUX
- raise NotImplementedError('FrameType AUX is not yet implemented')
-
- elif head.type == FrameType.CMD:
- type = decrypted_data[1]
- data = decrypted_data[2:]
-
- cl = UnknownMessage
-
- subclasses = [cl for cl in CmdIncomingMessage.__subclasses__() if cl is not SimpleBooleanMessage]
- subclasses.extend(SimpleBooleanMessage.__subclasses__())
-
- for _cl in subclasses:
- # `UnknownMessage` is a special class that holds a packed command that we don't recognize.
- # It will be used anyway if we don't find a match, so skip it here
- if _cl == UnknownMessage:
- continue
-
- if _cl.TYPE == type:
- cl = _cl
- break
-
- m = cl.from_packed_data(data, seq=head.seq)
- if isinstance(m, UnknownMessage):
- m.set_type(type)
- return m
-
- else:
- raise NotImplementedError(f'Unexpected frame type: {head.type}')
-
- def pack_data(self) -> bytes:
- return b''
-
- @property
- def seq(self) -> Union[int, None]:
- try:
- return self.frame.head.seq
- except:
- return None
-
- @seq.setter
- def seq(self, seq: int):
- self.frame.head.seq = seq
-
- def encrypt(self, outkey: bytes, inkey: bytes, token: bytes, pubkey: bytes):
- assert self.frame is not None
-
- data = self._get_data_to_encrypt()
- assert data is not None
-
- b = self.frame.head.seq
- i = b & 0xf
- j = b >> 4 & 0xf
-
- outkey = bytearray(outkey)
-
- l = len(outkey)
- key = bytearray(l)
-
- arraycopy(outkey, i, key, 0, l-i)
- arraycopy(outkey, 0, key, l-i, i)
-
- inkey = bytearray(inkey)
-
- l = len(inkey)
- iv = bytearray(l)
-
- arraycopy(inkey, j, iv, 0, l-j)
- arraycopy(inkey, 0, iv, l-j, j)
-
- cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv))
- encryptor = cipher.encryptor()
-
- newdata = bytearray(len(data)+1)
- newdata[0] = b
-
- arraycopy(data, 0, newdata, 1, len(data))
-
- newdata = bytes(newdata)
- _logger.debug('frame payload to be encrypted: ' + newdata.hex())
-
- padder = padding.PKCS7(algorithms.AES.block_size).padder()
- ciphertext = bytearray()
- ciphertext.extend(encryptor.update(padder.update(newdata) + padder.finalize()))
- ciphertext.extend(encryptor.finalize())
-
- self.frame.setpayload(ciphertext)
-
- def _get_data_to_encrypt(self) -> bytes:
- return self.pack_data()
-
-
-class AckMessage(Message, ABC):
- def __init__(self, seq: Optional[int] = None):
- super().__init__()
- self.frame = FrameItem(FrameHead(seq, FrameType.ACK, None))
-
-
-class NakMessage(Message, ABC):
- def __init__(self, seq: Optional[int] = None):
- super().__init__()
- self.frame = FrameItem(FrameHead(seq, FrameType.NAK, None))
-
-
-class CmdMessage(Message):
- type: Optional[int]
- data: bytes
-
- TYPE = None
-
- def _get_data_to_encrypt(self) -> bytes:
- buf = bytearray()
- buf.append(self.get_type())
- buf.extend(self.pack_data())
- return bytes(buf)
-
- def __init__(self, seq: Optional[int] = None):
- super().__init__()
- self.frame = FrameItem(FrameHead(seq, FrameType.CMD))
- self.data = b''
-
- def _repr_fields(self) -> ReprDict:
- return {
- 'cmd': self.get_type()
- }
-
- def __repr__(self):
- params = [
- __name__+'.'+self.__class__.__name__,
- f'id={self.id}',
- f'seq={self.seq}'
- ]
- fields = self._repr_fields()
- if fields:
- for k, v in fields.items():
- params.append(f'{k}={v}')
- elif self.data:
- params.append(f'data={self.data.hex()}')
- return '<'+' '.join(params)+'>'
-
- def get_type(self) -> int:
- return self.__class__.TYPE
-
-
-class CmdIncomingMessage(CmdMessage):
- @staticmethod
- @abstractmethod
- def from_packed_data(cls, data: bytes, seq: Optional[int] = None):
- pass
-
- @abstractmethod
- def _repr_fields(self) -> ReprDict:
- pass
-
-
-class CmdOutgoingMessage(CmdMessage):
- @abstractmethod
- def pack_data(self) -> bytes:
- return b''
-
-
-class ModeMessage(CmdOutgoingMessage, CmdIncomingMessage):
- TYPE = 1
-
- pt: PowerType
-
- def __init__(self, power_type: PowerType, seq: Optional[int] = None):
- super().__init__(seq)
- self.pt = power_type
-
- @classmethod
- def from_packed_data(cls, data: bytes, seq=0) -> ModeMessage:
- assert len(data) == 1, 'data size expected to be 1'
- mode, = struct.unpack('B', data)
- return ModeMessage(PowerType(mode), seq=seq)
-
- def pack_data(self) -> bytes:
- return self.pt.value.to_bytes(1, byteorder='little')
-
- def _repr_fields(self) -> ReprDict:
- return {'mode': self.pt.name}
-
-
-class TargetTemperatureMessage(CmdOutgoingMessage, CmdIncomingMessage):
- temperature: int
-
- TYPE = 2
-
- def __init__(self, temp: int, seq: Optional[int] = None):
- super().__init__(seq)
- self.temperature = temp
-
- @classmethod
- def from_packed_data(cls, data: bytes, seq=0) -> TargetTemperatureMessage:
- assert len(data) == 2, 'data size expected to be 2'
- nat, frac = struct.unpack('BB', data)
- temp = int(nat + (frac / 100))
- return TargetTemperatureMessage(temp, seq=seq)
-
- def pack_data(self) -> bytes:
- return bytes([self.temperature, 0])
-
- def _repr_fields(self) -> ReprDict:
- return {'temperature': self.temperature}
-
-
-class PingMessage(CmdIncomingMessage, CmdOutgoingMessage):
- TYPE = 255
-
- @classmethod
- def from_packed_data(cls, data: bytes, seq=0) -> PingMessage:
- assert len(data) == 0, 'no data expected'
- return PingMessage(seq=seq)
-
- def pack_data(self) -> bytes:
- return b''
-
- def _repr_fields(self) -> ReprDict:
- return {}
-
-
-# This is the first protocol message. Sent by a client.
-# Kettle usually ACKs this, but sometimes i don't get any ACK and the very next message is HandshakeResponseMessage.
-class HandshakeMessage(CmdMessage):
- TYPE = 0
-
- def encrypt(self,
- outkey: bytes,
- inkey: bytes,
- token: bytes,
- pubkey: bytes):
- cipher = ciphers.Cipher(algorithms.AES(outkey), modes.CBC(inkey))
- encryptor = cipher.encryptor()
-
- ciphertext = bytearray()
- ciphertext.extend(encryptor.update(token))
- ciphertext.extend(encryptor.finalize())
-
- pld = bytearray()
- pld.append(0)
- pld.extend(pubkey)
- pld.extend(ciphertext)
-
- self.frame.setpayload(pld)
-
-
-# Kettle either sends this right after the handshake, of first it ACKs the handshake then sends this.
-class HandshakeResponseMessage(CmdIncomingMessage):
- TYPE = 0
-
- protocol: int
- fw_major: int
- fw_minor: int
- mode: int
- token: bytes
-
- def __init__(self,
- protocol: int,
- fw_major: int,
- fw_minor: int,
- mode: int,
- token: bytes,
- seq: Optional[int] = None):
- super().__init__(seq)
- self.protocol = protocol
- self.fw_major = fw_major
- self.fw_minor = fw_minor
- self.mode = mode
- self.token = token
-
- @classmethod
- def from_packed_data(cls, data: bytes, seq=0) -> HandshakeResponseMessage:
- protocol, fw_major, fw_minor, mode = struct.unpack('<HBBB', data[:5])
- return HandshakeResponseMessage(protocol, fw_major, fw_minor, mode, token=data[5:], seq=seq)
-
- def _repr_fields(self) -> ReprDict:
- return {
- 'protocol': self.protocol,
- 'fw': f'{self.fw_major}.{self.fw_minor}',
- 'mode': self.mode,
- 'token': self.token.hex()
- }
-
-
-# Apparently, some hardware info.
-# On the other hand, if you look at com.syncleiot.iottransport.commands.CmdHardware, its mqtt topic says "mcu_firmware".
-# My device returns 1.1.1. The kettle uses on ESP8266 ESP-12F MCU under the hood (or, more precisely, under a piece of
-# cheap plastic), so maybe 1.1.1 is some MCU ROM version.
-class DeviceHardwareMessage(CmdIncomingMessage):
- TYPE = 143 # -113
-
- hw: List[int]
-
- def __init__(self, hw: List[int], seq: Optional[int] = None):
- super().__init__(seq)
- self.hw = hw
-
- @classmethod
- def from_packed_data(cls, data: bytes, seq=0) -> DeviceHardwareMessage:
- assert len(data) == 3, 'invalid data size, expected 3'
- hw = list(struct.unpack('<BBB', data))
- return DeviceHardwareMessage(hw, seq=seq)
-
- def _repr_fields(self) -> ReprDict:
- return {'device_hardware': '.'.join(map(str, self.hw))}
-
-
-# This message is sent by kettle right after the HandshakeMessageResponse.
-# The diagnostic data is supposed to be sent to vendor, which we, obviously, not going to do.
-# So just ACK and skip it.
-class DeviceDiagnosticMessage(CmdIncomingMessage):
- TYPE = 145 # -111
-
- diag_data: bytes
-
- def __init__(self, diag_data: bytes, seq: Optional[int] = None):
- super().__init__(seq)
- self.diag_data = diag_data
-
- @classmethod
- def from_packed_data(cls, data: bytes, seq=0) -> DeviceDiagnosticMessage:
- return DeviceDiagnosticMessage(diag_data=data, seq=seq)
-
- def _repr_fields(self) -> ReprDict:
- return {'diag_data': self.diag_data.hex()}
-
-
-class SimpleBooleanMessage(ABC, CmdIncomingMessage):
- value: bool
-
- def __init__(self, value: bool, seq: Optional[int] = None):
- super().__init__(seq)
- self.value = value
-
- @classmethod
- def from_packed_data(cls, data: bytes, seq: Optional[int] = None):
- assert len(data) == 1, 'invalid data size, expected 1'
- enabled, = struct.unpack('<B', data)
- return cls(value=enabled == 1, seq=seq)
-
- @abstractmethod
- def _repr_fields(self) -> ReprDict:
- pass
-
-
-class AccessControlMessage(SimpleBooleanMessage):
- TYPE = 133 # -123
-
- def _repr_fields(self) -> ReprDict:
- return {'acl_enabled': self.value}
-
-
-class ErrorMessage(SimpleBooleanMessage):
- TYPE = 7
-
- def _repr_fields(self) -> ReprDict:
- return {'error': self.value}
-
-
-class ChildLockMessage(SimpleBooleanMessage):
- TYPE = 30
-
- def _repr_fields(self) -> ReprDict:
- return {'child_lock': self.value}
-
-
-class VolumeMessage(SimpleBooleanMessage):
- TYPE = 9
-
- def _repr_fields(self) -> ReprDict:
- return {'volume': self.value}
-
-
-class BacklightMessage(SimpleBooleanMessage):
- TYPE = 28
-
- def _repr_fields(self) -> ReprDict:
- return {'backlight': self.value}
-
-
-class CurrentTemperatureMessage(CmdIncomingMessage):
- TYPE = 20
-
- current_temperature: int
-
- def __init__(self, temp: int, seq: Optional[int] = None):
- super().__init__(seq)
- self.current_temperature = temp
-
- @classmethod
- def from_packed_data(cls, data: bytes, seq=0) -> CurrentTemperatureMessage:
- assert len(data) == 2, 'data size expected to be 2'
- nat, frac = struct.unpack('BB', data)
- temp = int(nat + (frac / 100))
- return CurrentTemperatureMessage(temp, seq=seq)
-
- def pack_data(self) -> bytes:
- return bytes([self.current_temperature, 0])
-
- def _repr_fields(self) -> ReprDict:
- return {'current_temperature': self.current_temperature}
-
-
-class UnknownMessage(CmdIncomingMessage):
- type: Optional[int]
- data: bytes
-
- def __init__(self, data: bytes, **kwargs):
- super().__init__(**kwargs)
- self.type = None
- self.data = data
-
- @classmethod
- def from_packed_data(cls, data: bytes, seq=0) -> UnknownMessage:
- return UnknownMessage(data, seq=seq)
-
- def set_type(self, type: int):
- self.type = type
-
- def get_type(self) -> int:
- return self.type
-
- def _repr_fields(self) -> ReprDict:
- return {
- 'type': self.type,
- 'data': self.data.hex()
- }
-
-
-class WrappedMessage:
- _message: Message
- _handler: Optional[callable]
- _validator: Optional[callable]
- _logger: Optional[logging.Logger]
- _phase: MessagePhase
- _phase_update_time: float
-
- def __init__(self,
- message: Message,
- handler: Optional[callable] = None,
- validator: Optional[callable] = None,
- ack=False):
- self._message = message
- self._handler = handler
- self._validator = validator
- self._logger = None
- self._phase = MessagePhase.WAITING
- self._phase_update_time = 0
- if not validator and ack:
- self._validator = lambda m: isinstance(m, AckMessage)
-
- def setlogger(self, logger: logging.Logger):
- self._logger = logger
-
- def validate(self, message: Message):
- if not self._validator:
- return True
- return self._validator(message)
-
- def call(self, *args, error_message: str = None) -> None:
- if not self._handler:
- return
- try:
- self._handler(*args)
- except Exception as exc:
- logger = self._logger or logging.getLogger(self.__class__.__name__)
- logger.error(f'{error_message}, see exception below:')
- logger.exception(exc)
-
- @property
- def phase(self) -> MessagePhase:
- return self._phase
-
- @phase.setter
- def phase(self, phase: MessagePhase):
- self._phase = phase
- self._phase_update_time = 0 if phase == MessagePhase.WAITING else time.time()
-
- @property
- def phase_update_time(self) -> float:
- return self._phase_update_time
-
- @property
- def message(self) -> Message:
- return self._message
-
- @property
- def id(self) -> int:
- return self._message.id
-
- @property
- def seq(self) -> int:
- return self._message.seq
-
- @seq.setter
- def seq(self, seq: int):
- self._message.seq = seq
-
- def __repr__(self):
- return f'<{__name__}.{self.__class__.__name__} message={self._message.__repr__()}>'
-
-
-# Connection stuff
-# Well, strictly speaking, as it's UDP, there's no connection, but who cares.
-# ---------------------------------------------------------------------------
-
-class IncomingMessageListener:
- @abstractmethod
- def incoming_message(self, message: Message) -> Optional[Message]:
- pass
-
-
-class ConnectionStatus(Enum):
- NOT_CONNECTED = auto()
- CONNECTING = auto()
- CONNECTED = auto()
- RECONNECTING = auto()
- DISCONNECTED = auto()
-
-
-class ConnectionStatusListener:
- @abstractmethod
- def connection_status_updated(self, status: ConnectionStatus):
- pass
-
-
-class UDPConnection(threading.Thread, ConnectionStatusListener):
- inseq: int
- outseq: int
- source_port: int
- device_addr: str
- device_port: int
- device_token: bytes
- device_pubkey: bytes
- interrupted: bool
- response_handlers: Dict[int, WrappedMessage]
- outgoing_queue: List[WrappedMessage]
- pubkey: Optional[bytes]
- encinkey: Optional[bytes]
- encoutkey: Optional[bytes]
- inc_listeners: List[IncomingMessageListener]
- conn_listeners: List[ConnectionStatusListener]
- outgoing_time: float
- outgoing_time_1st: float
- incoming_time: float
- status: ConnectionStatus
- reconnect_tries: int
- read_timeout: int
-
- _addr_lock: threading.Lock
- _iml_lock: threading.Lock
- _csl_lock: threading.Lock
- _st_lock: threading.Lock
-
- def __init__(self,
- addr: Union[IPv4Address, IPv6Address],
- port: int,
- device_pubkey: bytes,
- device_token: bytes,
- read_timeout: int = 1):
- super().__init__()
- self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__} <{hex(id(self))}>')
- self.setName(self.__class__.__name__)
-
- self.inseq = 0
- self.outseq = 0
- self.source_port = random.randint(1024, 65535)
- self.device_addr = str(addr)
- self.device_port = port
- self.device_token = device_token
- self.device_pubkey = device_pubkey
- self.outgoing_queue = []
- self.response_handlers = {}
- self.interrupted = False
- self.outgoing_time = 0
- self.outgoing_time_1st = 0
- self.incoming_time = 0
- self.inc_listeners = []
- self.conn_listeners = [self]
- self.status = ConnectionStatus.NOT_CONNECTED
- self.reconnect_tries = 0
- self.read_timeout = read_timeout
-
- self._iml_lock = threading.Lock()
- self._csl_lock = threading.Lock()
- self._addr_lock = threading.Lock()
- self._st_lock = threading.Lock()
-
- self.pubkey = None
- self.encinkey = None
- self.encoutkey = None
-
- def connection_status_updated(self, status: ConnectionStatus):
- # self._logger.info(f'connection_status_updated: status = {status}')
- with self._st_lock:
- # self._logger.debug(f'connection_status_updated: lock acquired')
- self.status = status
- if status == ConnectionStatus.RECONNECTING:
- self.reconnect_tries += 1
- if status in (ConnectionStatus.CONNECTED, ConnectionStatus.NOT_CONNECTED, ConnectionStatus.DISCONNECTED):
- self.reconnect_tries = 0
-
- def _cleanup(self):
- # erase outgoing queue
- for wm in self.outgoing_queue:
- wm.call(False,
- error_message=f'_cleanup: exception while calling cb(False) on message {wm.message}')
- self.outgoing_queue = []
- self.response_handlers = {}
-
- # reset timestamps
- self.incoming_time = 0
- self.outgoing_time = 0
- self.outgoing_time_1st = 0
-
- self._logger.debug('_cleanup: done')
-
- def set_address(self, addr: Union[IPv4Address, IPv6Address], port: int):
- with self._addr_lock:
- if self.device_addr != str(addr) or self.device_port != port:
- self.device_addr = str(addr)
- self.device_port = port
- self._logger.info(f'updated device network address: {self.device_addr}:{self.device_port}')
-
- def set_device_pubkey(self, pubkey: bytes):
- if self.device_pubkey.hex() != pubkey.hex():
- self._logger.info(f'device pubkey has changed (old={self.device_pubkey.hex()}, new={pubkey.hex()})')
- self.device_pubkey = pubkey
- self._notify_cs(ConnectionStatus.RECONNECTING)
-
- def get_address(self) -> Tuple[str, int]:
- with self._addr_lock:
- return self.device_addr, self.device_port
-
- def add_incoming_message_listener(self, listener: IncomingMessageListener):
- with self._iml_lock:
- if listener not in self.inc_listeners:
- self.inc_listeners.append(listener)
-
- def add_connection_status_listener(self, listener: ConnectionStatusListener):
- with self._csl_lock:
- if listener not in self.conn_listeners:
- self.conn_listeners.append(listener)
-
- def _notify_cs(self, status: ConnectionStatus):
- # self._logger.debug(f'_notify_cs: status={status}')
- with self._csl_lock:
- for obj in self.conn_listeners:
- # self._logger.debug(f'_notify_cs: notifying {obj}')
- obj.connection_status_updated(status)
-
- def _prepare_keys(self):
- # generate key pair
- privkey = X25519PrivateKey.generate()
-
- self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=srlz.Encoding.Raw,
- format=srlz.PublicFormat.Raw)))
-
- # generate shared key
- device_pubkey = X25519PublicKey.from_public_bytes(
- bytes(reversed(self.device_pubkey))
- )
- shared_key = bytes(reversed(
- privkey.exchange(device_pubkey)
- ))
-
- # in/out encryption keys
- digest = hashes.Hash(hashes.SHA256())
- digest.update(shared_key)
-
- shared_sha256 = digest.finalize()
-
- self.encinkey = shared_sha256[:16]
- self.encoutkey = shared_sha256[16:]
-
- self._logger.info('encryption keys have been created')
-
- def _handshake_callback(self, r: MessageResponse):
- # if got error for our HandshakeMessage, reset everything and try again
- if r is False:
- # self._logger.debug('_handshake_callback: set status=RECONNETING')
- self._notify_cs(ConnectionStatus.RECONNECTING)
- else:
- # self._logger.debug('_handshake_callback: set status=CONNECTED')
- self._notify_cs(ConnectionStatus.CONNECTED)
-
- def run(self):
- self._logger.info('starting server loop')
-
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- sock.bind(('0.0.0.0', self.source_port))
- sock.settimeout(self.read_timeout)
-
- while not self.interrupted:
- with self._st_lock:
- status = self.status
-
- if status in (ConnectionStatus.DISCONNECTED, ConnectionStatus.RECONNECTING):
- self._cleanup()
- if status == ConnectionStatus.DISCONNECTED:
- break
-
- # no activity for some time means connection is broken
- fail = False
- fail_path = 0
- if self.incoming_time > 0 and time.time() - self.incoming_time >= DISCONNECT_TIMEOUT:
- fail = True
- fail_path = 1
- elif self.outgoing_time_1st > 0 and self.incoming_time == 0 and time.time() - self.outgoing_time_1st >= DISCONNECT_TIMEOUT:
- fail = True
- fail_path = 2
-
- if fail:
- self._logger.debug(f'run: setting status=RECONNECTING because of long inactivity, fail_path={fail_path}')
- self._notify_cs(ConnectionStatus.RECONNECTING)
-
- # establishing a connection
- if status in (ConnectionStatus.RECONNECTING, ConnectionStatus.NOT_CONNECTED):
- if status == ConnectionStatus.RECONNECTING and self.reconnect_tries >= 3:
- self._notify_cs(ConnectionStatus.DISCONNECTED)
- continue
-
- self._reset_outseq()
- self._prepare_keys()
-
- # shake the imaginary kettle's hand
- wrapped = WrappedMessage(HandshakeMessage(),
- handler=self._handshake_callback,
- validator=lambda m: isinstance(m, (AckMessage, HandshakeResponseMessage)))
- self.enqueue_message(wrapped, prepend=True)
- self._notify_cs(ConnectionStatus.CONNECTING)
-
- # pick next (wrapped) message to send
- wm = self._get_next_message() # wm means "wrapped message"
- if wm:
- one_shot = isinstance(wm.message, (AckMessage, NakMessage))
-
- if not isinstance(wm.message, (AckMessage, NakMessage)):
- old_seq = wm.seq
- wm.seq = self.outseq
- self._set_response_handler(wm, old_seq=old_seq)
- elif wm.seq is None:
- # ack/nak is a response to some incoming message (and it must have the same seqno that incoming
- # message had)
- raise RuntimeError(f'run: seq must be set for {wm.__class__.__name__}')
-
- self._logger.debug(f'run: sending message: {wm.message}, one_shot={one_shot}, phase={wm.phase}')
- encrypted = False
- try:
- wm.message.encrypt(outkey=self.encoutkey, inkey=self.encinkey,
- token=self.device_token, pubkey=self.pubkey)
- encrypted = True
- except ValueError as exc:
- # handle "ValueError: Invalid padding bytes."
- self._logger.error('run: failed to encrypt the message.')
- self._logger.exception(exc)
-
- if encrypted:
- buf = wm.message.frame.pack()
- # self._logger.debug(f'run: raw data to be sent: {buf.hex()}')
-
- # sending the first time
- if wm.phase == MessagePhase.WAITING:
- sock.sendto(buf, self.get_address())
- # resending
- elif wm.phase == MessagePhase.SENT:
- left = RESEND_ATTEMPTS
- while left > 0:
- sock.sendto(buf, self.get_address())
- left -= 1
- if left > 0:
- time.sleep(0.05)
-
- if one_shot or wm.phase == MessagePhase.SENT:
- wm.phase = MessagePhase.DONE
- else:
- wm.phase = MessagePhase.SENT
-
- now = time.time()
- self.outgoing_time = now
- if not self.outgoing_time_1st:
- self.outgoing_time_1st = now
-
- # receiving data
- try:
- data = sock.recv(4096)
- self._handle_incoming(data)
- except (TimeoutError, socket.timeout):
- pass
-
- self._logger.info('bye...')
-
- def _get_next_message(self) -> Optional[WrappedMessage]:
- message = None
- lpfx = '_get_next_message:'
- remove_list = []
- for wm in self.outgoing_queue:
- if wm.phase == MessagePhase.DONE:
- if isinstance(wm.message, (AckMessage, NakMessage, PingMessage)) or time.time() - wm.phase_update_time >= MESSAGE_QUEUE_REMOVE_DELAY:
- remove_list.append(wm)
- continue
- message = wm
- break
-
- for wm in remove_list:
- self._logger.debug(f'{lpfx} rm path: removing id={wm.id} seq={wm.seq}')
-
- # clear message handler
- if wm.seq in self.response_handlers:
- self.response_handlers[wm.seq].call(
- False, error_message=f'{lpfx} rm path: error while calling callback for seq={wm.seq}')
- del self.response_handlers[wm.seq]
-
- # remove from queue
- try:
- self.outgoing_queue.remove(wm)
- except ValueError as exc:
- self._logger.error(f'{lpfx} rm path: removing from outgoing_queue raised an exception: {str(exc)}')
-
- # ping pong
- if not message and self.outgoing_time_1st != 0 and self.status == ConnectionStatus.CONNECTED:
- now = time.time()
- out_delta = now - self.outgoing_time
- in_delta = now - self.incoming_time
- if max(out_delta, in_delta) > PING_FREQUENCY:
- self._logger.debug(f'{lpfx} no activity: in for {in_delta:.2f}s, out for {out_delta:.2f}s, time to ping the damn thing')
- message = WrappedMessage(PingMessage(), ack=True)
- # add it to outgoing_queue in order to be aggressively resent in future (if needed)
- self.outgoing_queue.insert(0, message)
-
- return message
-
- def _handle_incoming(self, buf: bytes):
- try:
- incoming_message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey)
- except ValueError as exc:
- # handle "ValueError: Invalid padding bytes."
- self._logger.error('_handle_incoming: failed to decrypt incoming frame:')
- self._logger.exception(exc)
- return
-
- self.incoming_time = time.time()
- seq = incoming_message.seq
-
- lpfx = f'handle_incoming({incoming_message.id}):'
- self._logger.debug(f'{lpfx} received: {incoming_message}')
-
- if isinstance(incoming_message, (AckMessage, NakMessage)):
- seq_max = self.outseq
- seq_name = 'outseq'
- else:
- seq_max = self.inseq
- seq_name = 'inseq'
- self.inseq = seq
-
- if seq < seq_max < 0xfd:
- self._logger.debug(f'{lpfx} dropping: seq={seq}, {seq_name}={seq_max}')
- return
-
- if seq not in self.response_handlers:
- self._handle_incoming_cmd(incoming_message)
- return
-
- callback_value = None # None means don't call a callback
- handler = self.response_handlers[seq]
-
- if handler.validate(incoming_message):
- self._logger.debug(f'{lpfx} response OK')
- handler.phase = MessagePhase.DONE
- callback_value = incoming_message
- self._incr_outseq()
- else:
- self._logger.warning(f'{lpfx} response is INVALID')
-
- # It seems that we've received an incoming CmdMessage or PingMessage with the same seqno that our outgoing
- # message had. Bad, but what can I say, this is quick-and-dirty made UDP based protocol and this sort of
- # shit just happens.
-
- # (To be fair, maybe my implementation is not perfect either. But hey, what did you expect from a
- # reverse-engineered re-implementation of custom UDP-based protocol that some noname vendor uses for their
- # cheap IoT devices? I think _that_ is _the_ definition of shit. At least my implementation is FOSS, which
- # is more than you'll ever be able to say about them.)
-
- # All this crapload of code below might not be needed at all, 'cause the protocol uses separate frame seq
- # numbers for IN and OUT frames and this situation is not highly likely, as Theresa May could argue.
- # After a handshake, a kettle sends us 10 or so CmdMessages, and then either we continuously ping it every
- # 3 seconds, or kettle pings us. This in any case widens the gap between inseq and outseq.
-
- # But! the seqno is only 1 byte in size and once it reaches 0xff, it circles back to zero. And that (plus,
- # perhaps, some bad luck) gives a chance for a collision.
-
- if handler.phase == MessagePhase.DONE or isinstance(handler.message, HandshakeMessage):
- # no more attempts left, returning error back to user
- # as to handshake, it cannot fail.
- callback_value = False
-
- # else:
- # # try resending the message
- # handler.phase_reset()
- # max_seq = self.outseq
- # wait_remap = {}
- # for m in self.outgoing_queue:
- # if m.seq in self.waiting_for_response:
- # wait_remap[m.seq] = (m.seq+1) % 256
- # m.set_seq((m.seq+1) % 256)
- # if m.seq > max_seq:
- # max_seq = m.seq
- # if max_seq > self.outseq:
- # self.outseq = max_seq % 256
- # if wait_remap:
- # waiting_new = {}
- # for old_seq, new_seq in wait_remap.items():
- # waiting_new[new_seq] = self.waiting_for_response[old_seq]
- # self.waiting_for_response = waiting_new
-
- if isinstance(incoming_message, (PingMessage, CmdIncomingMessage)):
- # handle incoming message as usual, as we need to ack/nak it anyway
- self._handle_incoming_cmd(incoming_message)
-
- if callback_value is not None:
- handler.call(callback_value,
- error_message=f'{lpfx} error while calling callback for msg id={handler.message.id} seq={seq}')
- del self.response_handlers[seq]
-
- def _handle_incoming_cmd(self, incoming_message: Message):
- if isinstance(incoming_message, (AckMessage, NakMessage)):
- self._logger.debug(f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): it\'s {incoming_message.__class__.__name__}, ignoring')
- return
-
- replied = False
- with self._iml_lock:
- for f in self.inc_listeners:
- retval = safe_callback_call(f.incoming_message, incoming_message,
- logger=self._logger,
- error_message=f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): error while calling message listener')
- if isinstance(retval, Message):
- if isinstance(retval, (AckMessage, NakMessage)):
- retval.seq = incoming_message.seq
- self.enqueue_message(WrappedMessage(retval), prepend=True)
- replied = True
- break
- else:
- raise RuntimeError('are you sure your response is correct? only ack/nak are allowed')
-
- if not replied:
- self.enqueue_message(WrappedMessage(AckMessage(incoming_message.seq)), prepend=True)
-
- def enqueue_message(self, wrapped: WrappedMessage, prepend=False):
- self._logger.debug(f'enqueue_message: {wrapped.message}')
- if not prepend:
- self.outgoing_queue.append(wrapped)
- else:
- self.outgoing_queue.insert(0, wrapped)
-
- def _set_response_handler(self, wm: WrappedMessage, old_seq=None):
- if old_seq in self.response_handlers:
- del self.response_handlers[old_seq]
-
- seq = wm.seq
- assert seq is not None, 'seq is not set'
-
- if seq in self.response_handlers:
- self._logger.debug(f'_set_response_handler(seq={seq}): handler is already set, cancelling it')
- self.response_handlers[seq].call(False,
- error_message=f'_set_response_handler({seq}): error while calling old callback')
- self.response_handlers[seq] = wm
-
- def _incr_outseq(self) -> None:
- self.outseq = (self.outseq + 1) % 256
-
- def _reset_outseq(self):
- self.outseq = 0
- self._logger.debug(f'_reset_outseq: set 0')
-
-
-MessageResponse = Union[Message, bool]