summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/polaris/__init__.py5
-rw-r--r--src/polaris/kettle.py497
-rw-r--r--src/polaris/protocol.py412
-rwxr-xr-xsrc/polaris_kettle_util.py60
4 files changed, 503 insertions, 471 deletions
diff --git a/src/polaris/__init__.py b/src/polaris/__init__.py
index dd212a6..aa077ce 100644
--- a/src/polaris/__init__.py
+++ b/src/polaris/__init__.py
@@ -1 +1,4 @@
-from .kettle import Kettle \ No newline at end of file
+# SPDX-License-Identifier: BSD-3-Clause
+
+from .kettle import Kettle
+from .protocol import Message, FrameType, PowerType \ No newline at end of file
diff --git a/src/polaris/kettle.py b/src/polaris/kettle.py
index ea9a1f4..37f6813 100644
--- a/src/polaris/kettle.py
+++ b/src/polaris/kettle.py
@@ -1,391 +1,53 @@
+# SPDX-License-Identifier: BSD-3-Clause
+
from __future__ import annotations
import logging
import zeroconf
-import socket
-import random
-import struct
-
-from enum import Enum
-from ipaddress import ip_address, IPv4Address, IPv6Address
-from typing import Union, Optional, Any
-import cryptography
import cryptography.hazmat.primitives._serialization
-from cryptography.hazmat.primitives.asymmetric.ec import SECP192R1, SECP256R1, SECP384R1
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
-from cryptography.hazmat.primitives import hashes, ciphers, padding
-from cryptography.hazmat.primitives.ciphers import algorithms, modes
+from cryptography.hazmat.primitives import hashes
+
+from functools import partial
+from abc import ABC
+from ipaddress import ip_address
+from typing import Optional
+
+from .protocol import (
+ Connection,
+ ModeMessage,
+ HandshakeMessage,
+ TargetTemperatureMessage,
+ Message,
+ PowerType
+)
_logger = logging.getLogger(__name__)
-PubkeyType = Union[Any, X25519PublicKey, bytes]
-PrivkeyType = Union[Any, X25519PrivateKey, bytes]
-
-
-# com.syncleoiot.iottransport.utils.crypto.EllipticCurveCoder
-class CurveType(Enum):
- secp192r1 = 19
- secp256r1 = 23
- secp384r1 = 24
- x25519 = 29
-
-
-def key_to_bytes(key: Union[str, bytes, X25519PrivateKey, X25519PublicKey], reverse=False) -> bytes:
- val = None
-
- if isinstance(key, str):
- val = bytes.fromhex(key)
-
- if isinstance(key, bytes):
- # logger.warning('key_to_bytes: key is bytes already')
- val = key
-
- raw_kwargs = dict(encoding=cryptography.hazmat.primitives._serialization.Encoding.Raw,
- format=cryptography.hazmat.primitives._serialization.PublicFormat.Raw)
-
- if isinstance(key, X25519PublicKey):
- val = key.public_bytes(**raw_kwargs)
-
- elif isinstance(key, X25519PrivateKey):
- val = key.private_bytes(**raw_kwargs)
-
- assert type(val) is bytes
-
- if reverse:
- val = bytes(reversed(val))
-
- return val
-
-
-def key_to_hex(key: Union[str, bytes, X25519PrivateKey, X25519PublicKey]) -> str:
- return key_to_bytes(key).hex()
-
-
-def arraycopy(src, src_pos, dest, dest_pos, length):
- for i in range(length):
- dest[i + dest_pos] = src[i + src_pos]
-
-
-def pack(fmt, *args):
- # enforce little endian
- return struct.pack(f'<{fmt}', *args)
-
-
-def unpack(fmt, *args):
- # enforce little endian
- return struct.unpack(f'<{fmt}', *args)
-
-
-class FrameType(Enum):
- ACK = 0
- CMD = 1
- AUX = 2
- NAK = 3
-
-
-class FrameHead:
- seq: int # u8
- type: FrameType # u8
- length: int # u16
-
- @staticmethod
- def from_bytes(buf: bytes) -> FrameHead:
- seq, ft, length = unpack('BBH', buf)
- return FrameHead(seq, FrameType(ft), length)
-
- def __init__(self, seq: int, frame_type: FrameType, length: Optional[int] = None):
- self.seq = seq
- self.type = frame_type
- self.length = length or 0
-
- def pack(self) -> bytes:
- assert self.length != 0, "FrameHead.length has not been set"
- return pack('BBH', self.seq, self.type.value, self.length)
-
-
-class FrameItem:
- head: FrameHead
- payload: bytes
-
- def __init__(self, head: FrameHead, payload: Optional[bytes] = None):
- self.head = head
- self.payload = payload
-
- def setpayload(self, payload: Union[bytes, bytearray]):
- if isinstance(payload, bytearray):
- payload = bytes(payload)
- self.payload = payload
- self.head.length = len(payload)
-
- def pack(self) -> bytes:
- ba = bytearray(self.head.pack())
- ba.extend(self.payload)
- return bytes(ba)
-
-
-class Message:
- frame: Optional[FrameItem]
-
- def __init__(self):
- self.frame = None
-
- @staticmethod
- def from_encrypted(buf: bytes, inkey: bytes, outkey: bytes) -> Message:
- _logger.debug('[from_encrypted] buf='+buf.hex())
- # print(f'buf len={len(buf)}')
- assert len(buf) >= 4, 'invalid size'
- head = FrameHead.from_bytes(buf[:4])
-
- assert len(buf) == head.length + 4, f'invalid buf size ({len(buf)} != {head.length})'
-
- payload = buf[4:]
-
- # byte b = paramFrameHead.seq;
- b = head.seq
-
- # TODO check if protocol is 2, otherwise raise an exception
-
- j = b & 0xF
- k = b >> 4 & 0xF
-
- # arrayOfByte1 = this.encryptionInKey;
- key = bytearray(len(inkey))
- arraycopy(inkey, j, key, 0, len(inkey) - j)
- arraycopy(inkey, 0, key, len(inkey) - j, j)
-
- # arrayOfByte1 = this.encryptionOutKey;
- iv = bytearray(len(outkey))
- arraycopy(outkey, k, iv, 0, len(outkey) - k)
- arraycopy(outkey, 0, iv, len(outkey) - k, k)
-
- cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv))
- decryptor = cipher.decryptor()
- decrypted_data = decryptor.update(payload) + decryptor.finalize()
-
- # print(f'head.length={head.length} len(decr)={len(decrypted_data)}')
- if len(decrypted_data) > head.length:
- unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
- decrypted_data = unpadder.update(decrypted_data)
- try:
- decrypted_data += unpadder.finalize()
- except ValueError as exc:
- _logger.exception(exc)
- pass
-
- _logger.debug('decrypted data:', decrypted_data.hex())
-
- assert len(decrypted_data) != 0, 'decrypted data is null'
- assert head.seq == decrypted_data[0], f'decrypted seq mismatch {head.seq} != {decrypted_data[0]}'
-
- if head.type == FrameType.ACK:
- return AckMessage(head.seq)
-
- elif head.type == FrameType.NAK:
- return NakMessage(head.seq)
-
- else:
- cmd = decrypted_data[0]
- data = decrypted_data[2:]
- return CmdMessage(head.seq, cmd, data)
-
- def encrypt(self):
- raise RuntimeError('this method is abstract')
-
- @property
- def data(self) -> bytes:
- raise RuntimeError('this method is abstract')
-
- def _encrypt(self,
- outkey: bytes,
- inkey: bytes,
- token: bytes,
- pubkey: bytes):
-
- assert self.frame is not None
-
- data = self.data
- assert data is not None
-
- # print('data: '+data.hex())
-
- b = self.frame.head.seq
- i = b & 0xf
- j = b >> 4 & 0xf
-
- # byte[] arrayOfByte1 = this.encryptionOutKey;
- outkey = bytearray(outkey)
-
- # arrayOfByte = new byte[arrayOfByte1.length];
- l = len(outkey)
- key = bytearray(l)
-
- # System.arraycopy(arrayOfByte1, i, arrayOfByte, 0, arrayOfByte1.length - i);
- arraycopy(outkey, i, key, 0, l-i)
-
- # arrayOfByte1 = this.encryptionOutKey;
- # System.arraycopy(arrayOfByte1, 0, arrayOfByte, arrayOfByte1.length - i, i);
- arraycopy(outkey, 0, key, l-i, i)
-
- # byte[] arrayOfByte2 = this.encryptionInKey;
- inkey = bytearray(inkey)
-
- # arrayOfByte1 = new byte[arrayOfByte2.length];
- l = len(inkey)
- iv = bytearray(l)
-
- # System.arraycopy(arrayOfByte2, j, arrayOfByte1, 0, arrayOfByte2.length - j);
- arraycopy(inkey, j, iv, 0, l-j)
- # arrayOfByte2 = this.encryptionInKey;
- # System.arraycopy(arrayOfByte2, 0, arrayOfByte1, arrayOfByte2.length - j, j);
- arraycopy(inkey, 0, iv, l-j, j)
-
- # Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
- # SecretKeySpec secretKeySpec = new SecretKeySpec();
- # this(arrayOfByte, "AES");
- # IvParameterSpec ivParameterSpec = new IvParameterSpec();
- # this(arrayOfByte1);
- # cipher.init(1, secretKeySpec, ivParameterSpec);
-
- cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv))
- encryptor = cipher.encryptor()
-
- # arrayOfByte = new byte[paramArrayOfbyte.length + 1];
- # arrayOfByte[0] = b;
- # System.arraycopy(paramArrayOfbyte, 0, arrayOfByte, 1, paramArrayOfbyte.length);
- # data = bytearray(data)
-
- newdata = bytearray(len(data)+1)
- newdata[0] = b
-
- # data = bytearray(len(payload)+1)
- # data[0] = b
- arraycopy(data, 0, newdata, 1, len(data))
-
- newdata = bytes(newdata)
- _logger.debug('payload to be sent:' + newdata.hex())
-
- # arrayOfByte = ByteUtils.concatArrays(cipher.update(arrayOfByte), cipher.doFinal());
- encdata = bytearray()
- padder = padding.PKCS7(algorithms.AES.block_size).padder()
- encdata.extend(encryptor.update(padder.update(newdata) + padder.finalize()))
- encdata.extend(encryptor.finalize())
-
- self.frame.setpayload(encdata)
-
- def construct(self) -> FrameItem:
- raise RuntimeError('this is an abstract method')
-
-
-class AckMessage(Message):
- def __init__(self, seq: int):
- super().__init__()
- self.frame = FrameItem(FrameHead(seq, FrameType.ACK, 0))
-
-
-class NakMessage(Message):
- def __init__(self, seq: int):
- super().__init__()
- self.frame = FrameItem(FrameHead(seq, FrameType.NAK, 0))
-
-
-class CmdMessage(Message):
- cmd: Optional[int]
- cmd_data: Optional[Union[bytes, str]]
-
- def __init__(self,
- seq: Optional[int] = None,
- cmd: Optional[int] = None,
- cmd_data: Optional[bytes] = None):
- super().__init__()
-
- if (seq is not None) and (cmd is not None) and (cmd_data is not None):
- self.frame = FrameItem(FrameHead(seq, FrameType.CMD))
- # self.frame.setpayload(data)
- self.cmd = cmd
- self.cmd_data = cmd_data
- else:
- self.cmd = None
- self.cmd_data = None
-
- @property
- def data(self) -> bytes:
- buf = bytearray()
- buf.append(self.cmd)
- buf.extend(self.cmd_data)
- # print(buf)
- return bytes(buf)
-
-
-class ModeMessage(CmdMessage):
- def __init__(self, seq: int, on: bool):
- super().__init__(seq, 1, b'\x01' if on else b'\x00')
-
-
-class TargetTemperatureMessage(CmdMessage):
- def __init__(self, seq: int, temp: int):
- super().__init__(seq, 2, bytes(bytearray([temp, 0])))
-
-
-class HandshakeMessage(CmdMessage):
- def _encrypt(self,
- outkey: bytes,
- inkey: bytes,
- token: bytes,
- pubkey: bytes):
- cipher = ciphers.Cipher(algorithms.AES(outkey), modes.CBC(inkey))
- encryptor = cipher.encryptor()
-
- encr_data = bytearray()
- encr_data.extend(encryptor.update(token))
- encr_data.extend(encryptor.finalize())
-
- payload = bytearray()
-
- # const/4 v7, 0x0
- # aput-byte v7, v5, v7
- payload.append(0)
-
- payload.extend(pubkey)
- payload.extend(encr_data)
-
- self.frame = FrameItem(FrameHead(0, FrameType.CMD))
- self.frame.setpayload(payload)
# Polaris PWK 1725CGLD IoT kettle
-class Kettle(zeroconf.ServiceListener):
+class Kettle(zeroconf.ServiceListener, ABC):
macaddr: str
- token: str
+ device_token: str
sb: Optional[zeroconf.ServiceBrowser]
found_device: Optional[zeroconf.ServiceInfo]
- privkey: Optional[Union[Any, X25519PrivateKey]]
- pubkey: Optional[bytes]
- sharedkey: Optional[bytes]
- sharedsha256: Optional[bytes]
- encinkey: Optional[bytes]
- encoutkey: Optional[bytes]
- seqno: int
+ conn: Optional[Connection]
- def __init__(self, mac: str, token: str):
+ def __init__(self, mac: str, device_token: str):
super().__init__()
self.zeroconf = zeroconf.Zeroconf()
self.sb = None
self.macaddr = mac
- self.token = token
+ self.device_token = device_token
self.found_device = None
- self.privkey = None
- self.pubkey = None
- self.sharedkey = None
- self.sharedsha256 = None
- self.encinkey = None
- self.encoutkey = None
- self.sourceport = random.randint(1024, 65535)
- self.seqno = 0
+ self.conn = None
- def find(self):
+ def find(self) -> zeroconf.ServiceInfo:
self.sb = zeroconf.ServiceBrowser(self.zeroconf, "_syncleo._udp.local.", self)
self.sb.join()
- # return self.found_device
+
+ return self.found_device
# zeroconf.ServiceListener implementation
def add_service(self,
@@ -401,96 +63,35 @@ class Kettle(zeroconf.ServiceListener):
self.zeroconf.close()
self.found_device = info
- # def update_service(self, zc: Zeroconf, type_: str, name: str) -> None:
- # print(f"Service {name} updated")
- #
- # def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None:
- # print(f"Service {name} removed")
-
- @property
- def device_pubkey(self) -> str:
- return self.found_device.properties[b'public'].decode()
+ assert self.device_curve == 29, f'curve type {self.device_curve} is not implemented'
- @property
- def device_addresses(self) -> list[Union[IPv4Address, IPv6Address]]:
- return list(map(ip_address, self.found_device.addresses))
+ def start_server(self, callback: callable):
+ addresses = list(map(ip_address, self.found_device.addresses))
+ self.conn = Connection(addr=addresses[0],
+ port=int(self.found_device.port),
+ device_pubkey=self.device_pubkey,
+ device_token=bytes.fromhex(self.device_token))
- @property
- def device_port(self) -> int:
- return int(self.found_device.port)
+ # shake the kettle's hand
+ self._pass_message(HandshakeMessage(), callback)
+ self.conn.start()
- # @property
- # def device_pubkey_bytes(self) -> bytes:
- # return bytes.fromhex(self.device_pubkey)
+ def stop_server(self):
+ self.conn.interrupted = True
@property
- def curve_type(self) -> CurveType:
- return CurveType(int(self.found_device.properties[b'curve'].decode()))
-
- def genkeys(self):
- # based on decompiled EllipticCurveCoder.java
-
- if self.curve_type in (CurveType.secp192r1, CurveType.secp256r1, CurveType.secp384r1):
- if self.curve_type == CurveType.secp192r1:
- curve = SECP192R1()
- elif self.curve_type == CurveType.secp256r1:
- curve = SECP256R1()
- elif self.curve_type == CurveType.secp384r1:
- curve = SECP384R1()
- else:
- raise TypeError(f'unexpected curve type: {self.curve_type}')
-
- self.privkey = cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(curve)
-
- elif self.curve_type == CurveType.x25519:
- self.privkey = X25519PrivateKey.generate()
-
- self.pubkey = key_to_bytes(self.privkey.public_key(), reverse=True)
-
- def genshared(self):
- self.sharedkey = bytes(reversed(
- self.privkey.exchange(X25519PublicKey.from_public_bytes(
- key_to_bytes(self.device_pubkey, reverse=True))
- )
- ))
+ def device_pubkey(self) -> bytes:
+ return bytes.fromhex(self.found_device.properties[b'public'].decode())
- digest = hashes.Hash(hashes.SHA256())
- digest.update(self.sharedkey)
- self.sharedsha256 = digest.finalize()
-
- self.encinkey = self.sharedsha256[:16]
- self.encoutkey = self.sharedsha256[16:]
-
- def next_seqno(self) -> int:
- self.seqno += 1
- return self.seqno
-
- def setpower(self, on: bool):
- message = ModeMessage(self.next_seqno(), on)
- print(self.do_send(message))
-
- def settemperature(self, temp: int):
- message = TargetTemperatureMessage(self.next_seqno(), temp)
- print(self.do_send(message))
-
- def handshake(self):
- message = HandshakeMessage()
- response = self.do_send(message)
- assert response.frame.head.type == FrameType.ACK, 'ACK expected'
-
- def do_send(self, message: Message) -> Message:
- message._encrypt(pubkey=self.pubkey,
- outkey=self.encoutkey,
- inkey=self.encinkey,
- token=bytes.fromhex(self.token))
+ @property
+ def device_curve(self) -> int:
+ return int(self.found_device.properties[b'curve'].decode())
- dst_addr = str(self.device_addresses[0])
- dst_port = self.device_port
+ def set_power(self, power_type: PowerType, callback: callable):
+ self._pass_message(ModeMessage(power_type), callback)
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- sock.bind(('0.0.0.0', self.sourceport))
- sock.sendto(message.frame.pack(), (dst_addr, dst_port))
- _logger.debug('data has been sent, waiting for incoming data....')
+ def set_target_temperature(self, temp: int, callback: callable):
+ self._pass_message(TargetTemperatureMessage(temp), callback)
- data = sock.recv(4096)
- return Message.from_encrypted(data, inkey=self.encinkey, outkey=self.encoutkey) \ No newline at end of file
+ def _pass_message(self, message: Message, callback: callable):
+ self.conn.send_message(message, partial(callback, self))
diff --git a/src/polaris/protocol.py b/src/polaris/protocol.py
new file mode 100644
index 0000000..cc4e36a
--- /dev/null
+++ b/src/polaris/protocol.py
@@ -0,0 +1,412 @@
+# SPDX-License-Identifier: BSD-3-Clause
+
+from __future__ import annotations
+import logging
+import socket
+import random
+import struct
+import threading
+import queue
+
+from enum import Enum
+from typing import Union, Optional, Any
+from ipaddress import IPv4Address, IPv6Address
+
+import cryptography.hazmat.primitives._serialization
+
+from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
+from cryptography.hazmat.primitives import ciphers, padding, hashes
+from cryptography.hazmat.primitives.ciphers import algorithms, modes
+
+
+_logger = logging.getLogger(__name__)
+
+
+# drop-in replacement for Java API
+# TODO: rewrite
+def arraycopy(src, src_pos, dest, dest_pos, length):
+ for i in range(length):
+ dest[i + dest_pos] = src[i + src_pos]
+
+
+class FrameType(Enum):
+ ACK = 0
+ CMD = 1
+ AUX = 2
+ NAK = 3
+
+
+class PowerType(Enum):
+ OFF = 0 # turn off
+ ON = 1 # turn on, set target temperature to 100
+ CUSTOM = 3 # turn on, allows custom target temperature
+ # MYSTERY_MODE = 2 # don't know what 2 means, needs testing
+ # update: if I set it to '2', it just resets to '0'
+
+
+class FrameHead:
+ seq: int # u8
+ type: FrameType # u8
+ length: int # u16
+
+ @staticmethod
+ def from_bytes(buf: bytes) -> FrameHead:
+ seq, ft, length = struct.unpack('<BBH', buf)
+ return FrameHead(seq, FrameType(ft), length)
+
+ def __init__(self, seq: int, frame_type: FrameType, length: Optional[int] = None):
+ self.seq = seq
+ self.type = frame_type
+ self.length = length or 0
+
+ def pack(self) -> bytes:
+ assert self.length != 0, "FrameHead.length has not been set"
+ return struct.pack('<BBH', self.seq, self.type.value, self.length)
+
+
+class FrameItem:
+ head: FrameHead
+ payload: bytes
+
+ def __init__(self, head: FrameHead, payload: Optional[bytes] = None):
+ self.head = head
+ self.payload = payload
+
+ def setpayload(self, payload: Union[bytes, bytearray]):
+ if isinstance(payload, bytearray):
+ payload = bytes(payload)
+ self.payload = payload
+ self.head.length = len(payload)
+
+ def pack(self) -> bytes:
+ ba = bytearray(self.head.pack())
+ ba.extend(self.payload)
+ return bytes(ba)
+
+
+class Message:
+ frame: Optional[FrameItem]
+
+ def __init__(self):
+ self.frame = None
+
+ @staticmethod
+ def from_encrypted(buf: bytes,
+ inkey: bytes,
+ outkey: bytes) -> Message:
+ # _logger.debug('[from_encrypted] buf='+buf.hex())
+ # print(f'buf len={len(buf)}')
+ assert len(buf) >= 4, 'invalid size'
+ head = FrameHead.from_bytes(buf[:4])
+
+ assert len(buf) == head.length + 4, f'invalid buf size ({len(buf)} != {head.length})'
+
+ payload = buf[4:]
+
+ # byte b = paramFrameHead.seq;
+ b = head.seq
+
+ # TODO check if protocol is 2, otherwise raise an exception
+
+ j = b & 0xF
+ k = b >> 4 & 0xF
+
+ key = bytearray(len(inkey))
+ arraycopy(inkey, j, key, 0, len(inkey) - j)
+ arraycopy(inkey, 0, key, len(inkey) - j, j)
+
+ iv = bytearray(len(outkey))
+ arraycopy(outkey, k, iv, 0, len(outkey) - k)
+ arraycopy(outkey, 0, iv, len(outkey) - k, k)
+
+ cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv))
+ decryptor = cipher.decryptor()
+ decrypted_data = decryptor.update(payload) + decryptor.finalize()
+
+ # print(f'head.length={head.length} len(decr)={len(decrypted_data)}')
+ # if len(decrypted_data) > head.length:
+ unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
+ decrypted_data = unpadder.update(decrypted_data)
+ # try:
+ decrypted_data += unpadder.finalize()
+ # except ValueError as exc:
+ # _logger.exception(exc)
+ # pass
+
+ assert len(decrypted_data) != 0, 'decrypted data is null'
+ assert head.seq == decrypted_data[0], f'decrypted seq mismatch {head.seq} != {decrypted_data[0]}'
+
+ # _logger.debug('Message.from_encrypted: plaintext: '+decrypted_data.hex())
+
+ if head.type == FrameType.ACK:
+ return AckMessage(head.seq)
+
+ elif head.type == FrameType.NAK:
+ return NakMessage(head.seq)
+
+ elif head.type == FrameType.AUX:
+ raise NotImplementedError('FrameType AUX is not yet implemented')
+
+ elif head.type == FrameType.CMD:
+ cmd = decrypted_data[0]
+ data = decrypted_data[2:]
+ return CmdMessage(head.seq, cmd, data)
+
+ else:
+ raise NotImplementedError(f'Unexpected frame type: {head.type}')
+
+ @property
+ def data(self) -> bytes:
+ return b''
+
+ def encrypt(self,
+ outkey: bytes,
+ inkey: bytes,
+ token: bytes,
+ pubkey: bytes):
+
+ assert self.frame is not None
+
+ data = self.data
+ assert data is not None
+
+ b = self.frame.head.seq
+ i = b & 0xf
+ j = b >> 4 & 0xf
+
+ outkey = bytearray(outkey)
+
+ l = len(outkey)
+ key = bytearray(l)
+
+ arraycopy(outkey, i, key, 0, l-i)
+ arraycopy(outkey, 0, key, l-i, i)
+
+ inkey = bytearray(inkey)
+
+ l = len(inkey)
+ iv = bytearray(l)
+
+ arraycopy(inkey, j, iv, 0, l-j)
+ arraycopy(inkey, 0, iv, l-j, j)
+
+ cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv))
+ encryptor = cipher.encryptor()
+
+ newdata = bytearray(len(data)+1)
+ newdata[0] = b
+
+ arraycopy(data, 0, newdata, 1, len(data))
+
+ newdata = bytes(newdata)
+ _logger.debug('payload to be sent: ' + newdata.hex())
+
+ padder = padding.PKCS7(algorithms.AES.block_size).padder()
+ ciphertext = bytearray()
+ ciphertext.extend(encryptor.update(padder.update(newdata) + padder.finalize()))
+ ciphertext.extend(encryptor.finalize())
+
+ self.frame.setpayload(ciphertext)
+
+ def set_seq(self, seq: int):
+ self.frame.head.seq = seq
+
+ def __repr__(self):
+ return f'<{self.__class__.__name__} seq={self.frame.head.seq}>'
+
+
+class AckMessage(Message):
+ def __init__(self, seq: int = 0):
+ super().__init__()
+ self.frame = FrameItem(FrameHead(seq, FrameType.ACK, 0))
+
+
+class NakMessage(Message):
+ def __init__(self, seq: int = 0):
+ super().__init__()
+ self.frame = FrameItem(FrameHead(seq, FrameType.NAK, 0))
+
+
+class CmdMessage(Message):
+ _type: Optional[int]
+ _data: bytes
+
+ def __init__(self, seq=0,
+ type: Optional[int] = None,
+ data: bytes = b''):
+ super().__init__()
+ self._data = data
+ if type is not None:
+ self.frame = FrameItem(FrameHead(seq, FrameType.CMD))
+ self._type = type
+ else:
+ self._type = None
+
+ @property
+ def data(self) -> bytes:
+ buf = bytearray()
+ buf.append(self._type)
+ buf.extend(self._data)
+ return bytes(buf)
+
+ def __repr__(self):
+ params = [
+ __name__+'.'+self.__class__.__name__,
+ f'seq={self.frame.head.seq}',
+ # f'type={self.frame.head.type}',
+ f'cmd={self._type}'
+ ]
+ if self._data:
+ params.append(f'data={self._data.hex()}')
+ return '<'+' '.join(params)+'>'
+
+
+class ModeMessage(CmdMessage):
+ def __init__(self, power_type: PowerType):
+ super().__init__(type=1,
+ data=(power_type.value).to_bytes(1, byteorder='little'))
+
+
+class TargetTemperatureMessage(CmdMessage):
+ def __init__(self, temp: int):
+ super().__init__(type=2,
+ data=bytes(bytearray([temp, 0])))
+
+
+class HandshakeMessage(CmdMessage):
+ def __init__(self):
+ super().__init__(type=0)
+
+ def encrypt(self,
+ outkey: bytes,
+ inkey: bytes,
+ token: bytes,
+ pubkey: bytes):
+ cipher = ciphers.Cipher(algorithms.AES(outkey), modes.CBC(inkey))
+ encryptor = cipher.encryptor()
+
+ ciphertext = bytearray()
+ ciphertext.extend(encryptor.update(token))
+ ciphertext.extend(encryptor.finalize())
+
+ pld = bytearray()
+ pld.append(0)
+ pld.extend(pubkey)
+ pld.extend(ciphertext)
+
+ self.frame.setpayload(pld)
+
+
+# TODO
+# implement resending UDP messages if no answer has been received in a second
+# try at least 5 times, then give up
+class Connection(threading.Thread):
+ seq_no: int
+ source_port: int
+ device_addr: str
+ device_port: int
+ device_token: bytes
+ interrupted: bool
+ waiting_for_response: dict[int, callable]
+ pubkey: Optional[bytes]
+ encinkey: Optional[bytes]
+ encoutkey: Optional[bytes]
+
+ def __init__(self,
+ addr: Union[IPv4Address, IPv6Address],
+ port: int,
+ device_pubkey: bytes,
+ device_token: bytes):
+ super().__init__()
+ self.logger = logging.getLogger(__name__+'.'+self.__class__.__name__)
+ self.setName(self.__class__.__name__)
+ # self.daemon = True
+
+ self.seq_no = -1
+ self.source_port = random.randint(1024, 65535)
+ self.device_addr = str(addr)
+ self.device_port = port
+ self.device_token = device_token
+ self.lock = threading.Lock()
+ self.outgoing_queue = queue.SimpleQueue()
+ self.waiting_for_response = {}
+ self.interrupted = False
+
+ self.pubkey = None
+ self.encinkey = None
+ self.encoutkey = None
+
+ self.prepare_keys(device_pubkey)
+
+ def prepare_keys(self, device_pubkey: bytes):
+ # generate key pair
+ privkey = X25519PrivateKey.generate()
+
+ self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=cryptography.hazmat.primitives._serialization.Encoding.Raw,
+ format=cryptography.hazmat.primitives._serialization.PublicFormat.Raw)))
+
+ # generate shared key
+ device_pubkey = X25519PublicKey.from_public_bytes(
+ bytes(reversed(device_pubkey))
+ )
+ shared_key = bytes(reversed(
+ privkey.exchange(device_pubkey)
+ ))
+
+ # in/out encryption keys
+ digest = hashes.Hash(hashes.SHA256())
+ digest.update(shared_key)
+
+ shared_sha256 = digest.finalize()
+
+ self.encinkey = shared_sha256[:16]
+ self.encoutkey = shared_sha256[16:]
+
+ def run(self):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ sock.bind(('0.0.0.0', self.source_port))
+ sock.settimeout(1)
+
+ while not self.interrupted:
+ if not self.outgoing_queue.empty():
+ message = self.outgoing_queue.get()
+ message.encrypt(outkey=self.encoutkey,
+ inkey=self.encinkey,
+ token=self.device_token,
+ pubkey=self.pubkey)
+ buf = message.frame.pack()
+ self.logger.debug('send: '+buf.hex())
+ self.logger.debug(f'sendto: {self.device_addr}:{self.device_port}')
+ sock.sendto(buf, (self.device_addr, self.device_port))
+ try:
+ data = sock.recv(4096)
+ self.handle_incoming(data)
+ except TimeoutError:
+ pass
+
+ def handle_incoming(self, buf: bytes):
+ self.logger.debug('handle_incoming: '+buf.hex())
+ message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey)
+ if message.frame.head.seq in self.waiting_for_response:
+ self.logger.info(f'received awaited message: {message}')
+ try:
+ f = self.waiting_for_response[message.frame.head.seq]
+ f(message)
+ except Exception as exc:
+ self.logger.exception(exc)
+ finally:
+ del self.waiting_for_response[message.frame.head.seq]
+ else:
+ self.logger.info(f'received message (not awaited): {message}')
+
+ def send_message(self, message: Message, callback: callable):
+ seq = self.next_seqno()
+ message.set_seq(seq)
+ self.outgoing_queue.put(message)
+ self.waiting_for_response[seq] = callback
+
+ def next_seqno(self) -> int:
+ with self.lock:
+ self.seq_no += 1
+ self.logger.debug(f'next_seqno: set to {self.seq_no}')
+ return self.seq_no
diff --git a/src/polaris_kettle_util.py b/src/polaris_kettle_util.py
index 7f5c7c2..419739b 100755
--- a/src/polaris_kettle_util.py
+++ b/src/polaris_kettle_util.py
@@ -1,18 +1,21 @@
#!/usr/bin/env python3
+# SPDX-License-Identifier: BSD-3-Clause
+
import logging
-# import os
import sys
+import time
import paho.mqtt.client as mqtt
# from datetime import datetime
# from html import escape
from argparse import ArgumentParser
+from queue import SimpleQueue
# from home.bot import Wrapper, Context
# from home.api.types import BotType
# from home.util import parse_addr
from home.mqtt import MQTTBase
from home.config import config
-from polaris import Kettle
+from polaris import Kettle, Message, FrameType, PowerType
# from telegram.error import TelegramError
# from telegram import ReplyKeyboardMarkup, InlineKeyboardMarkup, InlineKeyboardButton
@@ -24,6 +27,7 @@ from polaris import Kettle
logger = logging.getLogger(__name__)
+control_tasks = SimpleQueue()
# bot: Optional[Wrapper] = None
# RenderedContent = tuple[str, Optional[InlineKeyboardMarkup]]
@@ -88,6 +92,24 @@ class MQTTServer(MQTTBase):
# ]
# return ReplyKeyboardMarkup(buttons, one_time_keyboard=False)
+def kettle_connection_established(k: Kettle, response: Message):
+ try:
+ assert response.frame.head.type == FrameType.ACK, f'ACK expected, but received: {response}'
+ except AssertionError:
+ k.stop_server()
+ return
+
+ def next_task(k, response):
+ if not control_tasks.empty():
+ task = control_tasks.get()
+ f, args = task(k)
+ args.append(next_task)
+ f(*args)
+ else:
+ k.stop_server()
+
+ next_task(k, response)
+
def main():
tempmin = 30
@@ -98,7 +120,8 @@ def main():
parser.add_argument('-m', dest='mode', required=True, type=str, choices=('mqtt', 'control'))
parser.add_argument('--on', action='store_true')
parser.add_argument('--off', action='store_true')
- parser.add_argument('-t', '--temperature', dest='temp', type=int, choices=range(tempmin, tempmax+tempstep, tempstep))
+ parser.add_argument('-t', '--temperature', dest='temp', type=int, default=tempmax,
+ choices=range(tempmin, tempmax+tempstep, tempstep))
arg = config.load('polaris_kettle_bot', use_cli=True, parser=parser)
@@ -113,27 +136,20 @@ def main():
if arg.on and arg.off:
raise RuntimeError('--on and --off are mutually exclusive')
- k = Kettle(mac='40f52018dec1', token='3a5865f015950cae82cd120e76a80d28')
- k.find()
- print('device found')
-
- k.genkeys()
- k.genshared()
- k.handshake()
-
- if arg.on:
- k.setpower(True)
- elif arg.off:
- k.setpower(False)
- elif arg.temp:
- k.settemperature(arg.temp)
-
- # k.sendfirst()
+ if arg.off:
+ control_tasks.put(lambda k: (k.set_power, [PowerType.OFF]))
+ else:
+ if arg.temp == tempmax:
+ control_tasks.put(lambda k: (k.set_power, [PowerType.ON]))
+ else:
+ control_tasks.put(lambda k: (k.set_target_temperature, [arg.temp]))
+ control_tasks.put(lambda k: (k.set_power, [PowerType.CUSTOM]))
- # print('shared key:', key_to_hex(k.sharedkey))
- # print('shared hash:', key_to_hex(k.sharedsha256))
+ k = Kettle(mac='40f52018dec1', device_token='3a5865f015950cae82cd120e76a80d28')
+ info = k.find()
+ print('found service:', info)
- # print(len(k.sharedsha256))
+ k.start_server(kettle_connection_established)
return 0