diff options
author | Evgeny Zinoviev <me@ch1p.io> | 2022-06-28 03:22:30 +0300 |
---|---|---|
committer | Evgeny Zinoviev <me@ch1p.io> | 2022-06-30 03:47:49 +0300 |
commit | 8f20c9b825cabab7a3f0f5dd2cfe000cc7f72c28 (patch) | |
tree | b5d7446e7b2fcfd42b1e5029aeef33ecb5f9715f /src | |
parent | ee09bc98aedfc6a65a5026432b399345a30a39c8 (diff) |
polaris pwk 1725cgld full support
- significant improvements, correctnesses and stability fixes in
protocol implementation
- correct handling of device appearances and disappearances
- flawlessly functioning telegram bot that re-renders kettle's state
(temperature and other) in real time
Diffstat (limited to 'src')
-rw-r--r-- | src/home/api/errors/api_response_error.py | 4 | ||||
-rw-r--r-- | src/home/api/types/types.py | 1 | ||||
-rw-r--r-- | src/home/api/web_api_client.py | 20 | ||||
-rw-r--r-- | src/home/audio/amixer.py | 4 | ||||
-rw-r--r-- | src/home/bot/__init__.py | 2 | ||||
-rw-r--r-- | src/home/bot/lang.py | 13 | ||||
-rw-r--r-- | src/home/bot/wrapper.py | 34 | ||||
-rw-r--r-- | src/home/camera/util.py | 3 | ||||
-rw-r--r-- | src/home/database/bots.py | 10 | ||||
-rw-r--r-- | src/home/media/node_client.py | 6 | ||||
-rw-r--r-- | src/home/media/record.py | 6 | ||||
-rw-r--r-- | src/home/media/record_client.py | 18 | ||||
-rw-r--r-- | src/home/media/storage.py | 4 | ||||
-rw-r--r-- | src/home/telegram/telegram.py | 3 | ||||
-rw-r--r-- | src/home/util.py | 6 | ||||
-rwxr-xr-x | src/ipcam_server.py | 8 | ||||
-rwxr-xr-x | src/openwrt_logger.py | 5 | ||||
-rw-r--r-- | src/polaris/__init__.py | 14 | ||||
-rw-r--r-- | src/polaris/kettle.py | 271 | ||||
-rw-r--r-- | src/polaris/protocol.py | 1015 | ||||
-rw-r--r-- | src/polaris_kettle_bot.py | 684 | ||||
-rwxr-xr-x | src/polaris_kettle_util.py | 104 | ||||
-rwxr-xr-x | src/sound_bot.py | 17 | ||||
-rwxr-xr-x | src/sound_sensor_server.py | 8 |
24 files changed, 1907 insertions, 353 deletions
diff --git a/src/home/api/errors/api_response_error.py b/src/home/api/errors/api_response_error.py index 6910b2d..85d788b 100644 --- a/src/home/api/errors/api_response_error.py +++ b/src/home/api/errors/api_response_error.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List class ApiResponseError(Exception): @@ -6,7 +6,7 @@ class ApiResponseError(Exception): status_code: int, error_type: str, error_message: str, - error_stacktrace: Optional[list[str]] = None): + error_stacktrace: Optional[List[str]] = None): super().__init__() self.status_code = status_code self.error_message = error_message diff --git a/src/home/api/types/types.py b/src/home/api/types/types.py index b6233e6..4d8b4ff 100644 --- a/src/home/api/types/types.py +++ b/src/home/api/types/types.py @@ -7,6 +7,7 @@ class BotType(Enum): SENSORS = auto() ADMIN = auto() SOUND = auto() + POLARIS_KETTLE = auto() class TemperatureSensorLocation(Enum): diff --git a/src/home/api/web_api_client.py b/src/home/api/web_api_client.py index 34d080c..d6c9dc7 100644 --- a/src/home/api/web_api_client.py +++ b/src/home/api/web_api_client.py @@ -6,7 +6,7 @@ import logging from collections import namedtuple from datetime import datetime from enum import Enum, auto -from typing import Optional, Callable, Union +from typing import Optional, Callable, Union, List, Tuple, Dict from requests.auth import HTTPBasicAuth from .errors import ApiResponseError @@ -28,13 +28,13 @@ class HTTPMethod(Enum): class WebAPIClient: token: str - timeout: Union[float, tuple[float, float]] + timeout: Union[float, Tuple[float, float]] basic_auth: Optional[HTTPBasicAuth] do_async: bool async_error_handler: Optional[Callable] async_success_handler: Optional[Callable] - def __init__(self, timeout: Union[float, tuple[float, float]] = 5): + def __init__(self, timeout: Union[float, Tuple[float, float]] = 5): self.token = config['api']['token'] self.timeout = timeout self.basic_auth = None @@ -66,7 +66,7 @@ class WebAPIClient: }) def log_openwrt(self, - lines: list[tuple[int, str]]): + lines: List[Tuple[int, str]]): return self._post('logs/openwrt', { 'logs': stringify(lines) }) @@ -81,14 +81,14 @@ class WebAPIClient: return [(datetime.fromtimestamp(date), temp, hum) for date, temp, hum in data] def add_sound_sensor_hits(self, - hits: list[tuple[str, int]]): + hits: List[Tuple[str, int]]): return self._post('sound_sensors/hits/', { 'hits': stringify(hits) }) def get_sound_sensor_hits(self, location: SoundSensorLocation, - after: datetime) -> list[dict]: + after: datetime) -> List[dict]: return self._process_sound_sensor_hits_data(self._get('sound_sensors/hits/', { 'after': int(after.timestamp()), 'location': location.value @@ -100,13 +100,13 @@ class WebAPIClient: 'location': location.value })) - def recordings_list(self, extended=False, as_objects=False) -> Union[list[str], list[dict], list[RecordFile]]: + def recordings_list(self, extended=False, as_objects=False) -> Union[List[str], List[dict], List[RecordFile]]: files = self._get('recordings/list/', {'extended': int(extended)})['data'] if as_objects: return MediaNodeClient.record_list_from_serialized(files) return files - def _process_sound_sensor_hits_data(self, data: list[dict]) -> list[dict]: + def _process_sound_sensor_hits_data(self, data: List[dict]) -> List[dict]: for item in data: item['time'] = datetime.fromtimestamp(item['time']) return data @@ -124,7 +124,7 @@ class WebAPIClient: name: str, params: dict, method: HTTPMethod, - files: Optional[dict[str, str]] = None): + files: Optional[Dict[str, str]] = None): if not self.do_async: return self._make_request(name, params, method, files) else: @@ -136,7 +136,7 @@ class WebAPIClient: name: str, params: dict, method: HTTPMethod = HTTPMethod.GET, - files: Optional[dict[str, str]] = None) -> Optional[any]: + files: Optional[Dict[str, str]] = None) -> Optional[any]: domain = config['api']['host'] kwargs = {} diff --git a/src/home/audio/amixer.py b/src/home/audio/amixer.py index 0ab2c64..53e6bce 100644 --- a/src/home/audio/amixer.py +++ b/src/home/audio/amixer.py @@ -2,7 +2,7 @@ import subprocess from ..config import config from threading import Lock -from typing import Union +from typing import Union, List _lock = Lock() @@ -16,7 +16,7 @@ def has_control(s: str) -> bool: return False -def get_caps(s: str) -> list[str]: +def get_caps(s: str) -> List[str]: for control in config['amixer']['controls']: if control['name'] == s: return control['caps'] diff --git a/src/home/bot/__init__.py b/src/home/bot/__init__.py index 5e68af7..0d93af3 100644 --- a/src/home/bot/__init__.py +++ b/src/home/bot/__init__.py @@ -1,6 +1,6 @@ from .reporting import ReportingHelper from .lang import LangPack -from .wrapper import Wrapper, Context, text_filter +from .wrapper import Wrapper, Context, text_filter, handlermethod from .store import Store from .errors import * from .util import command_usage, user_any_name
\ No newline at end of file diff --git a/src/home/bot/lang.py b/src/home/bot/lang.py index 2f10358..624c748 100644 --- a/src/home/bot/lang.py +++ b/src/home/bot/lang.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import logging -from typing import Union, Optional +from typing import Union, Optional, List, Dict logger = logging.getLogger(__name__) @@ -24,7 +26,7 @@ class LangStrings(dict): class LangPack: - strings: dict[str, LangStrings[str, str]] + strings: Dict[str, LangStrings[str, str]] default_lang: str def __init__(self): @@ -57,11 +59,14 @@ class LangPack: return result @property - def languages(self) -> list[str]: + def languages(self) -> List[str]: return list(self.strings.keys()) def get(self, key: str, lang: str, *args) -> str: - return self.strings[lang][key] % args + if args: + return self.strings[lang][key] % args + else: + return self.strings[lang][key] def __call__(self, *args, **kwargs): return self.strings[self.default_lang][args[0]] diff --git a/src/home/bot/wrapper.py b/src/home/bot/wrapper.py index 8ebde4f..5f399ce 100644 --- a/src/home/bot/wrapper.py +++ b/src/home/bot/wrapper.py @@ -8,6 +8,7 @@ from telegram import ( ReplyKeyboardMarkup, CallbackQuery, User, + Message, ) from telegram.ext import ( Updater, @@ -22,7 +23,7 @@ from telegram.ext import ( ) from telegram.error import TimedOut from ..config import config -from typing import Optional, Union +from typing import Optional, Union, List, Tuple from .store import Store from .lang import LangPack from ..api.types import BotType @@ -110,7 +111,7 @@ class Context: kwargs = dict(parse_mode=ParseMode.HTML) if not isinstance(markup, IgnoreMarkup): kwargs['reply_markup'] = markup - self._update.message.reply_text(text, **kwargs) + return self._update.message.reply_text(text, **kwargs) def reply_exc(self, e: Exception) -> None: self.reply(exc2text(e)) @@ -133,7 +134,7 @@ class Context: return self._update.callback_query @property - def args(self) -> Optional[list[str]]: + def args(self) -> Optional[List[str]]: return self._callback_context.args @property @@ -157,6 +158,25 @@ class Context: return self._update.callback_query and self._update.callback_query.data and self._update.callback_query.data != '' +def handlermethod(f: callable): + def _handler(self, update: Update, context: CallbackContext, *args, **kwargs): + ctx = Context(update, + callback_context=context, + markup_getter=self.markup, + lang=self.lang, + store=self.store) + try: + return f(self, ctx, *args, **kwargs) + except Exception as e: + if not self.exception_handler(e, ctx) and not isinstance(e, TimedOut): + logger.exception(e) + if not ctx.is_callback_context(): + ctx.reply_exc(e) + else: + self.notify_user(ctx.user_id, exc2text(e)) + return _handler + + class Wrapper: store: Optional[Store] updater: Updater @@ -252,7 +272,7 @@ class Wrapper: def exception_handler(self, e: Exception, ctx: Context) -> Optional[bool]: pass - def notify_all(self, text_getter: callable, exclude: tuple[int] = ()) -> None: + def notify_all(self, text_getter: callable, exclude: Tuple[int] = ()) -> None: if 'notify_users' not in config['bot']: logger.error('notify_all() called but no notify_users directive found in the config') return @@ -280,6 +300,12 @@ class Wrapper: def send_file(self, user_id, **kwargs): self.updater.bot.send_document(chat_id=user_id, **kwargs) + def edit_message_text(self, user_id, message_id, *args, **kwargs): + self.updater.bot.edit_message_text(chat_id=user_id, message_id=message_id, parse_mode='HTML', *args, **kwargs) + + def delete_message(self, user_id, message_id): + self.updater.bot.delete_message(chat_id=user_id, message_id=message_id) + # # Language Selection # diff --git a/src/home/camera/util.py b/src/home/camera/util.py index 39bfcd3..97f35aa 100644 --- a/src/home/camera/util.py +++ b/src/home/camera/util.py @@ -3,6 +3,7 @@ import os.path import logging import psutil +from typing import List, Tuple from ..util import chunks from ..config import config @@ -62,7 +63,7 @@ async def ffmpeg_cut(input: str, _logger.info(f'ffmpeg_cut({input}): OK') -def dvr_scan_timecodes(timecodes: str) -> list[tuple[int, int]]: +def dvr_scan_timecodes(timecodes: str) -> List[Tuple[int, int]]: tc_backup = timecodes timecodes = timecodes.split(',') diff --git a/src/home/database/bots.py b/src/home/database/bots.py index bc490e1..99befc0 100644 --- a/src/home/database/bots.py +++ b/src/home/database/bots.py @@ -5,7 +5,7 @@ from ..api.types import ( BotType, SoundSensorLocation ) -from typing import Optional +from typing import Optional, List, Tuple from datetime import datetime from html import escape @@ -37,7 +37,7 @@ class BotsDatabase(MySQLDatabase): self.commit() def add_openwrt_logs(self, - lines: list[tuple[datetime, str]]): + lines: List[Tuple[datetime, str]]): now = datetime.now() with self.cursor() as cursor: for line in lines: @@ -47,7 +47,7 @@ class BotsDatabase(MySQLDatabase): self.commit() def add_sound_hits(self, - hits: list[tuple[SoundSensorLocation, int]], + hits: List[Tuple[SoundSensorLocation, int]], time: datetime): with self.cursor() as cursor: for loc, count in hits: @@ -58,7 +58,7 @@ class BotsDatabase(MySQLDatabase): def get_sound_hits(self, location: SoundSensorLocation, after: Optional[datetime] = None, - last: Optional[int] = None) -> list[dict]: + last: Optional[int] = None) -> List[dict]: with self.cursor(dictionary=True) as cursor: sql = "SELECT `time`, hits FROM sound_hits WHERE location=%s" args = [location.name.lower()] @@ -84,7 +84,7 @@ class BotsDatabase(MySQLDatabase): def get_openwrt_logs(self, filter_text: str, min_id: int, - limit: int = None) -> list[OpenwrtLogRecord]: + limit: int = None) -> List[OpenwrtLogRecord]: tz = pytz.timezone('Europe/Moscow') with self.cursor(dictionary=True) as cursor: sql = "SELECT * FROM openwrt WHERE text LIKE %s AND id > %s" diff --git a/src/home/media/node_client.py b/src/home/media/node_client.py index 4430962..eb39898 100644 --- a/src/home/media/node_client.py +++ b/src/home/media/node_client.py @@ -2,7 +2,7 @@ import requests import shutil import logging -from typing import Optional, Union +from typing import Optional, Union, List from .storage import RecordFile from ..util import Addr from ..api.errors import ApiResponseError @@ -25,7 +25,7 @@ class MediaNodeClient: def record_download(self, record_id: int, output: str): return self._call(f'record/download/{record_id}/', save_to=output) - def storage_list(self, extended=False, as_objects=False) -> Union[list[str], list[dict], list[RecordFile]]: + def storage_list(self, extended=False, as_objects=False) -> Union[List[str], List[dict], List[RecordFile]]: r = self._call('storage/list/', params={'extended': int(extended)}) files = r['files'] if as_objects: @@ -33,7 +33,7 @@ class MediaNodeClient: return files @staticmethod - def record_list_from_serialized(files: Union[list[str], list[dict]]): + def record_list_from_serialized(files: Union[List[str], List[dict]]): new_files = [] for f in files: kwargs = {'remote': True} diff --git a/src/home/media/record.py b/src/home/media/record.py index fdb8382..cd7447a 100644 --- a/src/home/media/record.py +++ b/src/home/media/record.py @@ -5,7 +5,7 @@ import time import subprocess import signal -from typing import Optional +from typing import Optional, List, Dict from ..util import find_child_processes, Addr from ..config import config from .storage import RecordFile, RecordStorage @@ -22,7 +22,7 @@ class RecordHistoryItem: request_time: float start_time: float stop_time: float - relations: list[int] + relations: List[int] status: RecordStatus error: Optional[Exception] file: Optional[RecordFile] @@ -76,7 +76,7 @@ class RecordingNotFoundError(Exception): class RecordHistory: - history: dict[int, RecordHistoryItem] + history: Dict[int, RecordHistoryItem] def __init__(self): self.history = {} diff --git a/src/home/media/record_client.py b/src/home/media/record_client.py index f264155..322495c 100644 --- a/src/home/media/record_client.py +++ b/src/home/media/record_client.py @@ -7,7 +7,7 @@ from tempfile import gettempdir from .record import RecordStatus from .node_client import SoundNodeClient, MediaNodeClient, CameraNodeClient from ..util import Addr -from typing import Optional, Callable +from typing import Optional, Callable, Dict class RecordClient: @@ -15,14 +15,14 @@ class RecordClient: interrupted: bool logger: logging.Logger - clients: dict[str, MediaNodeClient] - awaiting: dict[str, dict[int, Optional[dict]]] + clients: Dict[str, MediaNodeClient] + awaiting: Dict[str, Dict[int, Optional[dict]]] error_handler: Optional[Callable] finished_handler: Optional[Callable] download_on_finish: bool def __init__(self, - nodes: dict[str, Addr], + nodes: Dict[str, Addr], error_handler: Optional[Callable] = None, finished_handler: Optional[Callable] = None, download_on_finish=False): @@ -50,7 +50,7 @@ class RecordClient: self.stop() self.logger.exception(exc) - def make_clients(self, nodes: dict[str, Addr]): + def make_clients(self, nodes: Dict[str, Addr]): pass def stop(self): @@ -148,9 +148,9 @@ class RecordClient: class SoundRecordClient(RecordClient): DOWNLOAD_EXTENSION = 'mp3' - # clients: dict[str, SoundNodeClient] + # clients: Dict[str, SoundNodeClient] - def make_clients(self, nodes: dict[str, Addr]): + def make_clients(self, nodes: Dict[str, Addr]): for node, addr in nodes.items(): self.clients[node] = SoundNodeClient(addr) self.awaiting[node] = {} @@ -158,9 +158,9 @@ class SoundRecordClient(RecordClient): class CameraRecordClient(RecordClient): DOWNLOAD_EXTENSION = 'mp4' - # clients: dict[str, CameraNodeClient] + # clients: Dict[str, CameraNodeClient] - def make_clients(self, nodes: dict[str, Addr]): + def make_clients(self, nodes: Dict[str, Addr]): for node, addr in nodes.items(): self.clients[node] = CameraNodeClient(addr) self.awaiting[node] = {}
\ No newline at end of file diff --git a/src/home/media/storage.py b/src/home/media/storage.py index 08ba06a..dd74ff8 100644 --- a/src/home/media/storage.py +++ b/src/home/media/storage.py @@ -3,7 +3,7 @@ import re import shutil import logging -from typing import Optional, Union +from typing import Optional, Union, List from datetime import datetime from ..util import strgen @@ -149,7 +149,7 @@ class RecordStorage: self.root = root - def getfiles(self, as_objects=False) -> Union[list[str], list[RecordFile]]: + def getfiles(self, as_objects=False) -> Union[List[str], List[RecordFile]]: files = [] for name in os.listdir(self.root): path = os.path.join(self.root, name) diff --git a/src/home/telegram/telegram.py b/src/home/telegram/telegram.py index 9c7ea73..2f94f93 100644 --- a/src/home/telegram/telegram.py +++ b/src/home/telegram/telegram.py @@ -1,6 +1,7 @@ import requests import logging +from typing import Tuple from ..config import config @@ -29,7 +30,7 @@ def send_photo(filename: str): def _send_telegram_data(text: str, parse_mode: str = None, - disable_web_page_preview: bool = False) -> tuple[dict, str]: + disable_web_page_preview: bool = False) -> Tuple[dict, str]: data = { 'chat_id': config['telegram']['chat_id'], 'text': text diff --git a/src/home/util.py b/src/home/util.py index 9dd84f6..5050ebb 100644 --- a/src/home/util.py +++ b/src/home/util.py @@ -9,7 +9,7 @@ import random from enum import Enum from datetime import datetime -from typing import Tuple, Optional +from typing import Tuple, Optional, List Addr = Tuple[str, int] # network address type (host, port) @@ -96,7 +96,7 @@ def send_datagram(message: str, addr: Addr) -> None: sock.sendto(message.encode(), addr) -def format_tb(exc) -> Optional[list[str]]: +def format_tb(exc) -> Optional[List[str]]: tb = traceback.format_tb(exc.__traceback__) if not tb: return None @@ -120,7 +120,7 @@ class ChildProcessInfo: self.cmd = cmd -def find_child_processes(ppid: int) -> list[ChildProcessInfo]: +def find_child_processes(ppid: int) -> List[ChildProcessInfo]: p = subprocess.run(['pgrep', '-P', str(ppid), '--list-full'], capture_output=True) if p.returncode != 0: raise OSError(f'pgrep returned {p.returncode}') diff --git a/src/ipcam_server.py b/src/ipcam_server.py index 9e72e68..47e7156 100755 --- a/src/ipcam_server.py +++ b/src/ipcam_server.py @@ -14,7 +14,7 @@ from home.database.sqlite import SQLiteBase from home.camera import util as camutil from enum import Enum -from typing import Optional, Union +from typing import Optional, Union, List from datetime import datetime, timedelta @@ -273,7 +273,7 @@ def get_motion_path(cam: int) -> str: def get_recordings_files(cam: int, - time_filter_type: Optional[TimeFilterType] = None) -> list[dict]: + time_filter_type: Optional[TimeFilterType] = None) -> List[dict]: from_time = 0 to_time = int(time.time()) @@ -305,7 +305,7 @@ def get_recordings_files(cam: int, async def process_fragments(camera: int, filename: str, - fragments: list[tuple[int, int]]) -> None: + fragments: List[Tuple[int, int]]) -> None: time = filename_to_datetime(filename) rec_dir = get_recordings_path(camera) @@ -338,7 +338,7 @@ async def process_fragments(camera: int, async def motion_notify_tg(camera: int, filename: str, - fragments: list[tuple[int, int]]): + fragments: List[Tuple[int, int]]): dt_file = filename_to_datetime(filename) fmt = '%H:%M:%S' diff --git a/src/openwrt_logger.py b/src/openwrt_logger.py index 4d3b310..05fedfe 100755 --- a/src/openwrt_logger.py +++ b/src/openwrt_logger.py @@ -5,6 +5,7 @@ from datetime import datetime from home.config import config from home.database import SimpleState from home.api import WebAPIClient +from typing import Tuple log_file = '/var/log/openwrt.log' @@ -24,7 +25,7 @@ $UDPServerRun 514 """ -def parse_line(line: str) -> tuple[int, str]: +def parse_line(line: str) -> Tuple[int, str]: space_pos = line.index(' ') date = line[:space_pos] @@ -58,7 +59,7 @@ if __name__ == '__main__': state['seek'] = f.tell() state['size'] = fsize - lines: list[tuple[int, str]] = [] + lines: List[Tuple[int, str]] = [] if content != '': for line in content.strip().split('\n'): diff --git a/src/polaris/__init__.py b/src/polaris/__init__.py index aa077ce..f1a7c1d 100644 --- a/src/polaris/__init__.py +++ b/src/polaris/__init__.py @@ -1,4 +1,12 @@ -# SPDX-License-Identifier: BSD-3-Clause +# Polaris PWK 1725CGLD "smart" kettle python library +# -------------------------------------------------- +# Copyright (C) Evgeny Zinoviev, 2022 +# License: BSD-3c -from .kettle import Kettle -from .protocol import Message, FrameType, PowerType
\ No newline at end of file +from .kettle import Kettle, DeviceListener +from .protocol import ( + PowerType, + IncomingMessageListener, + ConnectionStatusListener, + ConnectionStatus +)
\ No newline at end of file diff --git a/src/polaris/kettle.py b/src/polaris/kettle.py index 37f6813..1b995fd 100644 --- a/src/polaris/kettle.py +++ b/src/polaris/kettle.py @@ -1,97 +1,238 @@ -# SPDX-License-Identifier: BSD-3-Clause +# Polaris PWK 1725CGLD smart kettle python library +# ------------------------------------------------ +# Copyright (C) Evgeny Zinoviev, 2022 +# License: BSD-3c from __future__ import annotations +import threading import logging import zeroconf -import cryptography.hazmat.primitives._serialization -from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey -from cryptography.hazmat.primitives import hashes - -from functools import partial -from abc import ABC -from ipaddress import ip_address -from typing import Optional +from abc import abstractmethod +from ipaddress import ip_address, IPv4Address, IPv6Address +from typing import Optional, List, Union from .protocol import ( - Connection, + UDPConnection, ModeMessage, - HandshakeMessage, TargetTemperatureMessage, - Message, - PowerType + PowerType, + ConnectionStatus, + ConnectionStatusListener, + WrappedMessage ) -_logger = logging.getLogger(__name__) - -# Polaris PWK 1725CGLD IoT kettle -class Kettle(zeroconf.ServiceListener, ABC): - macaddr: str - device_token: str - sb: Optional[zeroconf.ServiceBrowser] - found_device: Optional[zeroconf.ServiceInfo] - conn: Optional[Connection] +class DeviceDiscover(threading.Thread, zeroconf.ServiceListener): + si: Optional[zeroconf.ServiceInfo] + _mac: str + _sb: Optional[zeroconf.ServiceBrowser] + _zc: Optional[zeroconf.Zeroconf] + _listeners: List[DeviceListener] + _valid_addresses: List[Union[IPv4Address, IPv6Address]] + _only_ipv4: bool - def __init__(self, mac: str, device_token: str): + def __init__(self, mac: str, + listener: Optional[DeviceListener] = None, + only_ipv4=True): super().__init__() - self.zeroconf = zeroconf.Zeroconf() - self.sb = None - self.macaddr = mac - self.device_token = device_token - self.found_device = None - self.conn = None + self.si = None + self._mac = mac + self._zc = None + self._sb = None + self._only_ipv4 = only_ipv4 + self._valid_addresses = [] + self._listeners = [] + if isinstance(listener, DeviceListener): + self._listeners.append(listener) + self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}') + + def add_listener(self, listener: DeviceListener): + if listener not in self._listeners: + self._listeners.append(listener) + else: + self._logger.warning(f'add_listener: listener {listener} already in the listeners list') + + def set_info(self, info: zeroconf.ServiceInfo): + valid_addresses = self._get_valid_addresses(info) + if not valid_addresses: + raise ValueError('no valid addresses') + self._valid_addresses = valid_addresses + self.si = info + for f in self._listeners: + try: + f.device_updated() + except Exception as exc: + self._logger.error(f'set_info: error while calling device_updated on {f}') + self._logger.exception(exc) - def find(self) -> zeroconf.ServiceInfo: - self.sb = zeroconf.ServiceBrowser(self.zeroconf, "_syncleo._udp.local.", self) - self.sb.join() + def add_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: + self._add_update_service('add_service', zc, type_, name) - return self.found_device + def update_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: + self._add_update_service('update_service', zc, type_, name) - # zeroconf.ServiceListener implementation - def add_service(self, - zc: zeroconf.Zeroconf, - type_: str, - name: str) -> None: - if name.startswith(f'{self.macaddr}.'): - info = zc.get_service_info(type_, name) + def _add_update_service(self, method: str, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: + info = zc.get_service_info(type_, name) + if name.startswith(f'{self._mac}.'): + self._logger.info(f'{method}: type={type_} name={name}') + try: + self.set_info(info) + except ValueError as exc: + self._logger.error(f'{method}: rejected: {str(exc)}') + else: + self._logger.debug(f'{method}: mac not matched: {info}') + + def remove_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None: + if name.startswith(f'{self._mac}.'): + self._logger.info(f'remove_service: type={type_} name={name}') + # TODO what to do here?! + + def run(self): + self._logger.info('starting zeroconf service browser') + ip_version = zeroconf.IPVersion.V4Only if self._only_ipv4 else zeroconf.IPVersion.All + self._zc = zeroconf.Zeroconf(ip_version=ip_version) + self._sb = zeroconf.ServiceBrowser(self._zc, "_syncleo._udp.local.", self) + self._sb.join() + + def stop(self): + if self._sb: try: - self.sb.cancel() + self._sb.cancel() except RuntimeError: pass - self.zeroconf.close() - self.found_device = info + self._sb = None + self._zc.close() + self._zc = None + + def _get_valid_addresses(self, si: zeroconf.ServiceInfo) -> List[Union[IPv4Address, IPv6Address]]: + valid = [] + for addr in map(ip_address, si.addresses): + if self._only_ipv4 and not isinstance(addr, IPv4Address): + continue + if isinstance(addr, IPv4Address) and str(addr).startswith('169.254.'): + continue + valid.append(addr) + return valid - assert self.device_curve == 29, f'curve type {self.device_curve} is not implemented' - - def start_server(self, callback: callable): - addresses = list(map(ip_address, self.found_device.addresses)) - self.conn = Connection(addr=addresses[0], - port=int(self.found_device.port), - device_pubkey=self.device_pubkey, - device_token=bytes.fromhex(self.device_token)) + @property + def pubkey(self) -> bytes: + return bytes.fromhex(self.si.properties[b'public'].decode()) - # shake the kettle's hand - self._pass_message(HandshakeMessage(), callback) - self.conn.start() + @property + def curve(self) -> int: + return int(self.si.properties[b'curve'].decode()) - def stop_server(self): - self.conn.interrupted = True + @property + def addr(self) -> Union[IPv4Address, IPv6Address]: + return self._valid_addresses[0] @property - def device_pubkey(self) -> bytes: - return bytes.fromhex(self.found_device.properties[b'public'].decode()) + def port(self) -> int: + return int(self.si.port) @property - def device_curve(self) -> int: - return int(self.found_device.properties[b'curve'].decode()) + def protocol(self) -> int: + return int(self.si.properties[b'protocol'].decode()) + + +class DeviceListener: + @abstractmethod + def device_updated(self): + pass + + +class Kettle(DeviceListener, ConnectionStatusListener): + mac: str + device: Optional[DeviceDiscover] + device_token: str + conn: Optional[UDPConnection] + conn_status: Optional[ConnectionStatus] + _logger: logging.Logger + _find_evt: threading.Event + + def __init__(self, mac: str, device_token: str): + super().__init__() + self.mac = mac + self.device = None + self.device_token = device_token + self.conn = None + self.conn_status = None + self._find_evt = threading.Event() + self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}') + + def device_updated(self): + self._find_evt.set() + self._logger.info(f'device updated, service info: {self.device.si}') + + def connection_status_updated(self, status: ConnectionStatus): + self.conn_status = status + + def discover(self, wait=True, timeout=None, listener=None) -> Optional[zeroconf.ServiceInfo]: + do_start = False + if not self.device: + self.device = DeviceDiscover(self.mac, listener=self, only_ipv4=True) + do_start = True + self._logger.debug('discover: started device discovery') + else: + self._logger.warning('discover: already started') + + if listener is not None: + self.device.add_listener(listener) + + if do_start: + self.device.start() + + if wait: + self._find_evt.clear() + try: + self._find_evt.wait(timeout=timeout) + except KeyboardInterrupt: + self.device.stop() + return None + return self.device.si + + def start_server_if_needed(self, + incoming_message_listener=None, + connection_status_listener=None): + if self.conn: + self._logger.warning('start_server_if_needed: server is already started!') + self.conn.set_address(self.device.addr, self.device.port) + self.conn.set_device_pubkey(self.device.pubkey) + return + + assert self.device.curve == 29, f'curve type {self.device.curve} is not implemented' + assert self.device.protocol == 2, f'protocol {self.device.protocol} is not supported' + + self.conn = UDPConnection(addr=self.device.addr, + port=self.device.port, + device_pubkey=self.device.pubkey, + device_token=bytes.fromhex(self.device_token)) + if incoming_message_listener: + self.conn.add_incoming_message_listener(incoming_message_listener) + + self.conn.add_connection_status_listener(self) + if connection_status_listener: + self.conn.add_connection_status_listener(connection_status_listener) + + self.conn.start() + + def stop_all(self): + # when we stop server, we should also stop device discovering service + if self.conn: + self.conn.interrupted = True + self.conn = None + self.device.stop() + self.device = None + + def is_connected(self) -> bool: + return self.conn is not None and self.conn_status == ConnectionStatus.CONNECTED def set_power(self, power_type: PowerType, callback: callable): - self._pass_message(ModeMessage(power_type), callback) + message = ModeMessage(power_type) + self.conn.enqueue_message(WrappedMessage(message, handler=callback, ack=True)) def set_target_temperature(self, temp: int, callback: callable): - self._pass_message(TargetTemperatureMessage(temp), callback) - - def _pass_message(self, message: Message, callback: callable): - self.conn.send_message(message, partial(callback, self)) + message = TargetTemperatureMessage(temp) + self.conn.enqueue_message(WrappedMessage(message, handler=callback, ack=True)) diff --git a/src/polaris/protocol.py b/src/polaris/protocol.py index cc4e36a..5d7390f 100644 --- a/src/polaris/protocol.py +++ b/src/polaris/protocol.py @@ -1,39 +1,61 @@ -# SPDX-License-Identifier: BSD-3-Clause +# Polaris PWK 1725CGLD "smart" kettle python library +# -------------------------------------------------- +# Copyright (C) Evgeny Zinoviev, 2022 +# License: BSD-3c from __future__ import annotations + import logging import socket import random import struct import threading -import queue +import time -from enum import Enum -from typing import Union, Optional, Any +from abc import abstractmethod, ABC +from enum import Enum, auto +from typing import Union, Optional, Dict, Tuple, List from ipaddress import IPv4Address, IPv6Address -import cryptography.hazmat.primitives._serialization +import cryptography.hazmat.primitives._serialization as srlz from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey from cryptography.hazmat.primitives import ciphers, padding, hashes from cryptography.hazmat.primitives.ciphers import algorithms, modes - +ReprDict = Dict[str, Union[str, int, float, bool]] _logger = logging.getLogger(__name__) +PING_FREQUENCY = 3 +RESEND_ATTEMPTS = 5 +READ_TIMEOUT = 1 +ERROR_TIMEOUT = 15 +MESSAGE_QUEUE_REMOVE_DELAY = 13 # after what time to delete (and pass False to handlers, if needed) messages with phase=DONE from queue +DISCONNECT_TIMEOUT = 15 + -# drop-in replacement for Java API +def safe_callback_call(f: callable, + *args, + logger: logging.Logger = None, + error_message: str = None): + try: + return f(*args) + except Exception as exc: + logger.error(f'{error_message}, see exception below:') + logger.exception(exc) + return None + + +# drop-in replacement for java.lang.System.arraycopy # TODO: rewrite def arraycopy(src, src_pos, dest, dest_pos, length): for i in range(length): dest[i + dest_pos] = src[i + src_pos] -class FrameType(Enum): - ACK = 0 - CMD = 1 - AUX = 2 - NAK = 3 +# "convert" unsigned byte to signed +def u8_to_s8(b: int) -> int: + return struct.unpack('b', bytes([b]))[0] class PowerType(Enum): @@ -44,23 +66,37 @@ class PowerType(Enum): # update: if I set it to '2', it just resets to '0' +# low-level protocol structures +# ----------------------------- + +class FrameType(Enum): + ACK = 0 + CMD = 1 + AUX = 2 + NAK = 3 + + class FrameHead: - seq: int # u8 + seq: Optional[int] # u8 type: FrameType # u8 - length: int # u16 + length: int # u16. This is the length of FrameItem's payload @staticmethod def from_bytes(buf: bytes) -> FrameHead: seq, ft, length = struct.unpack('<BBH', buf) return FrameHead(seq, FrameType(ft), length) - def __init__(self, seq: int, frame_type: FrameType, length: Optional[int] = None): + def __init__(self, + seq: Optional[int], + frame_type: FrameType, + length: Optional[int] = None): self.seq = seq self.type = frame_type self.length = length or 0 def pack(self) -> bytes: assert self.length != 0, "FrameHead.length has not been set" + assert self.seq is not None, "FrameHead.seq has not been set" return struct.pack('<BBH', self.seq, self.type.value, self.length) @@ -84,30 +120,47 @@ class FrameItem: return bytes(ba) +# high-level wrappers around FrameItem +# ------------------------------------ + +class MessagePhase(Enum): + WAITING = 0 + SENT = 1 + DONE = 2 + + class Message: frame: Optional[FrameItem] + id: int + + _global_id = 0 def __init__(self): self.frame = None + # global internal message id, only useful for debugging purposes + self.id = self.next_id() + + def __repr__(self): + return f'<{self.__class__.__name__} id={self.id} seq={self.frame.head.seq}>' + @staticmethod - def from_encrypted(buf: bytes, - inkey: bytes, - outkey: bytes) -> Message: - # _logger.debug('[from_encrypted] buf='+buf.hex()) - # print(f'buf len={len(buf)}') + def next_id(): + _id = Message._global_id + Message._global_id += 1 + return _id + + @staticmethod + def from_encrypted(buf: bytes, inkey: bytes, outkey: bytes) -> Message: + _logger.debug(f'Message:from_encrypted: buf={buf.hex()}') + assert len(buf) >= 4, 'invalid size' head = FrameHead.from_bytes(buf[:4]) assert len(buf) == head.length + 4, f'invalid buf size ({len(buf)} != {head.length})' - payload = buf[4:] - - # byte b = paramFrameHead.seq; b = head.seq - # TODO check if protocol is 2, otherwise raise an exception - j = b & 0xF k = b >> 4 & 0xF @@ -123,15 +176,9 @@ class Message: decryptor = cipher.decryptor() decrypted_data = decryptor.update(payload) + decryptor.finalize() - # print(f'head.length={head.length} len(decr)={len(decrypted_data)}') - # if len(decrypted_data) > head.length: unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() decrypted_data = unpadder.update(decrypted_data) - # try: decrypted_data += unpadder.finalize() - # except ValueError as exc: - # _logger.exception(exc) - # pass assert len(decrypted_data) != 0, 'decrypted data is null' assert head.seq == decrypted_data[0], f'decrypted seq mismatch {head.seq} != {decrypted_data[0]}' @@ -145,29 +192,54 @@ class Message: return NakMessage(head.seq) elif head.type == FrameType.AUX: + # TODO implement AUX raise NotImplementedError('FrameType AUX is not yet implemented') elif head.type == FrameType.CMD: - cmd = decrypted_data[0] + type = decrypted_data[1] data = decrypted_data[2:] - return CmdMessage(head.seq, cmd, data) + + cl = UnknownMessage + + subclasses = [cl for cl in CmdIncomingMessage.__subclasses__() if cl is not SimpleBooleanMessage] + subclasses.extend(SimpleBooleanMessage.__subclasses__()) + + for _cl in subclasses: + # `UnknownMessage` is a special class that holds a packed command that we don't recognize. + # It will be used anyway if we don't find a match, so skip it here + if _cl == UnknownMessage: + continue + + if _cl.TYPE == type: + cl = _cl + break + + m = cl.from_packed_data(data, seq=head.seq) + if isinstance(m, UnknownMessage): + m.set_type(type) + return m else: raise NotImplementedError(f'Unexpected frame type: {head.type}') - @property - def data(self) -> bytes: + def pack_data(self) -> bytes: return b'' - def encrypt(self, - outkey: bytes, - inkey: bytes, - token: bytes, - pubkey: bytes): + @property + def seq(self) -> Union[int, None]: + try: + return self.frame.head.seq + except: + return None + + @seq.setter + def seq(self, seq: int): + self.frame.head.seq = seq + def encrypt(self, outkey: bytes, inkey: bytes, token: bytes, pubkey: bytes): assert self.frame is not None - data = self.data + data = self._get_data_to_encrypt() assert data is not None b = self.frame.head.seq @@ -199,7 +271,7 @@ class Message: arraycopy(data, 0, newdata, 1, len(data)) newdata = bytes(newdata) - _logger.debug('payload to be sent: ' + newdata.hex()) + _logger.debug('frame payload to be encrypted: ' + newdata.hex()) padder = padding.PKCS7(algorithms.AES.block_size).padder() ciphertext = bytearray() @@ -208,74 +280,143 @@ class Message: self.frame.setpayload(ciphertext) - def set_seq(self, seq: int): - self.frame.head.seq = seq - - def __repr__(self): - return f'<{self.__class__.__name__} seq={self.frame.head.seq}>' + def _get_data_to_encrypt(self) -> bytes: + return self.pack_data() -class AckMessage(Message): - def __init__(self, seq: int = 0): +class AckMessage(Message, ABC): + def __init__(self, seq: Optional[int] = None): super().__init__() - self.frame = FrameItem(FrameHead(seq, FrameType.ACK, 0)) + self.frame = FrameItem(FrameHead(seq, FrameType.ACK, None)) -class NakMessage(Message): - def __init__(self, seq: int = 0): +class NakMessage(Message, ABC): + def __init__(self, seq: Optional[int] = None): super().__init__() - self.frame = FrameItem(FrameHead(seq, FrameType.NAK, 0)) + self.frame = FrameItem(FrameHead(seq, FrameType.NAK, None)) class CmdMessage(Message): - _type: Optional[int] - _data: bytes + type: Optional[int] + data: bytes - def __init__(self, seq=0, - type: Optional[int] = None, - data: bytes = b''): - super().__init__() - self._data = data - if type is not None: - self.frame = FrameItem(FrameHead(seq, FrameType.CMD)) - self._type = type - else: - self._type = None + TYPE = None - @property - def data(self) -> bytes: + def _get_data_to_encrypt(self) -> bytes: buf = bytearray() - buf.append(self._type) - buf.extend(self._data) + buf.append(self.get_type()) + buf.extend(self.pack_data()) return bytes(buf) + def __init__(self, seq: Optional[int] = None): + super().__init__() + self.frame = FrameItem(FrameHead(seq, FrameType.CMD)) + self.data = b'' + + def _repr_fields(self) -> ReprDict: + return { + 'cmd': self.get_type() + } + def __repr__(self): params = [ __name__+'.'+self.__class__.__name__, - f'seq={self.frame.head.seq}', - # f'type={self.frame.head.type}', - f'cmd={self._type}' + f'id={self.id}', + f'seq={self.seq}' ] - if self._data: - params.append(f'data={self._data.hex()}') + fields = self._repr_fields() + if fields: + for k, v in fields.items(): + params.append(f'{k}={v}') + elif self.data: + params.append(f'data={self.data.hex()}') return '<'+' '.join(params)+'>' + def get_type(self) -> int: + return self.__class__.TYPE + + +class CmdIncomingMessage(CmdMessage): + @staticmethod + @abstractmethod + def from_packed_data(cls, data: bytes, seq: Optional[int] = None): + pass + + @abstractmethod + def _repr_fields(self) -> ReprDict: + pass + + +class CmdOutgoingMessage(CmdMessage): + @abstractmethod + def pack_data(self) -> bytes: + return b'' + + +class ModeMessage(CmdOutgoingMessage, CmdIncomingMessage): + TYPE = 1 + + pt: PowerType + + def __init__(self, power_type: PowerType, seq: Optional[int] = None): + super().__init__(seq) + self.pt = power_type + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> ModeMessage: + assert len(data) == 1, 'data size expected to be 1' + mode, = struct.unpack('B', data) + return ModeMessage(PowerType(mode), seq=seq) + + def pack_data(self) -> bytes: + return self.pt.value.to_bytes(1, byteorder='little') -class ModeMessage(CmdMessage): - def __init__(self, power_type: PowerType): - super().__init__(type=1, - data=(power_type.value).to_bytes(1, byteorder='little')) + def _repr_fields(self) -> ReprDict: + return {'mode': self.pt.name} -class TargetTemperatureMessage(CmdMessage): - def __init__(self, temp: int): - super().__init__(type=2, - data=bytes(bytearray([temp, 0]))) +class TargetTemperatureMessage(CmdOutgoingMessage, CmdIncomingMessage): + temperature: int + TYPE = 2 + def __init__(self, temp: int, seq: Optional[int] = None): + super().__init__(seq) + self.temperature = temp + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> TargetTemperatureMessage: + assert len(data) == 2, 'data size expected to be 2' + nat, frac = struct.unpack('BB', data) + temp = int(nat + (frac / 100)) + return TargetTemperatureMessage(temp, seq=seq) + + def pack_data(self) -> bytes: + return bytes([self.temperature, 0]) + + def _repr_fields(self) -> ReprDict: + return {'temperature': self.temperature} + + +class PingMessage(CmdIncomingMessage, CmdOutgoingMessage): + TYPE = 255 + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> PingMessage: + assert len(data) == 0, 'no data expected' + return PingMessage(seq=seq) + + def pack_data(self) -> bytes: + return b'' + + def _repr_fields(self) -> ReprDict: + return {} + + +# This is the first protocol message. Sent by a client. +# Kettle usually ACKs this, but sometimes i don't get any ACK and the very next message is HandshakeResponseMessage. class HandshakeMessage(CmdMessage): - def __init__(self): - super().__init__(type=0) + TYPE = 0 def encrypt(self, outkey: bytes, @@ -297,20 +438,312 @@ class HandshakeMessage(CmdMessage): self.frame.setpayload(pld) -# TODO -# implement resending UDP messages if no answer has been received in a second -# try at least 5 times, then give up -class Connection(threading.Thread): - seq_no: int +# Kettle either sends this right after the handshake, of first it ACKs the handshake then sends this. +class HandshakeResponseMessage(CmdIncomingMessage): + TYPE = 0 + + protocol: int + fw_major: int + fw_minor: int + mode: int + token: bytes + + def __init__(self, + protocol: int, + fw_major: int, + fw_minor: int, + mode: int, + token: bytes, + seq: Optional[int] = None): + super().__init__(seq) + self.protocol = protocol + self.fw_major = fw_major + self.fw_minor = fw_minor + self.mode = mode + self.token = token + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> HandshakeResponseMessage: + protocol, fw_major, fw_minor, mode = struct.unpack('<HBBB', data[:5]) + return HandshakeResponseMessage(protocol, fw_major, fw_minor, mode, token=data[5:], seq=seq) + + def _repr_fields(self) -> ReprDict: + return { + 'protocol': self.protocol, + 'fw': f'{self.fw_major}.{self.fw_minor}', + 'mode': self.mode, + 'token': self.token.hex() + } + + +# Apparently, some hardware info. +# On the other hand, if you look at com.syncleiot.iottransport.commands.CmdHardware, its mqtt topic is "mcu_firmware". +# My device returns 1.1.1. The thing uses on ESP8266 MCU under the hood (or, more precisely, under a piece of cheap +# plastic), so maybe 1.1.1 is the MCU fw revision. +class DeviceHardwareMessage(CmdIncomingMessage): + TYPE = 143 # -113 + + hw: List[int] + + def __init__(self, hw: List[int], seq: Optional[int] = None): + super().__init__(seq) + self.hw = hw + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> DeviceHardwareMessage: + assert len(data) == 3, 'invalid data size, expected 3' + hw = list(struct.unpack('<BBB', data)) + return DeviceHardwareMessage(hw, seq=seq) + + def _repr_fields(self) -> ReprDict: + return {'device_hardware': '.'.join(map(str, self.hw))} + + +# This message is sent by kettle right after the HandshakeMessageResponse. +# The diagnostic data is supposed to be sent to vendor, which we, obviously, not going to do. +# So just ACK and skip it. +class DeviceDiagnosticMessage(CmdIncomingMessage): + TYPE = 145 # -111 + + diag_data: bytes + + def __init__(self, diag_data: bytes, seq: Optional[int] = None): + super().__init__(seq) + self.diag_data = diag_data + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> DeviceDiagnosticMessage: + return DeviceDiagnosticMessage(diag_data=data, seq=seq) + + def _repr_fields(self) -> ReprDict: + return {'diag_data': self.diag_data.hex()} + + +class SimpleBooleanMessage(ABC, CmdIncomingMessage): + value: bool + + def __init__(self, value: bool, seq: Optional[int] = None): + super().__init__(seq) + self.value = value + + @classmethod + def from_packed_data(cls, data: bytes, seq: Optional[int] = None): + assert len(data) == 1, 'invalid data size, expected 1' + enabled, = struct.unpack('<B', data) + return cls(value=enabled == 1, seq=seq) + + @abstractmethod + def _repr_fields(self) -> ReprDict: + pass + + +class AccessControlMessage(SimpleBooleanMessage): + TYPE = 133 # -123 + + def _repr_fields(self) -> ReprDict: + return {'acl_enabled': self.value} + + +class ErrorMessage(SimpleBooleanMessage): + TYPE = 7 + + def _repr_fields(self) -> ReprDict: + return {'error': self.value} + + +class ChildLockMessage(SimpleBooleanMessage): + TYPE = 30 + + def _repr_fields(self) -> ReprDict: + return {'child_lock': self.value} + + +class VolumeMessage(SimpleBooleanMessage): + TYPE = 9 + + def _repr_fields(self) -> ReprDict: + return {'volume': self.value} + + +class BacklightMessage(SimpleBooleanMessage): + TYPE = 28 + + def _repr_fields(self) -> ReprDict: + return {'backlight': self.value} + + +class CurrentTemperatureMessage(CmdIncomingMessage): + TYPE = 20 + + current_temperature: int + + def __init__(self, temp: int, seq: Optional[int] = None): + super().__init__(seq) + self.current_temperature = temp + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> CurrentTemperatureMessage: + assert len(data) == 2, 'data size expected to be 2' + nat, frac = struct.unpack('BB', data) + temp = int(nat + (frac / 100)) + return CurrentTemperatureMessage(temp, seq=seq) + + def pack_data(self) -> bytes: + return bytes([self.current_temperature, 0]) + + def _repr_fields(self) -> ReprDict: + return {'current_temperature': self.current_temperature} + + +class UnknownMessage(CmdIncomingMessage): + type: Optional[int] + data: bytes + + def __init__(self, data: bytes, **kwargs): + super().__init__(**kwargs) + self.type = None + self.data = data + + @classmethod + def from_packed_data(cls, data: bytes, seq=0) -> UnknownMessage: + return UnknownMessage(data, seq=seq) + + def set_type(self, type: int): + self.type = type + + def get_type(self) -> int: + return self.type + + def _repr_fields(self) -> ReprDict: + return { + 'type': self.type, + 'data': self.data.hex() + } + + +class WrappedMessage: + _message: Message + _handler: Optional[callable] + _validator: Optional[callable] + _logger: Optional[logging.Logger] + _phase: MessagePhase + _phase_update_time: float + + def __init__(self, + message: Message, + handler: Optional[callable] = None, + validator: Optional[callable] = None, + ack=False): + self._message = message + self._handler = handler + self._validator = validator + self._logger = None + self._phase = MessagePhase.WAITING + self._phase_update_time = 0 + if not validator and ack: + self._validator = lambda m: isinstance(m, AckMessage) + + def setlogger(self, logger: logging.Logger): + self._logger = logger + + def validate(self, message: Message): + if not self._validator: + return True + return self._validator(message) + + def call(self, *args, error_message: str = None) -> None: + if not self._handler: + return + try: + self._handler(*args) + except Exception as exc: + logger = self._logger or logging.getLogger(self.__class__.__name__) + logger.error(f'{error_message}, see exception below:') + logger.exception(exc) + + @property + def phase(self) -> MessagePhase: + return self._phase + + @phase.setter + def phase(self, phase: MessagePhase): + self._phase = phase + self._phase_update_time = 0 if phase == MessagePhase.WAITING else time.time() + + @property + def phase_update_time(self) -> float: + return self._phase_update_time + + @property + def message(self) -> Message: + return self._message + + @property + def id(self) -> int: + return self._message.id + + @property + def seq(self) -> int: + return self._message.seq + + @seq.setter + def seq(self, seq: int): + self._message.seq = seq + + def __repr__(self): + return f'<{__name__}.{self.__class__.__name__} message={self._message.__repr__()}>' + + +# Connection stuff +# Well, strictly speaking, as it's UDP, there's no connection, but who cares. +# --------------------------------------------------------------------------- + +class IncomingMessageListener: + @abstractmethod + def incoming_message(self, message: Message) -> Optional[Message]: + pass + + +class ConnectionStatus(Enum): + NOT_CONNECTED = auto() + CONNECTING = auto() + CONNECTED = auto() + RECONNECTING = auto() + DISCONNECTED = auto() + + +class ConnectionStatusListener: + @abstractmethod + def connection_status_updated(self, status: ConnectionStatus): + pass + + +class UDPConnection(threading.Thread, ConnectionStatusListener): + inseq: int + outseq: int source_port: int device_addr: str device_port: int device_token: bytes + device_pubkey: bytes interrupted: bool - waiting_for_response: dict[int, callable] + response_handlers: Dict[int, WrappedMessage] + outgoing_queue: List[WrappedMessage] pubkey: Optional[bytes] encinkey: Optional[bytes] encoutkey: Optional[bytes] + inc_listeners: List[IncomingMessageListener] + conn_listeners: List[ConnectionStatusListener] + outgoing_time: float + outgoing_time_1st: float + incoming_time: float + status: ConnectionStatus + reconnect_tries: int + + _addr_lock: threading.Lock + _iml_lock: threading.Lock + _csl_lock: threading.Lock + _st_lock: threading.Lock def __init__(self, addr: Union[IPv4Address, IPv6Address], @@ -318,36 +751,105 @@ class Connection(threading.Thread): device_pubkey: bytes, device_token: bytes): super().__init__() - self.logger = logging.getLogger(__name__+'.'+self.__class__.__name__) + self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__} <{hex(id(self))}>') self.setName(self.__class__.__name__) - # self.daemon = True - self.seq_no = -1 + self.inseq = 0 + self.outseq = 0 self.source_port = random.randint(1024, 65535) self.device_addr = str(addr) self.device_port = port self.device_token = device_token - self.lock = threading.Lock() - self.outgoing_queue = queue.SimpleQueue() - self.waiting_for_response = {} + self.device_pubkey = device_pubkey + self.outgoing_queue = [] + self.response_handlers = {} self.interrupted = False + self.outgoing_time = 0 + self.outgoing_time_1st = 0 + self.incoming_time = 0 + self.inc_listeners = [] + self.conn_listeners = [self] + self.status = ConnectionStatus.NOT_CONNECTED + self.reconnect_tries = 0 + + self._iml_lock = threading.Lock() + self._csl_lock = threading.Lock() + self._addr_lock = threading.Lock() + self._st_lock = threading.Lock() self.pubkey = None self.encinkey = None self.encoutkey = None - self.prepare_keys(device_pubkey) - - def prepare_keys(self, device_pubkey: bytes): + def connection_status_updated(self, status: ConnectionStatus): + # self._logger.info(f'connection_status_updated: status = {status}') + with self._st_lock: + # self._logger.debug(f'connection_status_updated: lock acquired') + self.status = status + if status == ConnectionStatus.RECONNECTING: + self.reconnect_tries += 1 + if status in (ConnectionStatus.CONNECTED, ConnectionStatus.NOT_CONNECTED, ConnectionStatus.DISCONNECTED): + self.reconnect_tries = 0 + + def _cleanup(self): + # erase outgoing queue + for wm in self.outgoing_queue: + wm.call(False, + error_message=f'_cleanup: exception while calling cb(False) on message {wm.message}') + self.outgoing_queue = [] + self.response_handlers = {} + + # reset timestamps + self.incoming_time = 0 + self.outgoing_time = 0 + self.outgoing_time_1st = 0 + + self._logger.info('_cleanup: done') + + def set_address(self, addr: Union[IPv4Address, IPv6Address], port: int): + with self._addr_lock: + if self.device_addr != str(addr) or self.device_port != port: + self.device_addr = str(addr) + self.device_port = port + self._logger.info(f'updated device network address: {self.device_addr}:{self.device_port}') + + def set_device_pubkey(self, pubkey: bytes): + if self.device_pubkey.hex() != pubkey.hex(): + self._logger.info(f'device pubkey has changed (old={self.device_pubkey.hex()}, new={pubkey.hex()})') + self.device_pubkey = pubkey + self._notify_cs(ConnectionStatus.RECONNECTING) + + def get_address(self) -> Tuple[str, int]: + with self._addr_lock: + return self.device_addr, self.device_port + + def add_incoming_message_listener(self, listener: IncomingMessageListener): + with self._iml_lock: + if listener not in self.inc_listeners: + self.inc_listeners.append(listener) + + def add_connection_status_listener(self, listener: ConnectionStatusListener): + with self._csl_lock: + if listener not in self.conn_listeners: + self.conn_listeners.append(listener) + + def _notify_cs(self, status: ConnectionStatus): + # self._logger.debug(f'_notify_cs: status={status}') + with self._csl_lock: + for obj in self.conn_listeners: + # self._logger.debug(f'_notify_cs: notifying {obj}') + obj.connection_status_updated(status) + + def _prepare_keys(self): # generate key pair privkey = X25519PrivateKey.generate() - self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=cryptography.hazmat.primitives._serialization.Encoding.Raw, - format=cryptography.hazmat.primitives._serialization.PublicFormat.Raw))) + self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=srlz.Encoding.Raw, + format=srlz.PublicFormat.Raw))) # generate shared key device_pubkey = X25519PublicKey.from_public_bytes( - bytes(reversed(device_pubkey)) + bytes(reversed(self.device_pubkey)) ) shared_key = bytes(reversed( privkey.exchange(device_pubkey) @@ -362,51 +864,286 @@ class Connection(threading.Thread): self.encinkey = shared_sha256[:16] self.encoutkey = shared_sha256[16:] + self._logger.info('encryption keys have been created') + + def _handshake_callback(self, r: MessageResponse): + # if got error for our HandshakeMessage, reset everything and try again + if r is False: + # self._logger.debug('_handshake_callback: set status=RECONNETING') + self._notify_cs(ConnectionStatus.RECONNECTING) + else: + # self._logger.debug('_handshake_callback: set status=CONNECTED') + self._notify_cs(ConnectionStatus.CONNECTED) + def run(self): + self._logger.info('starting server loop') + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind(('0.0.0.0', self.source_port)) - sock.settimeout(1) + sock.settimeout(READ_TIMEOUT) while not self.interrupted: - if not self.outgoing_queue.empty(): - message = self.outgoing_queue.get() - message.encrypt(outkey=self.encoutkey, - inkey=self.encinkey, - token=self.device_token, - pubkey=self.pubkey) - buf = message.frame.pack() - self.logger.debug('send: '+buf.hex()) - self.logger.debug(f'sendto: {self.device_addr}:{self.device_port}') - sock.sendto(buf, (self.device_addr, self.device_port)) + with self._st_lock: + status = self.status + + if status in (ConnectionStatus.DISCONNECTED, ConnectionStatus.RECONNECTING): + self._cleanup() + if status == ConnectionStatus.DISCONNECTED: + break + + # no activity for some time means connection is broken + fail = False + fail_path = 0 + if self.incoming_time > 0 and time.time() - self.incoming_time >= DISCONNECT_TIMEOUT: + fail = True + fail_path = 1 + elif self.outgoing_time_1st > 0 and self.incoming_time == 0 and time.time() - self.outgoing_time_1st >= DISCONNECT_TIMEOUT: + fail = True + fail_path = 2 + + if fail: + self._logger.debug(f'run: setting status=RECONNECTING because of long inactivity, fail_path={fail_path}') + self._notify_cs(ConnectionStatus.RECONNECTING) + + # establishing a connection + if status in (ConnectionStatus.RECONNECTING, ConnectionStatus.NOT_CONNECTED): + if status == ConnectionStatus.RECONNECTING and self.reconnect_tries >= 3: + self._notify_cs(ConnectionStatus.DISCONNECTED) + continue + + self._reset_outseq() + self._prepare_keys() + + # shake the imaginary kettle's hand + wrapped = WrappedMessage(HandshakeMessage(), + handler=self._handshake_callback, + validator=lambda m: isinstance(m, (AckMessage, HandshakeResponseMessage))) + self.enqueue_message(wrapped, prepend=True) + self._notify_cs(ConnectionStatus.CONNECTING) + + # pick next (wrapped) message to send + wm = self._get_next_message() # wm means "wrapped message" + if wm: + if not isinstance(wm.message, (AckMessage, NakMessage)): + old_seq = wm.seq + wm.seq = self.outseq + self._set_response_handler(wm, old_seq=old_seq) + elif wm.seq is None: + # ack/nak is a response to some incoming message (and it must have the same seqno that incoming + # message had) + raise RuntimeError(f'run: seq must be set for {wm.__class__.__name__}') + + self._logger.debug(f'run: sending message: {wm.message}') + wm.message.encrypt(outkey=self.encoutkey, inkey=self.encinkey, + token=self.device_token, pubkey=self.pubkey) + buf = wm.message.frame.pack() + one_shot = isinstance(wm.message, (AckMessage, NakMessage, PingMessage)) + # self._logger.debug(f'run: raw data to be sent: {buf.hex()}') + + # sending the first time + if wm.phase == MessagePhase.WAITING: + sock.sendto(buf, self.get_address()) + # resending + elif wm.phase == MessagePhase.SENT: + left = RESEND_ATTEMPTS + while left > 0: + sock.sendto(buf, self.get_address()) + left -= 1 + if left > 0: + time.sleep(0.05) + + if one_shot or wm.phase == MessagePhase.SENT: + wm.phase = MessagePhase.DONE + else: + wm.phase = MessagePhase.SENT + + now = time.time() + self.outgoing_time = now + if not self.outgoing_time_1st: + self.outgoing_time_1st = now + + # receiving data try: data = sock.recv(4096) - self.handle_incoming(data) - except TimeoutError: + self._handle_incoming(data) + except (TimeoutError, socket.timeout): pass - def handle_incoming(self, buf: bytes): - self.logger.debug('handle_incoming: '+buf.hex()) - message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey) - if message.frame.head.seq in self.waiting_for_response: - self.logger.info(f'received awaited message: {message}') + self._logger.info('bye...') + + def _get_next_message(self) -> Optional[WrappedMessage]: + message = None + lpfx = '_get_next_message:' + remove_list = [] + for wm in self.outgoing_queue: + if wm.phase == MessagePhase.DONE: + if isinstance(wm.message, (AckMessage, NakMessage)) or time.time() - wm.phase_update_time >= MESSAGE_QUEUE_REMOVE_DELAY: + remove_list.append(wm) + continue + message = wm + break + + for wm in remove_list: + self._logger.debug(f'{lpfx} rm path: removing id={wm.id} seq={wm.seq}') + + # clear message handler + if wm.seq in self.response_handlers: + self.response_handlers[wm.seq].call( + False, error_message=f'{lpfx} rm path: error while calling callback for seq={wm.seq}') + del self.response_handlers[wm.seq] + + # remove from queue try: - f = self.waiting_for_response[message.frame.head.seq] - f(message) - except Exception as exc: - self.logger.exception(exc) - finally: - del self.waiting_for_response[message.frame.head.seq] + self.outgoing_queue.remove(wm) + except ValueError as exc: + self._logger.error(f'{lpfx} rm path: removing from outgoing_queue raised an exception: {str(exc)}') + + # ping pong + if self.outgoing_time_1st != 0 and self.status == ConnectionStatus.CONNECTED: + now = time.time() + out_delta = now - self.outgoing_time + in_delta = now - self.incoming_time + if not message and max(out_delta, in_delta) > PING_FREQUENCY: + self._logger.debug(f'{lpfx} no activity: in for {in_delta:.2f}s, out for {out_delta:.2f}s, time to ping the damn thing') + message = WrappedMessage(PingMessage(), ack=True) + + return message + + def _handle_incoming(self, buf: bytes): + self.incoming_time = time.time() + + incoming_message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey) + seq = incoming_message.seq + + lpfx = f'handle_incoming({incoming_message.id}):' + self._logger.debug(f'{lpfx} received: {incoming_message}') + + if isinstance(incoming_message, (AckMessage, NakMessage)): + seq_max = self.outseq + seq_name = 'outseq' + else: + seq_max = self.inseq + seq_name = 'inseq' + self.inseq = seq + + if seq < seq_max < 0xfd: + self._logger.warning(f'{lpfx} dropping: seq={seq}, {seq_name}={seq_name}') + return + + if seq not in self.response_handlers: + self._handle_incoming_cmd(incoming_message) + return + + callback_value = None # None means don't call a callback + handler = self.response_handlers[seq] + + if handler.validate(incoming_message): + self._logger.info(f'{lpfx} response OK') + handler.phase = MessagePhase.DONE + callback_value = incoming_message + self._incr_outseq() else: - self.logger.info(f'received message (not awaited): {message}') - - def send_message(self, message: Message, callback: callable): - seq = self.next_seqno() - message.set_seq(seq) - self.outgoing_queue.put(message) - self.waiting_for_response[seq] = callback - - def next_seqno(self) -> int: - with self.lock: - self.seq_no += 1 - self.logger.debug(f'next_seqno: set to {self.seq_no}') - return self.seq_no + self._logger.info(f'{lpfx} response is INVALID') + + # It seems that we've received an incoming CmdMessage or PingMessage with the same seqno that our outgoing + # message had. Bad, but what can I say, this is quick-and-dirty made UDP based protocol and this sort of + # shit just happens. + + # (To be fair, maybe my implementation is not perfect either. But hey, what did you expect from a + # reverse-engineered re-implementation of custom UDP-based protocol that some noname vendor uses for their + # cheap IoT devices? I think _that_ is _the_ definition of shit. At least my implementation is FOSS, which + # is more than you'll ever be able to say about them.) + + # All this crapload of code below might not be needed at all, 'cause the protocol uses separate frame seq + # numbers for IN and OUT frames and this situation is not highly likely, as Theresa May could argue. + # After a handshake, a kettle sends us 10 or so CmdMessages, and then either we continuously ping it every + # 3 seconds, or kettle pings us. This in any case widens the gap between inseq and outseq. + + # But! the seqno is only 1 byte in size and once it reaches 0xff, it circles back to zero. And that (plus, + # perhaps, some bad luck) gives a chance for a collision. + + if handler.phase == MessagePhase.DONE or isinstance(handler.message, HandshakeMessage): + # no more attempts left, returning error back to user + # as to handshake, it cannot fail. + callback_value = False + + # else: + # # try resending the message + # handler.phase_reset() + # max_seq = self.outseq + # wait_remap = {} + # for m in self.outgoing_queue: + # if m.seq in self.waiting_for_response: + # wait_remap[m.seq] = (m.seq+1) % 256 + # m.set_seq((m.seq+1) % 256) + # if m.seq > max_seq: + # max_seq = m.seq + # if max_seq > self.outseq: + # self.outseq = max_seq % 256 + # if wait_remap: + # waiting_new = {} + # for old_seq, new_seq in wait_remap.items(): + # waiting_new[new_seq] = self.waiting_for_response[old_seq] + # self.waiting_for_response = waiting_new + + if isinstance(incoming_message, (PingMessage, CmdIncomingMessage)): + # handle incoming message as usual, as we need to ack/nak it anyway + self._handle_incoming_cmd(incoming_message) + + if callback_value is not None: + handler.call(callback_value, + error_message=f'{lpfx} error while calling callback for msg id={handler.message.id} seq={seq}') + del self.response_handlers[seq] + + def _handle_incoming_cmd(self, incoming_message: Message): + if isinstance(incoming_message, (AckMessage, NakMessage)): + self._logger.debug(f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): it\'s {incoming_message.__class__.__name__}, ignoring') + return + + replied = False + with self._iml_lock: + for f in self.inc_listeners: + retval = safe_callback_call(f.incoming_message, incoming_message, + logger=self._logger, + error_message=f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): error while calling message listener') + if isinstance(retval, Message): + if isinstance(retval, (AckMessage, NakMessage)): + retval.seq = incoming_message.seq + self.enqueue_message(WrappedMessage(retval), prepend=True) + replied = True + break + else: + raise RuntimeError('are you sure your response is correct? only ack/nak are allowed') + + if not replied: + self.enqueue_message(WrappedMessage(AckMessage(incoming_message.seq)), prepend=True) + + def enqueue_message(self, wrapped: WrappedMessage, prepend=False): + self._logger.debug(f'enqueue_message: {wrapped.message}') + if not prepend: + self.outgoing_queue.append(wrapped) + else: + self.outgoing_queue.insert(0, wrapped) + + def _set_response_handler(self, wm: WrappedMessage, old_seq=None): + if old_seq in self.response_handlers: + del self.response_handlers[old_seq] + + seq = wm.seq + assert seq is not None, 'seq is not set' + + if seq in self.response_handlers: + self._logger.warning(f'_set_response_handler(seq={seq}): handler is already set, cancelling it') + self.response_handlers[seq].call(False, + error_message=f'_set_response_handler({seq}): error while calling old callback') + self.response_handlers[seq] = wm + + def _incr_outseq(self) -> None: + self.outseq = (self.outseq + 1) % 256 + + def _reset_outseq(self): + self.outseq = 0 + self._logger.debug(f'_reset_outseq: set 0') + + +MessageResponse = Union[Message, bool] diff --git a/src/polaris_kettle_bot.py b/src/polaris_kettle_bot.py new file mode 100644 index 0000000..bad1b3a --- /dev/null +++ b/src/polaris_kettle_bot.py @@ -0,0 +1,684 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import logging +import locale +import queue +import time +import threading +import paho.mqtt.client as mqtt + +from home.bot import Wrapper, Context, text_filter, handlermethod +from home.api.types import BotType +from home.mqtt import MQTTBase +from home.config import config +from polaris import ( + Kettle, + PowerType, + DeviceListener, + IncomingMessageListener, + ConnectionStatusListener, + ConnectionStatus +) +import polaris.protocol as kettle_proto +from typing import Optional, Tuple, List +from collections import namedtuple +from functools import partial +from datetime import datetime +from abc import abstractmethod +from telegram.error import TelegramError +from telegram import ( + ReplyKeyboardMarkup, + InlineKeyboardMarkup, + InlineKeyboardButton, + Message +) +from telegram.ext import ( + CallbackQueryHandler, + MessageHandler, + CommandHandler +) + +logger = logging.getLogger(__name__) +kc: Optional[KettleController] = None +bot: Optional[Wrapper] = None +RenderedContent = Tuple[str, Optional[InlineKeyboardMarkup]] +tasks_lock = threading.Lock() + + +class KettleInfoListener: + @abstractmethod + def info_updated(self, field: str): + pass + + +# class that holds data coming from the kettle over mqtt +class KettleInfo: + update_time: int + _mode: Optional[PowerType] + _temperature: Optional[int] + _target_temperature: Optional[int] + _update_listener: KettleInfoListener + + def __init__(self, update_listener: KettleInfoListener): + self.update_time = 0 + self._mode = None + self._temperature = None + self._target_temperature = None + self._update_listener = update_listener + + def _update(self, field: str): + self.update_time = int(time.time()) + if self._update_listener: + self._update_listener.info_updated(field) + + @property + def temperature(self) -> int: + return self._temperature + + @temperature.setter + def temperature(self, value: int): + self._temperature = value + self._update('temperature') + + @property + def mode(self) -> PowerType: + return self._mode + + @mode.setter + def mode(self, value: PowerType): + self._mode = value + self._update('mode') + + @property + def target_temperature(self) -> int: + return self._target_temperature + + @target_temperature.setter + def target_temperature(self, value: int): + self._target_temperature = value + self._update('target_temperature') + + +class KettleController(threading.Thread, + MQTTBase, + DeviceListener, + IncomingMessageListener, + KettleInfoListener, + ConnectionStatusListener): + kettle: Kettle + info: KettleInfo + + _logger: logging.Logger + _stopped: bool + _restart_server_at: int + _lock: threading.Lock + _info_lock: threading.Lock + _accumulated_updates: dict + _info_flushed_time: float + _mqtt_root_topic: str + _muts: List[MessageUpdatingTarget] + + def __init__(self): + # basic setup + MQTTBase.__init__(self, clean_session=False) + threading.Thread.__init__(self) + + self._logger = logging.getLogger(self.__class__.__name__) + + self.kettle = Kettle(mac=config['kettle']['mac'], + device_token=config['kettle']['token']) + self.kettle_reconnect() + + # info + self.info = KettleInfo(update_listener=self) + self._accumulated_updates = {} + self._info_flushed_time = 0 + + # mqtt + self._mqtt_root_topic = '/polaris/6/'+config['kettle']['token']+'/#' + self.connect_and_loop(loop_forever=False) + + # thread loop related + self._stopped = False + # self._lock = threading.Lock() + self._info_lock = threading.Lock() + self._restart_server_at = 0 + + # bot + self._muts = [] + self._muts_lock = threading.Lock() + + self.start() + + def kettle_reconnect(self): + self.kettle.discover(wait=False, listener=self) + + def stop_all(self): + self.kettle.stop_all() + self._stopped = True + + def add_updating_message(self, mut: MessageUpdatingTarget): + with self._muts_lock: + for m in self._muts: + if m.user_id == m.user_id and m.user_did_turn_on() or m.user_did_turn_on() != mut.user_did_turn_on(): + m.delete() + self._muts.append(mut) + + # --------------------- + # threading.Thread impl + + def run(self): + while not self._stopped: + # do_restart_srv = False + # + # with self._lock: + # if self._restart_server_at != 0 and time.time() - self._restart_server_at: + # self._restart_server_at = 0 + # do_restart_srv = True + # + # if do_restart_srv: + # self.kettle_connect() + + updates = [] + deletions = [] + with self._muts_lock and self._info_lock: + # self._logger.debug('muts size: '+str(len(self._muts))) + if self._muts and self._accumulated_updates and (self._info_flushed_time == 0 or time.time() - self._info_flushed_time >= 1): + forget = [] + deletions = [] + + for mut in self._muts: + upd = mut.update( + mode=self.info.mode, + current_temp=self.info.temperature, + target_temp=self.info.target_temperature) + + if upd.finished or upd.delete: + forget.append(mut) + + if upd.delete: + deletions.append(upd) + elif upd.changed: + updates.append(upd) + + if forget: + for mut in forget: + self._logger.debug(f'loop: removing mut {mut}') + self._muts.remove(mut) + + self._info_flushed_time = time.time() + self._accumulated_updates = {} + + for upd in updates: + self._logger.debug(f'loop: got update: {upd}') + try: + bot.edit_message_text(upd.user_id, upd.message_id, + text=upd.html, + reply_markup=upd.markup) + except TelegramError as exc: + self._logger.error(f'loop: edit_message_text failed for update: {upd}') + self._logger.exception(exc) + + for upd in deletions: + self._logger.debug(f'loop: got deletion: {upd}') + try: + bot.delete_message(upd.user_id, upd.message_id) + except TelegramError as exc: + self._logger.error(f'loop: delete_message failed for update: {upd}') + self._logger.exception(exc) + + time.sleep(0.5) + + # ------------------- + # DeviceListener impl + + def device_updated(self): + self._logger.info(f'device updated: {self.kettle.device.si}') + self.kettle.start_server_if_needed(incoming_message_listener=self, + connection_status_listener=self) + + # ----------------------- + # KettleInfoListener impl + + def info_updated(self, field: str): + with self._info_lock: + newval = getattr(self.info, field) + self._logger.debug(f'info_updated: updated {field}, new value is {newval}') + self._accumulated_updates[field] = newval + + # ---------------------------- + # IncomingMessageListener impl + + def incoming_message(self, message: kettle_proto.Message) -> Optional[kettle_proto.Message]: + self._logger.info(f'incoming message: {message}') + + if isinstance(message, kettle_proto.ModeMessage): + self.info.mode = message.pt + elif isinstance(message, kettle_proto.CurrentTemperatureMessage): + self.info.temperature = message.current_temperature + elif isinstance(message, kettle_proto.TargetTemperatureMessage): + self.info.target_temperature = message.temperature + + return kettle_proto.AckMessage() + + # ----------------------------- + # ConnectionStatusListener impl + + def connection_status_updated(self, status: ConnectionStatus): + self._logger.info(f'connection status updated: {status}') + if status == ConnectionStatus.DISCONNECTED: + self.kettle.stop_all() + self.kettle_reconnect() + + # ------------- + # MQTTBase impl + + def on_connect(self, client: mqtt.Client, userdata, flags, rc): + super().on_connect(client, userdata, flags, rc) + client.subscribe(self._mqtt_root_topic, qos=1) + self._logger.info(f'subscribed to {self._mqtt_root_topic}') + + def on_message(self, client: mqtt.Client, userdata, msg): + try: + topic = msg.topic[len(self._mqtt_root_topic)-2:] + pld = msg.payload.decode() + + self._logger.debug(f'mqtt: on message: topic={topic} pld={pld}') + + if topic == 'state/sensor/temperature': + self.info.temperature = int(float(pld)) + elif topic == 'state/mode': + self.info.mode = PowerType(int(pld)) + elif topic == 'state/temperature': + self.info.target_temperature = int(float(pld)) + + except Exception as e: + self._logger.exception(str(e)) + + +class Renderer: + @classmethod + def index(cls, ctx: Context) -> RenderedContent: + html = f'<b>{ctx.lang("settings")}</b>\n\n' + html += ctx.lang('select_place') + return html, None + + @classmethod + def status(cls, ctx: Context, + connected: bool, + mode: PowerType, + current_temp: int, + target_temp: int, + update_time: int) -> RenderedContent: + if not connected: + return cls.not_connected(ctx) + else: + # power status + if mode != PowerType.OFF: + html = ctx.lang('status_on', target_temp) + else: + html = ctx.lang('status_off') + + # current temperature + html += '\n' + html += ctx.lang('status_current_temp', current_temp) + + # updated on + html += '\n' + html += cls.updated(ctx, update_time) + + return html, None + + @classmethod + def turned_on(cls, ctx: Context, + target_temp: int, + current_temp: int, + mode: PowerType, + update_time: Optional[int] = None, + reached=False, + no_keyboard=False) -> RenderedContent: + if mode == PowerType.OFF and not reached: + html = ctx.lang('enabling') + else: + if not reached: + emoji = '♨️' if current_temp <= 90 else '🔥' + html = ctx.lang('enabled', emoji, target_temp) + + # current temperature + html += '\n' + html += ctx.lang('status_current_temp', current_temp) + else: + html = ctx.lang('enabled_reached', current_temp) + + # updated on + if not reached and update_time is not None: + html += '\n' + html += cls.updated(ctx, update_time) + + return html, None if no_keyboard else cls.wait_buttons(ctx) + + @classmethod + def turned_off(cls, ctx: Context, + mode: PowerType, + update_time: Optional[int] = None, + reached=False, + no_keyboard=False) -> RenderedContent: + if mode != PowerType.OFF: + html = ctx.lang('disabling') + else: + html = ctx.lang('disabled') + + # updated on + if not reached and update_time is not None: + html += '\n' + html += cls.updated(ctx, update_time) + + return html, None if no_keyboard else cls.wait_buttons(ctx) + + @classmethod + def not_connected(cls, ctx: Context) -> RenderedContent: + return ctx.lang('status_not_connected'), None + + @classmethod + def smth_went_wrong(cls, ctx: Context) -> RenderedContent: + html = ctx.lang('smth_went_wrong') + return html, None + + @classmethod + def updated(cls, ctx: Context, update_time: int): + locale_bak = locale.getlocale(locale.LC_TIME) + locale.setlocale(locale.LC_TIME, 'ru_RU.UTF-8' if ctx.user_lang == 'ru' else 'en_US.UTF-8') + dt = datetime.fromtimestamp(update_time) + html = ctx.lang('status_update_time', dt.strftime(ctx.lang('status_update_time_fmt'))) + locale.setlocale(locale.LC_TIME, locale_bak) + return html + + @classmethod + def wait_buttons(cls, ctx: Context): + return InlineKeyboardMarkup([ + [ + InlineKeyboardButton(ctx.lang('please_wait'), callback_data='wait') + ] + ]) + + +def run_tasks(tasks: queue.SimpleQueue, done: callable): + def next_task(r: Optional[kettle_proto.MessageResponse]): + if r is not None: + try: + assert r is not False, 'server error' + except AssertionError as exc: + logger.exception(exc) + tasks_lock.release() + return done(False) + + if not tasks.empty(): + task = tasks.get() + args = task[1:] + args.append(next_task) + f = getattr(kc.kettle, task[0]) + f(*args) + else: + tasks_lock.release() + return done(True) + + tasks_lock.acquire() + next_task(None) + + +MUTUpdate = namedtuple('MUTUpdate', 'message_id, user_id, finished, changed, delete, html, markup') + + +class MessageUpdatingTarget: + ctx: Context + message: Message + user_target_temp: Optional[int] + user_enabled_power_mode: PowerType + initial_power_mode: PowerType + need_to_delete: bool + rendered_content: Optional[RenderedContent] + + def __init__(self, + ctx: Context, + message: Message, + user_enabled_power_mode: PowerType, + initial_power_mode: PowerType, + user_target_temp: Optional[int] = None): + self.ctx = ctx + self.message = message + self.initial_power_mode = initial_power_mode + self.user_enabled_power_mode = user_enabled_power_mode + self.ignore_pm = initial_power_mode is PowerType.OFF and self.user_did_turn_on() + self.user_target_temp = user_target_temp + self.need_to_delete = False + self.rendered_content = None + self.last_reported_temp = None + + def set_rendered_content(self, content: RenderedContent): + self.rendered_content = content + + def rendered_content_changed(self, content: RenderedContent) -> bool: + return content != self.rendered_content + + def update(self, + mode: PowerType, + current_temp: int, + target_temp: int) -> MUTUpdate: + + # determine whether status updating is finished + finished = False + reached = False + if self.ignore_pm: + if mode != PowerType.OFF: + self.ignore_pm = False + elif mode == PowerType.OFF: + reached = True + if self.user_did_turn_on(): + # when target is 100 degrees, this kettle sometimes turns off at 91, sometimes at 95, sometimes at 98. + # it's totally unpredictable, so in this case, we keep updating the message until it reaches at least 97 + # degrees, or if temperature started dropping. + if self.user_target_temp < 100 \ + or current_temp >= self.user_target_temp - 3 \ + or current_temp < self.last_reported_temp: + finished = True + else: + finished = True + + self.last_reported_temp = current_temp + + # render message + if self.user_did_turn_on(): + rc = Renderer.turned_on(self.ctx, + target_temp=target_temp, + current_temp=current_temp, + mode=mode, + reached=reached, + no_keyboard=finished) + else: + rc = Renderer.turned_off(self.ctx, + mode=mode, + reached=reached, + no_keyboard=finished) + + changed = self.rendered_content_changed(rc) + update = MUTUpdate(message_id=self.message.message_id, + user_id=self.ctx.user_id, + finished=finished, + changed=changed, + delete=self.need_to_delete, + html=rc[0], + markup=rc[1]) + if changed: + self.set_rendered_content(rc) + return update + + def user_did_turn_on(self) -> bool: + return self.user_enabled_power_mode in (PowerType.ON, PowerType.CUSTOM) + + def delete(self): + self.need_to_delete = True + + @property + def user_id(self) -> int: + return self.ctx.user_id + + +class KettleBot(Wrapper): + def __init__(self): + super().__init__() + + self.lang.ru( + start_message="Выберите команду на клавиатуре", + unknown_command="Неизвестная команда", + unexpected_callback_data="Ошибка: неверные данные", + enable_70="♨️ 70 °C", + enable_80="♨️ 80 °C", + enable_90="♨️ 90 °C", + enable_100="🔥 100 °C", + disable="❌ Выключить", + server_error="Ошибка сервера", + + # /status + status_not_connected="😟 Связь с чайником не установлена", + status_on="✅ Чайник <b>включён</b> (до <b>%d °C</b>)", + status_off="❌ Чайник <b>выключен</b>", + status_current_temp="Сейчас: <b>%d °C</b>", + status_update_time="<i>Обновлено %s</i>", + status_update_time_fmt="%d %b в %H:%M:%S", + + # enable + enabling="💤 Чайник включается...", + disabling="💤 Чайник выключается...", + enabled="%s Чайник <b>включён</b>.\nЦель: <b>%d °C</b>", + enabled_reached="✅ <b>Готово!</b> Чайник вскипел, температура <b>%d °C</b>.", + disabled="✅ Чайник <b>выключен</b>.", + please_wait="⏳ Ожидайте..." + ) + + self.lang.en( + start_message="Select command on the keyboard", + unknown_command="Unknown command", + unexpected_callback_data="Unexpected callback data", + enable_70="♨️ 70 °C", + enable_80="♨️ 80 °C", + enable_90="♨️ 90 °C", + enable_100="🔥 100 °C", + disable="❌ Turn OFF", + server_error="Server error", + + # /status + not_connected="😟 Connection has not been established", + status_on="✅ Turned <b>ON</b>! Target: <b>%d °C</b>", + status_off="❌ Turned <b>OFF</b>", + status_current_temp="Now: <b>%d °C</b>", + status_update_time="<i>Updated on %s</i>", + status_update_time_fmt="%b %d, %Y at %H:%M:%S", + + # enable + enabling="💤 Turning on...", + disabling="💤 Turning off...", + enabled="%s The kettle is <b>turned ON</b>.\nTarget: <b>%d °C</b>", + enabled_reached="✅ It's <b>done</b>! The kettle has boiled, the temperature is <b>%d °C</b>.", + disabled="✅ The kettle is <b>turned OFF</b>.", + please_wait="⏳ Please wait..." + ) + + # commands + self.add_handler(CommandHandler('status', self.status)) + + # messages + for temp in (70, 80, 90, 100): + self.add_handler(MessageHandler(text_filter(self.lang.all(f'enable_{temp}')), self.wrap(partial(self.on, temp)))) + + self.add_handler(MessageHandler(text_filter(self.lang.all('disable')), self.off)) + + def markup(self, ctx: Optional[Context]) -> Optional[ReplyKeyboardMarkup]: + buttons = [ + [ctx.lang(f'enable_{x}') for x in (70, 80, 90, 100)], + [ctx.lang('disable')] + ] + return ReplyKeyboardMarkup(buttons, one_time_keyboard=False) + + def on(self, temp: int, ctx: Context) -> None: + if not kc.kettle.is_connected(): + text, markup = Renderer.not_connected(ctx) + ctx.reply(text, markup=markup) + return + + tasks = queue.SimpleQueue() + if temp == 100: + power_mode = PowerType.ON + else: + power_mode = PowerType.CUSTOM + tasks.put(['set_target_temperature', temp]) + tasks.put(['set_power', power_mode]) + + def done(ok: bool): + if not ok: + html, markup = Renderer.smth_went_wrong(ctx) + else: + html, markup = Renderer.turned_on(ctx, + target_temp=temp, + current_temp=kc.info.temperature, + mode=kc.info.mode) + message = ctx.reply(html, markup=markup) + logger.info(f'ctx.reply returned message: {message}') + + mut = MessageUpdatingTarget(ctx, message, + initial_power_mode=kc.info.mode, + user_enabled_power_mode=power_mode, + user_target_temp=temp) + mut.set_rendered_content((html, markup)) + kc.add_updating_message(mut) + + run_tasks(tasks, done) + + @handlermethod + def off(self, ctx: Context) -> None: + if not kc.kettle.is_connected(): + text, markup = Renderer.not_connected(ctx) + ctx.reply(text, markup=markup) + return + + def done(ok: bool): + if not ok: + html, markup = Renderer.smth_went_wrong(ctx) + else: + html, markup = Renderer.turned_off(ctx, mode=kc.info.mode) + message = ctx.reply(html, markup=markup) + logger.info(f'ctx.reply returned message: {message}') + + mut = MessageUpdatingTarget(ctx, message, + initial_power_mode=kc.info.mode, + user_enabled_power_mode=PowerType.OFF) + mut.set_rendered_content((html, markup)) + kc.add_updating_message(mut) + + tasks = queue.SimpleQueue() + tasks.put(['set_power', PowerType.OFF]) + run_tasks(tasks, done) + + @handlermethod + def status(self, ctx: Context): + text, markup = Renderer.status(ctx, + connected=kc.kettle.is_connected(), + mode=kc.info.mode, + current_temp=kc.info.temperature, + target_temp=kc.info.target_temperature, + update_time=kc.info.update_time) + return ctx.reply(text, markup=markup) + + +if __name__ == '__main__': + config.load('polaris_kettle_bot') + + kc = KettleController() + + bot = KettleBot() + if 'api' in config: + bot.enable_logging(BotType.POLARIS_KETTLE) + bot.run() + + # bot library handles signals, so when sigterm or something like that happens, we should stop all other threads here + kc.stop_all() diff --git a/src/polaris_kettle_util.py b/src/polaris_kettle_util.py index 419739b..82b2588 100755 --- a/src/polaris_kettle_util.py +++ b/src/polaris_kettle_util.py @@ -3,35 +3,23 @@ import logging import sys -import time import paho.mqtt.client as mqtt -# from datetime import datetime -# from html import escape +from typing import Optional from argparse import ArgumentParser from queue import SimpleQueue -# from home.bot import Wrapper, Context -# from home.api.types import BotType -# from home.util import parse_addr from home.mqtt import MQTTBase from home.config import config -from polaris import Kettle, Message, FrameType, PowerType - -# from telegram.error import TelegramError -# from telegram import ReplyKeyboardMarkup, InlineKeyboardMarkup, InlineKeyboardButton -# from telegram.ext import ( -# CallbackQueryHandler, -# MessageHandler, -# CommandHandler -# ) - +from polaris import ( + Kettle, + PowerType, + protocol as kettle_proto +) +k: Optional[Kettle] = None logger = logging.getLogger(__name__) control_tasks = SimpleQueue() -# bot: Optional[Wrapper] = None -# RenderedContent = tuple[str, Optional[InlineKeyboardMarkup]] - class MQTTServer(MQTTBase): def __init__(self): @@ -50,65 +38,29 @@ class MQTTServer(MQTTBase): logger.exception(str(e)) -# class Renderer: -# @classmethod -# def index(cls, ctx: Context) -> RenderedContent: -# html = f'<b>{ctx.lang("settings")}</b>\n\n' -# html += ctx.lang('select_place') -# return html, None - - -# status handler -# -------------- - -# def status(ctx: Context): -# text, markup = Renderer.index(ctx) -# return ctx.reply(text, markup=markup) - - -# class SoundBot(Wrapper): -# def __init__(self): -# super().__init__() -# -# self.lang.ru( -# start_message="Выберите команду на клавиатуре", -# unknown_command="Неизвестная команда", -# unexpected_callback_data="Ошибка: неверные данные", -# status="Статус", -# ) -# -# self.lang.en( -# start_message="Select command on the keyboard", -# unknown_command="Unknown command", -# unexpected_callback_data="Unexpected callback data", -# status="Status", -# ) -# -# self.add_handler(CommandHandler('status', self.wrap(status))) -# -# def markup(self, ctx: Optional[Context]) -> Optional[ReplyKeyboardMarkup]: -# buttons = [ -# [ctx.lang('status')] -# ] -# return ReplyKeyboardMarkup(buttons, one_time_keyboard=False) - -def kettle_connection_established(k: Kettle, response: Message): +def kettle_connection_established(response: kettle_proto.MessageResponse): try: - assert response.frame.head.type == FrameType.ACK, f'ACK expected, but received: {response}' + assert isinstance(response, kettle_proto.AckMessage), f'ACK expected, but received: {response}' except AssertionError: - k.stop_server() + k.stop_all() return - def next_task(k, response): + def next_task(response: kettle_proto.MessageResponse): + try: + assert response is not False, 'server error' + except AssertionError: + k.stop_all() + return + if not control_tasks.empty(): task = control_tasks.get() f, args = task(k) args.append(next_task) f(*args) else: - k.stop_server() + k.stop_all() - next_task(k, response) + next_task(response) def main(): @@ -123,7 +75,7 @@ def main(): parser.add_argument('-t', '--temperature', dest='temp', type=int, default=tempmax, choices=range(tempmin, tempmax+tempstep, tempstep)) - arg = config.load('polaris_kettle_bot', use_cli=True, parser=parser) + arg = config.load('polaris_kettle_util', use_cli=True, parser=parser) if arg.mode == 'mqtt': server = MQTTServer() @@ -145,19 +97,17 @@ def main(): control_tasks.put(lambda k: (k.set_target_temperature, [arg.temp])) control_tasks.put(lambda k: (k.set_power, [PowerType.CUSTOM])) - k = Kettle(mac='40f52018dec1', device_token='3a5865f015950cae82cd120e76a80d28') - info = k.find() - print('found service:', info) + k = Kettle(mac=config['kettle']['mac'], device_token=config['kettle']['token']) + info = k.discover() + if not info: + print('no device found.') + return 1 - k.start_server(kettle_connection_established) + print('found service:', info) + k.start_server_if_needed(kettle_connection_established) return 0 if __name__ == '__main__': sys.exit(main()) - - # bot = SoundBot() - # if 'api' in config: - # bot.enable_logging(BotType.POLARIS_KETTLE) - # bot.run() diff --git a/src/sound_bot.py b/src/sound_bot.py index b515ae7..91e51f0 100755 --- a/src/sound_bot.py +++ b/src/sound_bot.py @@ -6,7 +6,7 @@ import tempfile from enum import Enum from datetime import datetime, timedelta from html import escape -from typing import Optional +from typing import Optional, List, Dict, Tuple from home.config import config from home.bot import Wrapper, Context, text_filter, user_any_name @@ -27,11 +27,11 @@ from telegram.ext import ( from PIL import Image logger = logging.getLogger(__name__) -RenderedContent = tuple[str, Optional[InlineKeyboardMarkup]] +RenderedContent = Tuple[str, Optional[InlineKeyboardMarkup]] record_client: Optional[SoundRecordClient] = None bot: Optional[Wrapper] = None -node_client_links: dict[str, SoundNodeClient] = {} -cam_client_links: dict[str, CameraNodeClient] = {} +node_client_links: Dict[str, SoundNodeClient] = {} +cam_client_links: Dict[str, CameraNodeClient] = {} def node_client(node: str) -> SoundNodeClient: @@ -73,7 +73,7 @@ def interval_defined(interval: int) -> bool: return interval in config['bot']['record_intervals'] -def callback_unpack(ctx: Context) -> list[str]: +def callback_unpack(ctx: Context) -> List[str]: return ctx.callback_query.data[3:].split('/') @@ -115,7 +115,7 @@ class SettingsRenderer(Renderer): @classmethod def node(cls, ctx: Context, - controls: list[dict]) -> RenderedContent: + controls: List[dict]) -> RenderedContent: node, = callback_unpack(ctx) html = [] @@ -169,7 +169,7 @@ class RecordRenderer(Renderer): return html, cls.places_markup(ctx, callback_prefix='r0') @classmethod - def node(cls, ctx: Context, durations: list[int]) -> RenderedContent: + def node(cls, ctx: Context, durations: List[int]) -> RenderedContent: node, = callback_unpack(ctx) html = ctx.lang('select_interval') @@ -241,7 +241,7 @@ class FilesRenderer(Renderer): return html, cls.places_markup(ctx, callback_prefix='f0') @classmethod - def filelist(cls, ctx: Context, files: list[SoundRecordFile]) -> RenderedContent: + def filelist(cls, ctx: Context, files: List[SoundRecordFile]) -> RenderedContent: node, = callback_unpack(ctx) html_files = map(lambda file: cls.file(ctx, file, node), files) @@ -936,7 +936,6 @@ class SoundBot(Wrapper): # cheese self.add_handler(CallbackQueryHandler(self.wrap(camera_capture), pattern=r'^c1/.*')) - def markup(self, ctx: Optional[Context]) -> Optional[ReplyKeyboardMarkup]: buttons = [ [ctx.lang('record'), ctx.lang('settings')], diff --git a/src/sound_sensor_server.py b/src/sound_sensor_server.py index 0303d6d..9495678 100755 --- a/src/sound_sensor_server.py +++ b/src/sound_sensor_server.py @@ -3,7 +3,7 @@ import logging import threading from time import sleep -from typing import Optional +from typing import Optional, List, Dict, Tuple from functools import partial from home.config import config from home.util import parse_addr @@ -18,7 +18,7 @@ server: SoundSensorServer def get_related_nodes(node_type: MediaNodeType, - sensor_name: str) -> list[str]: + sensor_name: str) -> List[str]: if sensor_name not in config[f'sensor_to_{node_type.name.lower()}_nodes_relations']: raise ValueError(f'unexpected sensor name {sensor_name}') return config[f'sensor_to_{node_type.name.lower()}_nodes_relations'][sensor_name] @@ -52,7 +52,7 @@ class HitCounter: with self.lock: self.sensors[name] += hits - def get_all(self) -> list[tuple[str, int]]: + def get_all(self) -> List[Tuple[str, int]]: vals = [] with self.lock: for name, hits in self.sensors.items(): @@ -119,7 +119,7 @@ def hits_sender(): api: Optional[WebAPIClient] = None hc: Optional[HitCounter] = None -record_clients: dict[MediaNodeType, RecordClient] = {} +record_clients: Dict[MediaNodeType, RecordClient] = {} # record callbacks |