summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorEvgeny Zinoviev <me@ch1p.io>2022-06-28 03:22:30 +0300
committerEvgeny Zinoviev <me@ch1p.io>2022-06-30 03:47:49 +0300
commit8f20c9b825cabab7a3f0f5dd2cfe000cc7f72c28 (patch)
treeb5d7446e7b2fcfd42b1e5029aeef33ecb5f9715f /src
parentee09bc98aedfc6a65a5026432b399345a30a39c8 (diff)
polaris pwk 1725cgld full support
- significant improvements, correctnesses and stability fixes in protocol implementation - correct handling of device appearances and disappearances - flawlessly functioning telegram bot that re-renders kettle's state (temperature and other) in real time
Diffstat (limited to 'src')
-rw-r--r--src/home/api/errors/api_response_error.py4
-rw-r--r--src/home/api/types/types.py1
-rw-r--r--src/home/api/web_api_client.py20
-rw-r--r--src/home/audio/amixer.py4
-rw-r--r--src/home/bot/__init__.py2
-rw-r--r--src/home/bot/lang.py13
-rw-r--r--src/home/bot/wrapper.py34
-rw-r--r--src/home/camera/util.py3
-rw-r--r--src/home/database/bots.py10
-rw-r--r--src/home/media/node_client.py6
-rw-r--r--src/home/media/record.py6
-rw-r--r--src/home/media/record_client.py18
-rw-r--r--src/home/media/storage.py4
-rw-r--r--src/home/telegram/telegram.py3
-rw-r--r--src/home/util.py6
-rwxr-xr-xsrc/ipcam_server.py8
-rwxr-xr-xsrc/openwrt_logger.py5
-rw-r--r--src/polaris/__init__.py14
-rw-r--r--src/polaris/kettle.py271
-rw-r--r--src/polaris/protocol.py1015
-rw-r--r--src/polaris_kettle_bot.py684
-rwxr-xr-xsrc/polaris_kettle_util.py104
-rwxr-xr-xsrc/sound_bot.py17
-rwxr-xr-xsrc/sound_sensor_server.py8
24 files changed, 1907 insertions, 353 deletions
diff --git a/src/home/api/errors/api_response_error.py b/src/home/api/errors/api_response_error.py
index 6910b2d..85d788b 100644
--- a/src/home/api/errors/api_response_error.py
+++ b/src/home/api/errors/api_response_error.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, List
class ApiResponseError(Exception):
@@ -6,7 +6,7 @@ class ApiResponseError(Exception):
status_code: int,
error_type: str,
error_message: str,
- error_stacktrace: Optional[list[str]] = None):
+ error_stacktrace: Optional[List[str]] = None):
super().__init__()
self.status_code = status_code
self.error_message = error_message
diff --git a/src/home/api/types/types.py b/src/home/api/types/types.py
index b6233e6..4d8b4ff 100644
--- a/src/home/api/types/types.py
+++ b/src/home/api/types/types.py
@@ -7,6 +7,7 @@ class BotType(Enum):
SENSORS = auto()
ADMIN = auto()
SOUND = auto()
+ POLARIS_KETTLE = auto()
class TemperatureSensorLocation(Enum):
diff --git a/src/home/api/web_api_client.py b/src/home/api/web_api_client.py
index 34d080c..d6c9dc7 100644
--- a/src/home/api/web_api_client.py
+++ b/src/home/api/web_api_client.py
@@ -6,7 +6,7 @@ import logging
from collections import namedtuple
from datetime import datetime
from enum import Enum, auto
-from typing import Optional, Callable, Union
+from typing import Optional, Callable, Union, List, Tuple, Dict
from requests.auth import HTTPBasicAuth
from .errors import ApiResponseError
@@ -28,13 +28,13 @@ class HTTPMethod(Enum):
class WebAPIClient:
token: str
- timeout: Union[float, tuple[float, float]]
+ timeout: Union[float, Tuple[float, float]]
basic_auth: Optional[HTTPBasicAuth]
do_async: bool
async_error_handler: Optional[Callable]
async_success_handler: Optional[Callable]
- def __init__(self, timeout: Union[float, tuple[float, float]] = 5):
+ def __init__(self, timeout: Union[float, Tuple[float, float]] = 5):
self.token = config['api']['token']
self.timeout = timeout
self.basic_auth = None
@@ -66,7 +66,7 @@ class WebAPIClient:
})
def log_openwrt(self,
- lines: list[tuple[int, str]]):
+ lines: List[Tuple[int, str]]):
return self._post('logs/openwrt', {
'logs': stringify(lines)
})
@@ -81,14 +81,14 @@ class WebAPIClient:
return [(datetime.fromtimestamp(date), temp, hum) for date, temp, hum in data]
def add_sound_sensor_hits(self,
- hits: list[tuple[str, int]]):
+ hits: List[Tuple[str, int]]):
return self._post('sound_sensors/hits/', {
'hits': stringify(hits)
})
def get_sound_sensor_hits(self,
location: SoundSensorLocation,
- after: datetime) -> list[dict]:
+ after: datetime) -> List[dict]:
return self._process_sound_sensor_hits_data(self._get('sound_sensors/hits/', {
'after': int(after.timestamp()),
'location': location.value
@@ -100,13 +100,13 @@ class WebAPIClient:
'location': location.value
}))
- def recordings_list(self, extended=False, as_objects=False) -> Union[list[str], list[dict], list[RecordFile]]:
+ def recordings_list(self, extended=False, as_objects=False) -> Union[List[str], List[dict], List[RecordFile]]:
files = self._get('recordings/list/', {'extended': int(extended)})['data']
if as_objects:
return MediaNodeClient.record_list_from_serialized(files)
return files
- def _process_sound_sensor_hits_data(self, data: list[dict]) -> list[dict]:
+ def _process_sound_sensor_hits_data(self, data: List[dict]) -> List[dict]:
for item in data:
item['time'] = datetime.fromtimestamp(item['time'])
return data
@@ -124,7 +124,7 @@ class WebAPIClient:
name: str,
params: dict,
method: HTTPMethod,
- files: Optional[dict[str, str]] = None):
+ files: Optional[Dict[str, str]] = None):
if not self.do_async:
return self._make_request(name, params, method, files)
else:
@@ -136,7 +136,7 @@ class WebAPIClient:
name: str,
params: dict,
method: HTTPMethod = HTTPMethod.GET,
- files: Optional[dict[str, str]] = None) -> Optional[any]:
+ files: Optional[Dict[str, str]] = None) -> Optional[any]:
domain = config['api']['host']
kwargs = {}
diff --git a/src/home/audio/amixer.py b/src/home/audio/amixer.py
index 0ab2c64..53e6bce 100644
--- a/src/home/audio/amixer.py
+++ b/src/home/audio/amixer.py
@@ -2,7 +2,7 @@ import subprocess
from ..config import config
from threading import Lock
-from typing import Union
+from typing import Union, List
_lock = Lock()
@@ -16,7 +16,7 @@ def has_control(s: str) -> bool:
return False
-def get_caps(s: str) -> list[str]:
+def get_caps(s: str) -> List[str]:
for control in config['amixer']['controls']:
if control['name'] == s:
return control['caps']
diff --git a/src/home/bot/__init__.py b/src/home/bot/__init__.py
index 5e68af7..0d93af3 100644
--- a/src/home/bot/__init__.py
+++ b/src/home/bot/__init__.py
@@ -1,6 +1,6 @@
from .reporting import ReportingHelper
from .lang import LangPack
-from .wrapper import Wrapper, Context, text_filter
+from .wrapper import Wrapper, Context, text_filter, handlermethod
from .store import Store
from .errors import *
from .util import command_usage, user_any_name \ No newline at end of file
diff --git a/src/home/bot/lang.py b/src/home/bot/lang.py
index 2f10358..624c748 100644
--- a/src/home/bot/lang.py
+++ b/src/home/bot/lang.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
import logging
-from typing import Union, Optional
+from typing import Union, Optional, List, Dict
logger = logging.getLogger(__name__)
@@ -24,7 +26,7 @@ class LangStrings(dict):
class LangPack:
- strings: dict[str, LangStrings[str, str]]
+ strings: Dict[str, LangStrings[str, str]]
default_lang: str
def __init__(self):
@@ -57,11 +59,14 @@ class LangPack:
return result
@property
- def languages(self) -> list[str]:
+ def languages(self) -> List[str]:
return list(self.strings.keys())
def get(self, key: str, lang: str, *args) -> str:
- return self.strings[lang][key] % args
+ if args:
+ return self.strings[lang][key] % args
+ else:
+ return self.strings[lang][key]
def __call__(self, *args, **kwargs):
return self.strings[self.default_lang][args[0]]
diff --git a/src/home/bot/wrapper.py b/src/home/bot/wrapper.py
index 8ebde4f..5f399ce 100644
--- a/src/home/bot/wrapper.py
+++ b/src/home/bot/wrapper.py
@@ -8,6 +8,7 @@ from telegram import (
ReplyKeyboardMarkup,
CallbackQuery,
User,
+ Message,
)
from telegram.ext import (
Updater,
@@ -22,7 +23,7 @@ from telegram.ext import (
)
from telegram.error import TimedOut
from ..config import config
-from typing import Optional, Union
+from typing import Optional, Union, List, Tuple
from .store import Store
from .lang import LangPack
from ..api.types import BotType
@@ -110,7 +111,7 @@ class Context:
kwargs = dict(parse_mode=ParseMode.HTML)
if not isinstance(markup, IgnoreMarkup):
kwargs['reply_markup'] = markup
- self._update.message.reply_text(text, **kwargs)
+ return self._update.message.reply_text(text, **kwargs)
def reply_exc(self, e: Exception) -> None:
self.reply(exc2text(e))
@@ -133,7 +134,7 @@ class Context:
return self._update.callback_query
@property
- def args(self) -> Optional[list[str]]:
+ def args(self) -> Optional[List[str]]:
return self._callback_context.args
@property
@@ -157,6 +158,25 @@ class Context:
return self._update.callback_query and self._update.callback_query.data and self._update.callback_query.data != ''
+def handlermethod(f: callable):
+ def _handler(self, update: Update, context: CallbackContext, *args, **kwargs):
+ ctx = Context(update,
+ callback_context=context,
+ markup_getter=self.markup,
+ lang=self.lang,
+ store=self.store)
+ try:
+ return f(self, ctx, *args, **kwargs)
+ except Exception as e:
+ if not self.exception_handler(e, ctx) and not isinstance(e, TimedOut):
+ logger.exception(e)
+ if not ctx.is_callback_context():
+ ctx.reply_exc(e)
+ else:
+ self.notify_user(ctx.user_id, exc2text(e))
+ return _handler
+
+
class Wrapper:
store: Optional[Store]
updater: Updater
@@ -252,7 +272,7 @@ class Wrapper:
def exception_handler(self, e: Exception, ctx: Context) -> Optional[bool]:
pass
- def notify_all(self, text_getter: callable, exclude: tuple[int] = ()) -> None:
+ def notify_all(self, text_getter: callable, exclude: Tuple[int] = ()) -> None:
if 'notify_users' not in config['bot']:
logger.error('notify_all() called but no notify_users directive found in the config')
return
@@ -280,6 +300,12 @@ class Wrapper:
def send_file(self, user_id, **kwargs):
self.updater.bot.send_document(chat_id=user_id, **kwargs)
+ def edit_message_text(self, user_id, message_id, *args, **kwargs):
+ self.updater.bot.edit_message_text(chat_id=user_id, message_id=message_id, parse_mode='HTML', *args, **kwargs)
+
+ def delete_message(self, user_id, message_id):
+ self.updater.bot.delete_message(chat_id=user_id, message_id=message_id)
+
#
# Language Selection
#
diff --git a/src/home/camera/util.py b/src/home/camera/util.py
index 39bfcd3..97f35aa 100644
--- a/src/home/camera/util.py
+++ b/src/home/camera/util.py
@@ -3,6 +3,7 @@ import os.path
import logging
import psutil
+from typing import List, Tuple
from ..util import chunks
from ..config import config
@@ -62,7 +63,7 @@ async def ffmpeg_cut(input: str,
_logger.info(f'ffmpeg_cut({input}): OK')
-def dvr_scan_timecodes(timecodes: str) -> list[tuple[int, int]]:
+def dvr_scan_timecodes(timecodes: str) -> List[Tuple[int, int]]:
tc_backup = timecodes
timecodes = timecodes.split(',')
diff --git a/src/home/database/bots.py b/src/home/database/bots.py
index bc490e1..99befc0 100644
--- a/src/home/database/bots.py
+++ b/src/home/database/bots.py
@@ -5,7 +5,7 @@ from ..api.types import (
BotType,
SoundSensorLocation
)
-from typing import Optional
+from typing import Optional, List, Tuple
from datetime import datetime
from html import escape
@@ -37,7 +37,7 @@ class BotsDatabase(MySQLDatabase):
self.commit()
def add_openwrt_logs(self,
- lines: list[tuple[datetime, str]]):
+ lines: List[Tuple[datetime, str]]):
now = datetime.now()
with self.cursor() as cursor:
for line in lines:
@@ -47,7 +47,7 @@ class BotsDatabase(MySQLDatabase):
self.commit()
def add_sound_hits(self,
- hits: list[tuple[SoundSensorLocation, int]],
+ hits: List[Tuple[SoundSensorLocation, int]],
time: datetime):
with self.cursor() as cursor:
for loc, count in hits:
@@ -58,7 +58,7 @@ class BotsDatabase(MySQLDatabase):
def get_sound_hits(self,
location: SoundSensorLocation,
after: Optional[datetime] = None,
- last: Optional[int] = None) -> list[dict]:
+ last: Optional[int] = None) -> List[dict]:
with self.cursor(dictionary=True) as cursor:
sql = "SELECT `time`, hits FROM sound_hits WHERE location=%s"
args = [location.name.lower()]
@@ -84,7 +84,7 @@ class BotsDatabase(MySQLDatabase):
def get_openwrt_logs(self,
filter_text: str,
min_id: int,
- limit: int = None) -> list[OpenwrtLogRecord]:
+ limit: int = None) -> List[OpenwrtLogRecord]:
tz = pytz.timezone('Europe/Moscow')
with self.cursor(dictionary=True) as cursor:
sql = "SELECT * FROM openwrt WHERE text LIKE %s AND id > %s"
diff --git a/src/home/media/node_client.py b/src/home/media/node_client.py
index 4430962..eb39898 100644
--- a/src/home/media/node_client.py
+++ b/src/home/media/node_client.py
@@ -2,7 +2,7 @@ import requests
import shutil
import logging
-from typing import Optional, Union
+from typing import Optional, Union, List
from .storage import RecordFile
from ..util import Addr
from ..api.errors import ApiResponseError
@@ -25,7 +25,7 @@ class MediaNodeClient:
def record_download(self, record_id: int, output: str):
return self._call(f'record/download/{record_id}/', save_to=output)
- def storage_list(self, extended=False, as_objects=False) -> Union[list[str], list[dict], list[RecordFile]]:
+ def storage_list(self, extended=False, as_objects=False) -> Union[List[str], List[dict], List[RecordFile]]:
r = self._call('storage/list/', params={'extended': int(extended)})
files = r['files']
if as_objects:
@@ -33,7 +33,7 @@ class MediaNodeClient:
return files
@staticmethod
- def record_list_from_serialized(files: Union[list[str], list[dict]]):
+ def record_list_from_serialized(files: Union[List[str], List[dict]]):
new_files = []
for f in files:
kwargs = {'remote': True}
diff --git a/src/home/media/record.py b/src/home/media/record.py
index fdb8382..cd7447a 100644
--- a/src/home/media/record.py
+++ b/src/home/media/record.py
@@ -5,7 +5,7 @@ import time
import subprocess
import signal
-from typing import Optional
+from typing import Optional, List, Dict
from ..util import find_child_processes, Addr
from ..config import config
from .storage import RecordFile, RecordStorage
@@ -22,7 +22,7 @@ class RecordHistoryItem:
request_time: float
start_time: float
stop_time: float
- relations: list[int]
+ relations: List[int]
status: RecordStatus
error: Optional[Exception]
file: Optional[RecordFile]
@@ -76,7 +76,7 @@ class RecordingNotFoundError(Exception):
class RecordHistory:
- history: dict[int, RecordHistoryItem]
+ history: Dict[int, RecordHistoryItem]
def __init__(self):
self.history = {}
diff --git a/src/home/media/record_client.py b/src/home/media/record_client.py
index f264155..322495c 100644
--- a/src/home/media/record_client.py
+++ b/src/home/media/record_client.py
@@ -7,7 +7,7 @@ from tempfile import gettempdir
from .record import RecordStatus
from .node_client import SoundNodeClient, MediaNodeClient, CameraNodeClient
from ..util import Addr
-from typing import Optional, Callable
+from typing import Optional, Callable, Dict
class RecordClient:
@@ -15,14 +15,14 @@ class RecordClient:
interrupted: bool
logger: logging.Logger
- clients: dict[str, MediaNodeClient]
- awaiting: dict[str, dict[int, Optional[dict]]]
+ clients: Dict[str, MediaNodeClient]
+ awaiting: Dict[str, Dict[int, Optional[dict]]]
error_handler: Optional[Callable]
finished_handler: Optional[Callable]
download_on_finish: bool
def __init__(self,
- nodes: dict[str, Addr],
+ nodes: Dict[str, Addr],
error_handler: Optional[Callable] = None,
finished_handler: Optional[Callable] = None,
download_on_finish=False):
@@ -50,7 +50,7 @@ class RecordClient:
self.stop()
self.logger.exception(exc)
- def make_clients(self, nodes: dict[str, Addr]):
+ def make_clients(self, nodes: Dict[str, Addr]):
pass
def stop(self):
@@ -148,9 +148,9 @@ class RecordClient:
class SoundRecordClient(RecordClient):
DOWNLOAD_EXTENSION = 'mp3'
- # clients: dict[str, SoundNodeClient]
+ # clients: Dict[str, SoundNodeClient]
- def make_clients(self, nodes: dict[str, Addr]):
+ def make_clients(self, nodes: Dict[str, Addr]):
for node, addr in nodes.items():
self.clients[node] = SoundNodeClient(addr)
self.awaiting[node] = {}
@@ -158,9 +158,9 @@ class SoundRecordClient(RecordClient):
class CameraRecordClient(RecordClient):
DOWNLOAD_EXTENSION = 'mp4'
- # clients: dict[str, CameraNodeClient]
+ # clients: Dict[str, CameraNodeClient]
- def make_clients(self, nodes: dict[str, Addr]):
+ def make_clients(self, nodes: Dict[str, Addr]):
for node, addr in nodes.items():
self.clients[node] = CameraNodeClient(addr)
self.awaiting[node] = {} \ No newline at end of file
diff --git a/src/home/media/storage.py b/src/home/media/storage.py
index 08ba06a..dd74ff8 100644
--- a/src/home/media/storage.py
+++ b/src/home/media/storage.py
@@ -3,7 +3,7 @@ import re
import shutil
import logging
-from typing import Optional, Union
+from typing import Optional, Union, List
from datetime import datetime
from ..util import strgen
@@ -149,7 +149,7 @@ class RecordStorage:
self.root = root
- def getfiles(self, as_objects=False) -> Union[list[str], list[RecordFile]]:
+ def getfiles(self, as_objects=False) -> Union[List[str], List[RecordFile]]:
files = []
for name in os.listdir(self.root):
path = os.path.join(self.root, name)
diff --git a/src/home/telegram/telegram.py b/src/home/telegram/telegram.py
index 9c7ea73..2f94f93 100644
--- a/src/home/telegram/telegram.py
+++ b/src/home/telegram/telegram.py
@@ -1,6 +1,7 @@
import requests
import logging
+from typing import Tuple
from ..config import config
@@ -29,7 +30,7 @@ def send_photo(filename: str):
def _send_telegram_data(text: str,
parse_mode: str = None,
- disable_web_page_preview: bool = False) -> tuple[dict, str]:
+ disable_web_page_preview: bool = False) -> Tuple[dict, str]:
data = {
'chat_id': config['telegram']['chat_id'],
'text': text
diff --git a/src/home/util.py b/src/home/util.py
index 9dd84f6..5050ebb 100644
--- a/src/home/util.py
+++ b/src/home/util.py
@@ -9,7 +9,7 @@ import random
from enum import Enum
from datetime import datetime
-from typing import Tuple, Optional
+from typing import Tuple, Optional, List
Addr = Tuple[str, int] # network address type (host, port)
@@ -96,7 +96,7 @@ def send_datagram(message: str, addr: Addr) -> None:
sock.sendto(message.encode(), addr)
-def format_tb(exc) -> Optional[list[str]]:
+def format_tb(exc) -> Optional[List[str]]:
tb = traceback.format_tb(exc.__traceback__)
if not tb:
return None
@@ -120,7 +120,7 @@ class ChildProcessInfo:
self.cmd = cmd
-def find_child_processes(ppid: int) -> list[ChildProcessInfo]:
+def find_child_processes(ppid: int) -> List[ChildProcessInfo]:
p = subprocess.run(['pgrep', '-P', str(ppid), '--list-full'], capture_output=True)
if p.returncode != 0:
raise OSError(f'pgrep returned {p.returncode}')
diff --git a/src/ipcam_server.py b/src/ipcam_server.py
index 9e72e68..47e7156 100755
--- a/src/ipcam_server.py
+++ b/src/ipcam_server.py
@@ -14,7 +14,7 @@ from home.database.sqlite import SQLiteBase
from home.camera import util as camutil
from enum import Enum
-from typing import Optional, Union
+from typing import Optional, Union, List
from datetime import datetime, timedelta
@@ -273,7 +273,7 @@ def get_motion_path(cam: int) -> str:
def get_recordings_files(cam: int,
- time_filter_type: Optional[TimeFilterType] = None) -> list[dict]:
+ time_filter_type: Optional[TimeFilterType] = None) -> List[dict]:
from_time = 0
to_time = int(time.time())
@@ -305,7 +305,7 @@ def get_recordings_files(cam: int,
async def process_fragments(camera: int,
filename: str,
- fragments: list[tuple[int, int]]) -> None:
+ fragments: List[Tuple[int, int]]) -> None:
time = filename_to_datetime(filename)
rec_dir = get_recordings_path(camera)
@@ -338,7 +338,7 @@ async def process_fragments(camera: int,
async def motion_notify_tg(camera: int,
filename: str,
- fragments: list[tuple[int, int]]):
+ fragments: List[Tuple[int, int]]):
dt_file = filename_to_datetime(filename)
fmt = '%H:%M:%S'
diff --git a/src/openwrt_logger.py b/src/openwrt_logger.py
index 4d3b310..05fedfe 100755
--- a/src/openwrt_logger.py
+++ b/src/openwrt_logger.py
@@ -5,6 +5,7 @@ from datetime import datetime
from home.config import config
from home.database import SimpleState
from home.api import WebAPIClient
+from typing import Tuple
log_file = '/var/log/openwrt.log'
@@ -24,7 +25,7 @@ $UDPServerRun 514
"""
-def parse_line(line: str) -> tuple[int, str]:
+def parse_line(line: str) -> Tuple[int, str]:
space_pos = line.index(' ')
date = line[:space_pos]
@@ -58,7 +59,7 @@ if __name__ == '__main__':
state['seek'] = f.tell()
state['size'] = fsize
- lines: list[tuple[int, str]] = []
+ lines: List[Tuple[int, str]] = []
if content != '':
for line in content.strip().split('\n'):
diff --git a/src/polaris/__init__.py b/src/polaris/__init__.py
index aa077ce..f1a7c1d 100644
--- a/src/polaris/__init__.py
+++ b/src/polaris/__init__.py
@@ -1,4 +1,12 @@
-# SPDX-License-Identifier: BSD-3-Clause
+# Polaris PWK 1725CGLD "smart" kettle python library
+# --------------------------------------------------
+# Copyright (C) Evgeny Zinoviev, 2022
+# License: BSD-3c
-from .kettle import Kettle
-from .protocol import Message, FrameType, PowerType \ No newline at end of file
+from .kettle import Kettle, DeviceListener
+from .protocol import (
+ PowerType,
+ IncomingMessageListener,
+ ConnectionStatusListener,
+ ConnectionStatus
+) \ No newline at end of file
diff --git a/src/polaris/kettle.py b/src/polaris/kettle.py
index 37f6813..1b995fd 100644
--- a/src/polaris/kettle.py
+++ b/src/polaris/kettle.py
@@ -1,97 +1,238 @@
-# SPDX-License-Identifier: BSD-3-Clause
+# Polaris PWK 1725CGLD smart kettle python library
+# ------------------------------------------------
+# Copyright (C) Evgeny Zinoviev, 2022
+# License: BSD-3c
from __future__ import annotations
+import threading
import logging
import zeroconf
-import cryptography.hazmat.primitives._serialization
-from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
-from cryptography.hazmat.primitives import hashes
-
-from functools import partial
-from abc import ABC
-from ipaddress import ip_address
-from typing import Optional
+from abc import abstractmethod
+from ipaddress import ip_address, IPv4Address, IPv6Address
+from typing import Optional, List, Union
from .protocol import (
- Connection,
+ UDPConnection,
ModeMessage,
- HandshakeMessage,
TargetTemperatureMessage,
- Message,
- PowerType
+ PowerType,
+ ConnectionStatus,
+ ConnectionStatusListener,
+ WrappedMessage
)
-_logger = logging.getLogger(__name__)
-
-# Polaris PWK 1725CGLD IoT kettle
-class Kettle(zeroconf.ServiceListener, ABC):
- macaddr: str
- device_token: str
- sb: Optional[zeroconf.ServiceBrowser]
- found_device: Optional[zeroconf.ServiceInfo]
- conn: Optional[Connection]
+class DeviceDiscover(threading.Thread, zeroconf.ServiceListener):
+ si: Optional[zeroconf.ServiceInfo]
+ _mac: str
+ _sb: Optional[zeroconf.ServiceBrowser]
+ _zc: Optional[zeroconf.Zeroconf]
+ _listeners: List[DeviceListener]
+ _valid_addresses: List[Union[IPv4Address, IPv6Address]]
+ _only_ipv4: bool
- def __init__(self, mac: str, device_token: str):
+ def __init__(self, mac: str,
+ listener: Optional[DeviceListener] = None,
+ only_ipv4=True):
super().__init__()
- self.zeroconf = zeroconf.Zeroconf()
- self.sb = None
- self.macaddr = mac
- self.device_token = device_token
- self.found_device = None
- self.conn = None
+ self.si = None
+ self._mac = mac
+ self._zc = None
+ self._sb = None
+ self._only_ipv4 = only_ipv4
+ self._valid_addresses = []
+ self._listeners = []
+ if isinstance(listener, DeviceListener):
+ self._listeners.append(listener)
+ self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}')
+
+ def add_listener(self, listener: DeviceListener):
+ if listener not in self._listeners:
+ self._listeners.append(listener)
+ else:
+ self._logger.warning(f'add_listener: listener {listener} already in the listeners list')
+
+ def set_info(self, info: zeroconf.ServiceInfo):
+ valid_addresses = self._get_valid_addresses(info)
+ if not valid_addresses:
+ raise ValueError('no valid addresses')
+ self._valid_addresses = valid_addresses
+ self.si = info
+ for f in self._listeners:
+ try:
+ f.device_updated()
+ except Exception as exc:
+ self._logger.error(f'set_info: error while calling device_updated on {f}')
+ self._logger.exception(exc)
- def find(self) -> zeroconf.ServiceInfo:
- self.sb = zeroconf.ServiceBrowser(self.zeroconf, "_syncleo._udp.local.", self)
- self.sb.join()
+ def add_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None:
+ self._add_update_service('add_service', zc, type_, name)
- return self.found_device
+ def update_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None:
+ self._add_update_service('update_service', zc, type_, name)
- # zeroconf.ServiceListener implementation
- def add_service(self,
- zc: zeroconf.Zeroconf,
- type_: str,
- name: str) -> None:
- if name.startswith(f'{self.macaddr}.'):
- info = zc.get_service_info(type_, name)
+ def _add_update_service(self, method: str, zc: zeroconf.Zeroconf, type_: str, name: str) -> None:
+ info = zc.get_service_info(type_, name)
+ if name.startswith(f'{self._mac}.'):
+ self._logger.info(f'{method}: type={type_} name={name}')
+ try:
+ self.set_info(info)
+ except ValueError as exc:
+ self._logger.error(f'{method}: rejected: {str(exc)}')
+ else:
+ self._logger.debug(f'{method}: mac not matched: {info}')
+
+ def remove_service(self, zc: zeroconf.Zeroconf, type_: str, name: str) -> None:
+ if name.startswith(f'{self._mac}.'):
+ self._logger.info(f'remove_service: type={type_} name={name}')
+ # TODO what to do here?!
+
+ def run(self):
+ self._logger.info('starting zeroconf service browser')
+ ip_version = zeroconf.IPVersion.V4Only if self._only_ipv4 else zeroconf.IPVersion.All
+ self._zc = zeroconf.Zeroconf(ip_version=ip_version)
+ self._sb = zeroconf.ServiceBrowser(self._zc, "_syncleo._udp.local.", self)
+ self._sb.join()
+
+ def stop(self):
+ if self._sb:
try:
- self.sb.cancel()
+ self._sb.cancel()
except RuntimeError:
pass
- self.zeroconf.close()
- self.found_device = info
+ self._sb = None
+ self._zc.close()
+ self._zc = None
+
+ def _get_valid_addresses(self, si: zeroconf.ServiceInfo) -> List[Union[IPv4Address, IPv6Address]]:
+ valid = []
+ for addr in map(ip_address, si.addresses):
+ if self._only_ipv4 and not isinstance(addr, IPv4Address):
+ continue
+ if isinstance(addr, IPv4Address) and str(addr).startswith('169.254.'):
+ continue
+ valid.append(addr)
+ return valid
- assert self.device_curve == 29, f'curve type {self.device_curve} is not implemented'
-
- def start_server(self, callback: callable):
- addresses = list(map(ip_address, self.found_device.addresses))
- self.conn = Connection(addr=addresses[0],
- port=int(self.found_device.port),
- device_pubkey=self.device_pubkey,
- device_token=bytes.fromhex(self.device_token))
+ @property
+ def pubkey(self) -> bytes:
+ return bytes.fromhex(self.si.properties[b'public'].decode())
- # shake the kettle's hand
- self._pass_message(HandshakeMessage(), callback)
- self.conn.start()
+ @property
+ def curve(self) -> int:
+ return int(self.si.properties[b'curve'].decode())
- def stop_server(self):
- self.conn.interrupted = True
+ @property
+ def addr(self) -> Union[IPv4Address, IPv6Address]:
+ return self._valid_addresses[0]
@property
- def device_pubkey(self) -> bytes:
- return bytes.fromhex(self.found_device.properties[b'public'].decode())
+ def port(self) -> int:
+ return int(self.si.port)
@property
- def device_curve(self) -> int:
- return int(self.found_device.properties[b'curve'].decode())
+ def protocol(self) -> int:
+ return int(self.si.properties[b'protocol'].decode())
+
+
+class DeviceListener:
+ @abstractmethod
+ def device_updated(self):
+ pass
+
+
+class Kettle(DeviceListener, ConnectionStatusListener):
+ mac: str
+ device: Optional[DeviceDiscover]
+ device_token: str
+ conn: Optional[UDPConnection]
+ conn_status: Optional[ConnectionStatus]
+ _logger: logging.Logger
+ _find_evt: threading.Event
+
+ def __init__(self, mac: str, device_token: str):
+ super().__init__()
+ self.mac = mac
+ self.device = None
+ self.device_token = device_token
+ self.conn = None
+ self.conn_status = None
+ self._find_evt = threading.Event()
+ self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}')
+
+ def device_updated(self):
+ self._find_evt.set()
+ self._logger.info(f'device updated, service info: {self.device.si}')
+
+ def connection_status_updated(self, status: ConnectionStatus):
+ self.conn_status = status
+
+ def discover(self, wait=True, timeout=None, listener=None) -> Optional[zeroconf.ServiceInfo]:
+ do_start = False
+ if not self.device:
+ self.device = DeviceDiscover(self.mac, listener=self, only_ipv4=True)
+ do_start = True
+ self._logger.debug('discover: started device discovery')
+ else:
+ self._logger.warning('discover: already started')
+
+ if listener is not None:
+ self.device.add_listener(listener)
+
+ if do_start:
+ self.device.start()
+
+ if wait:
+ self._find_evt.clear()
+ try:
+ self._find_evt.wait(timeout=timeout)
+ except KeyboardInterrupt:
+ self.device.stop()
+ return None
+ return self.device.si
+
+ def start_server_if_needed(self,
+ incoming_message_listener=None,
+ connection_status_listener=None):
+ if self.conn:
+ self._logger.warning('start_server_if_needed: server is already started!')
+ self.conn.set_address(self.device.addr, self.device.port)
+ self.conn.set_device_pubkey(self.device.pubkey)
+ return
+
+ assert self.device.curve == 29, f'curve type {self.device.curve} is not implemented'
+ assert self.device.protocol == 2, f'protocol {self.device.protocol} is not supported'
+
+ self.conn = UDPConnection(addr=self.device.addr,
+ port=self.device.port,
+ device_pubkey=self.device.pubkey,
+ device_token=bytes.fromhex(self.device_token))
+ if incoming_message_listener:
+ self.conn.add_incoming_message_listener(incoming_message_listener)
+
+ self.conn.add_connection_status_listener(self)
+ if connection_status_listener:
+ self.conn.add_connection_status_listener(connection_status_listener)
+
+ self.conn.start()
+
+ def stop_all(self):
+ # when we stop server, we should also stop device discovering service
+ if self.conn:
+ self.conn.interrupted = True
+ self.conn = None
+ self.device.stop()
+ self.device = None
+
+ def is_connected(self) -> bool:
+ return self.conn is not None and self.conn_status == ConnectionStatus.CONNECTED
def set_power(self, power_type: PowerType, callback: callable):
- self._pass_message(ModeMessage(power_type), callback)
+ message = ModeMessage(power_type)
+ self.conn.enqueue_message(WrappedMessage(message, handler=callback, ack=True))
def set_target_temperature(self, temp: int, callback: callable):
- self._pass_message(TargetTemperatureMessage(temp), callback)
-
- def _pass_message(self, message: Message, callback: callable):
- self.conn.send_message(message, partial(callback, self))
+ message = TargetTemperatureMessage(temp)
+ self.conn.enqueue_message(WrappedMessage(message, handler=callback, ack=True))
diff --git a/src/polaris/protocol.py b/src/polaris/protocol.py
index cc4e36a..5d7390f 100644
--- a/src/polaris/protocol.py
+++ b/src/polaris/protocol.py
@@ -1,39 +1,61 @@
-# SPDX-License-Identifier: BSD-3-Clause
+# Polaris PWK 1725CGLD "smart" kettle python library
+# --------------------------------------------------
+# Copyright (C) Evgeny Zinoviev, 2022
+# License: BSD-3c
from __future__ import annotations
+
import logging
import socket
import random
import struct
import threading
-import queue
+import time
-from enum import Enum
-from typing import Union, Optional, Any
+from abc import abstractmethod, ABC
+from enum import Enum, auto
+from typing import Union, Optional, Dict, Tuple, List
from ipaddress import IPv4Address, IPv6Address
-import cryptography.hazmat.primitives._serialization
+import cryptography.hazmat.primitives._serialization as srlz
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives import ciphers, padding, hashes
from cryptography.hazmat.primitives.ciphers import algorithms, modes
-
+ReprDict = Dict[str, Union[str, int, float, bool]]
_logger = logging.getLogger(__name__)
+PING_FREQUENCY = 3
+RESEND_ATTEMPTS = 5
+READ_TIMEOUT = 1
+ERROR_TIMEOUT = 15
+MESSAGE_QUEUE_REMOVE_DELAY = 13 # after what time to delete (and pass False to handlers, if needed) messages with phase=DONE from queue
+DISCONNECT_TIMEOUT = 15
+
-# drop-in replacement for Java API
+def safe_callback_call(f: callable,
+ *args,
+ logger: logging.Logger = None,
+ error_message: str = None):
+ try:
+ return f(*args)
+ except Exception as exc:
+ logger.error(f'{error_message}, see exception below:')
+ logger.exception(exc)
+ return None
+
+
+# drop-in replacement for java.lang.System.arraycopy
# TODO: rewrite
def arraycopy(src, src_pos, dest, dest_pos, length):
for i in range(length):
dest[i + dest_pos] = src[i + src_pos]
-class FrameType(Enum):
- ACK = 0
- CMD = 1
- AUX = 2
- NAK = 3
+# "convert" unsigned byte to signed
+def u8_to_s8(b: int) -> int:
+ return struct.unpack('b', bytes([b]))[0]
class PowerType(Enum):
@@ -44,23 +66,37 @@ class PowerType(Enum):
# update: if I set it to '2', it just resets to '0'
+# low-level protocol structures
+# -----------------------------
+
+class FrameType(Enum):
+ ACK = 0
+ CMD = 1
+ AUX = 2
+ NAK = 3
+
+
class FrameHead:
- seq: int # u8
+ seq: Optional[int] # u8
type: FrameType # u8
- length: int # u16
+ length: int # u16. This is the length of FrameItem's payload
@staticmethod
def from_bytes(buf: bytes) -> FrameHead:
seq, ft, length = struct.unpack('<BBH', buf)
return FrameHead(seq, FrameType(ft), length)
- def __init__(self, seq: int, frame_type: FrameType, length: Optional[int] = None):
+ def __init__(self,
+ seq: Optional[int],
+ frame_type: FrameType,
+ length: Optional[int] = None):
self.seq = seq
self.type = frame_type
self.length = length or 0
def pack(self) -> bytes:
assert self.length != 0, "FrameHead.length has not been set"
+ assert self.seq is not None, "FrameHead.seq has not been set"
return struct.pack('<BBH', self.seq, self.type.value, self.length)
@@ -84,30 +120,47 @@ class FrameItem:
return bytes(ba)
+# high-level wrappers around FrameItem
+# ------------------------------------
+
+class MessagePhase(Enum):
+ WAITING = 0
+ SENT = 1
+ DONE = 2
+
+
class Message:
frame: Optional[FrameItem]
+ id: int
+
+ _global_id = 0
def __init__(self):
self.frame = None
+ # global internal message id, only useful for debugging purposes
+ self.id = self.next_id()
+
+ def __repr__(self):
+ return f'<{self.__class__.__name__} id={self.id} seq={self.frame.head.seq}>'
+
@staticmethod
- def from_encrypted(buf: bytes,
- inkey: bytes,
- outkey: bytes) -> Message:
- # _logger.debug('[from_encrypted] buf='+buf.hex())
- # print(f'buf len={len(buf)}')
+ def next_id():
+ _id = Message._global_id
+ Message._global_id += 1
+ return _id
+
+ @staticmethod
+ def from_encrypted(buf: bytes, inkey: bytes, outkey: bytes) -> Message:
+ _logger.debug(f'Message:from_encrypted: buf={buf.hex()}')
+
assert len(buf) >= 4, 'invalid size'
head = FrameHead.from_bytes(buf[:4])
assert len(buf) == head.length + 4, f'invalid buf size ({len(buf)} != {head.length})'
-
payload = buf[4:]
-
- # byte b = paramFrameHead.seq;
b = head.seq
- # TODO check if protocol is 2, otherwise raise an exception
-
j = b & 0xF
k = b >> 4 & 0xF
@@ -123,15 +176,9 @@ class Message:
decryptor = cipher.decryptor()
decrypted_data = decryptor.update(payload) + decryptor.finalize()
- # print(f'head.length={head.length} len(decr)={len(decrypted_data)}')
- # if len(decrypted_data) > head.length:
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
decrypted_data = unpadder.update(decrypted_data)
- # try:
decrypted_data += unpadder.finalize()
- # except ValueError as exc:
- # _logger.exception(exc)
- # pass
assert len(decrypted_data) != 0, 'decrypted data is null'
assert head.seq == decrypted_data[0], f'decrypted seq mismatch {head.seq} != {decrypted_data[0]}'
@@ -145,29 +192,54 @@ class Message:
return NakMessage(head.seq)
elif head.type == FrameType.AUX:
+ # TODO implement AUX
raise NotImplementedError('FrameType AUX is not yet implemented')
elif head.type == FrameType.CMD:
- cmd = decrypted_data[0]
+ type = decrypted_data[1]
data = decrypted_data[2:]
- return CmdMessage(head.seq, cmd, data)
+
+ cl = UnknownMessage
+
+ subclasses = [cl for cl in CmdIncomingMessage.__subclasses__() if cl is not SimpleBooleanMessage]
+ subclasses.extend(SimpleBooleanMessage.__subclasses__())
+
+ for _cl in subclasses:
+ # `UnknownMessage` is a special class that holds a packed command that we don't recognize.
+ # It will be used anyway if we don't find a match, so skip it here
+ if _cl == UnknownMessage:
+ continue
+
+ if _cl.TYPE == type:
+ cl = _cl
+ break
+
+ m = cl.from_packed_data(data, seq=head.seq)
+ if isinstance(m, UnknownMessage):
+ m.set_type(type)
+ return m
else:
raise NotImplementedError(f'Unexpected frame type: {head.type}')
- @property
- def data(self) -> bytes:
+ def pack_data(self) -> bytes:
return b''
- def encrypt(self,
- outkey: bytes,
- inkey: bytes,
- token: bytes,
- pubkey: bytes):
+ @property
+ def seq(self) -> Union[int, None]:
+ try:
+ return self.frame.head.seq
+ except:
+ return None
+
+ @seq.setter
+ def seq(self, seq: int):
+ self.frame.head.seq = seq
+ def encrypt(self, outkey: bytes, inkey: bytes, token: bytes, pubkey: bytes):
assert self.frame is not None
- data = self.data
+ data = self._get_data_to_encrypt()
assert data is not None
b = self.frame.head.seq
@@ -199,7 +271,7 @@ class Message:
arraycopy(data, 0, newdata, 1, len(data))
newdata = bytes(newdata)
- _logger.debug('payload to be sent: ' + newdata.hex())
+ _logger.debug('frame payload to be encrypted: ' + newdata.hex())
padder = padding.PKCS7(algorithms.AES.block_size).padder()
ciphertext = bytearray()
@@ -208,74 +280,143 @@ class Message:
self.frame.setpayload(ciphertext)
- def set_seq(self, seq: int):
- self.frame.head.seq = seq
-
- def __repr__(self):
- return f'<{self.__class__.__name__} seq={self.frame.head.seq}>'
+ def _get_data_to_encrypt(self) -> bytes:
+ return self.pack_data()
-class AckMessage(Message):
- def __init__(self, seq: int = 0):
+class AckMessage(Message, ABC):
+ def __init__(self, seq: Optional[int] = None):
super().__init__()
- self.frame = FrameItem(FrameHead(seq, FrameType.ACK, 0))
+ self.frame = FrameItem(FrameHead(seq, FrameType.ACK, None))
-class NakMessage(Message):
- def __init__(self, seq: int = 0):
+class NakMessage(Message, ABC):
+ def __init__(self, seq: Optional[int] = None):
super().__init__()
- self.frame = FrameItem(FrameHead(seq, FrameType.NAK, 0))
+ self.frame = FrameItem(FrameHead(seq, FrameType.NAK, None))
class CmdMessage(Message):
- _type: Optional[int]
- _data: bytes
+ type: Optional[int]
+ data: bytes
- def __init__(self, seq=0,
- type: Optional[int] = None,
- data: bytes = b''):
- super().__init__()
- self._data = data
- if type is not None:
- self.frame = FrameItem(FrameHead(seq, FrameType.CMD))
- self._type = type
- else:
- self._type = None
+ TYPE = None
- @property
- def data(self) -> bytes:
+ def _get_data_to_encrypt(self) -> bytes:
buf = bytearray()
- buf.append(self._type)
- buf.extend(self._data)
+ buf.append(self.get_type())
+ buf.extend(self.pack_data())
return bytes(buf)
+ def __init__(self, seq: Optional[int] = None):
+ super().__init__()
+ self.frame = FrameItem(FrameHead(seq, FrameType.CMD))
+ self.data = b''
+
+ def _repr_fields(self) -> ReprDict:
+ return {
+ 'cmd': self.get_type()
+ }
+
def __repr__(self):
params = [
__name__+'.'+self.__class__.__name__,
- f'seq={self.frame.head.seq}',
- # f'type={self.frame.head.type}',
- f'cmd={self._type}'
+ f'id={self.id}',
+ f'seq={self.seq}'
]
- if self._data:
- params.append(f'data={self._data.hex()}')
+ fields = self._repr_fields()
+ if fields:
+ for k, v in fields.items():
+ params.append(f'{k}={v}')
+ elif self.data:
+ params.append(f'data={self.data.hex()}')
return '<'+' '.join(params)+'>'
+ def get_type(self) -> int:
+ return self.__class__.TYPE
+
+
+class CmdIncomingMessage(CmdMessage):
+ @staticmethod
+ @abstractmethod
+ def from_packed_data(cls, data: bytes, seq: Optional[int] = None):
+ pass
+
+ @abstractmethod
+ def _repr_fields(self) -> ReprDict:
+ pass
+
+
+class CmdOutgoingMessage(CmdMessage):
+ @abstractmethod
+ def pack_data(self) -> bytes:
+ return b''
+
+
+class ModeMessage(CmdOutgoingMessage, CmdIncomingMessage):
+ TYPE = 1
+
+ pt: PowerType
+
+ def __init__(self, power_type: PowerType, seq: Optional[int] = None):
+ super().__init__(seq)
+ self.pt = power_type
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> ModeMessage:
+ assert len(data) == 1, 'data size expected to be 1'
+ mode, = struct.unpack('B', data)
+ return ModeMessage(PowerType(mode), seq=seq)
+
+ def pack_data(self) -> bytes:
+ return self.pt.value.to_bytes(1, byteorder='little')
-class ModeMessage(CmdMessage):
- def __init__(self, power_type: PowerType):
- super().__init__(type=1,
- data=(power_type.value).to_bytes(1, byteorder='little'))
+ def _repr_fields(self) -> ReprDict:
+ return {'mode': self.pt.name}
-class TargetTemperatureMessage(CmdMessage):
- def __init__(self, temp: int):
- super().__init__(type=2,
- data=bytes(bytearray([temp, 0])))
+class TargetTemperatureMessage(CmdOutgoingMessage, CmdIncomingMessage):
+ temperature: int
+ TYPE = 2
+ def __init__(self, temp: int, seq: Optional[int] = None):
+ super().__init__(seq)
+ self.temperature = temp
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> TargetTemperatureMessage:
+ assert len(data) == 2, 'data size expected to be 2'
+ nat, frac = struct.unpack('BB', data)
+ temp = int(nat + (frac / 100))
+ return TargetTemperatureMessage(temp, seq=seq)
+
+ def pack_data(self) -> bytes:
+ return bytes([self.temperature, 0])
+
+ def _repr_fields(self) -> ReprDict:
+ return {'temperature': self.temperature}
+
+
+class PingMessage(CmdIncomingMessage, CmdOutgoingMessage):
+ TYPE = 255
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> PingMessage:
+ assert len(data) == 0, 'no data expected'
+ return PingMessage(seq=seq)
+
+ def pack_data(self) -> bytes:
+ return b''
+
+ def _repr_fields(self) -> ReprDict:
+ return {}
+
+
+# This is the first protocol message. Sent by a client.
+# Kettle usually ACKs this, but sometimes i don't get any ACK and the very next message is HandshakeResponseMessage.
class HandshakeMessage(CmdMessage):
- def __init__(self):
- super().__init__(type=0)
+ TYPE = 0
def encrypt(self,
outkey: bytes,
@@ -297,20 +438,312 @@ class HandshakeMessage(CmdMessage):
self.frame.setpayload(pld)
-# TODO
-# implement resending UDP messages if no answer has been received in a second
-# try at least 5 times, then give up
-class Connection(threading.Thread):
- seq_no: int
+# Kettle either sends this right after the handshake, of first it ACKs the handshake then sends this.
+class HandshakeResponseMessage(CmdIncomingMessage):
+ TYPE = 0
+
+ protocol: int
+ fw_major: int
+ fw_minor: int
+ mode: int
+ token: bytes
+
+ def __init__(self,
+ protocol: int,
+ fw_major: int,
+ fw_minor: int,
+ mode: int,
+ token: bytes,
+ seq: Optional[int] = None):
+ super().__init__(seq)
+ self.protocol = protocol
+ self.fw_major = fw_major
+ self.fw_minor = fw_minor
+ self.mode = mode
+ self.token = token
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> HandshakeResponseMessage:
+ protocol, fw_major, fw_minor, mode = struct.unpack('<HBBB', data[:5])
+ return HandshakeResponseMessage(protocol, fw_major, fw_minor, mode, token=data[5:], seq=seq)
+
+ def _repr_fields(self) -> ReprDict:
+ return {
+ 'protocol': self.protocol,
+ 'fw': f'{self.fw_major}.{self.fw_minor}',
+ 'mode': self.mode,
+ 'token': self.token.hex()
+ }
+
+
+# Apparently, some hardware info.
+# On the other hand, if you look at com.syncleiot.iottransport.commands.CmdHardware, its mqtt topic is "mcu_firmware".
+# My device returns 1.1.1. The thing uses on ESP8266 MCU under the hood (or, more precisely, under a piece of cheap
+# plastic), so maybe 1.1.1 is the MCU fw revision.
+class DeviceHardwareMessage(CmdIncomingMessage):
+ TYPE = 143 # -113
+
+ hw: List[int]
+
+ def __init__(self, hw: List[int], seq: Optional[int] = None):
+ super().__init__(seq)
+ self.hw = hw
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> DeviceHardwareMessage:
+ assert len(data) == 3, 'invalid data size, expected 3'
+ hw = list(struct.unpack('<BBB', data))
+ return DeviceHardwareMessage(hw, seq=seq)
+
+ def _repr_fields(self) -> ReprDict:
+ return {'device_hardware': '.'.join(map(str, self.hw))}
+
+
+# This message is sent by kettle right after the HandshakeMessageResponse.
+# The diagnostic data is supposed to be sent to vendor, which we, obviously, not going to do.
+# So just ACK and skip it.
+class DeviceDiagnosticMessage(CmdIncomingMessage):
+ TYPE = 145 # -111
+
+ diag_data: bytes
+
+ def __init__(self, diag_data: bytes, seq: Optional[int] = None):
+ super().__init__(seq)
+ self.diag_data = diag_data
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> DeviceDiagnosticMessage:
+ return DeviceDiagnosticMessage(diag_data=data, seq=seq)
+
+ def _repr_fields(self) -> ReprDict:
+ return {'diag_data': self.diag_data.hex()}
+
+
+class SimpleBooleanMessage(ABC, CmdIncomingMessage):
+ value: bool
+
+ def __init__(self, value: bool, seq: Optional[int] = None):
+ super().__init__(seq)
+ self.value = value
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq: Optional[int] = None):
+ assert len(data) == 1, 'invalid data size, expected 1'
+ enabled, = struct.unpack('<B', data)
+ return cls(value=enabled == 1, seq=seq)
+
+ @abstractmethod
+ def _repr_fields(self) -> ReprDict:
+ pass
+
+
+class AccessControlMessage(SimpleBooleanMessage):
+ TYPE = 133 # -123
+
+ def _repr_fields(self) -> ReprDict:
+ return {'acl_enabled': self.value}
+
+
+class ErrorMessage(SimpleBooleanMessage):
+ TYPE = 7
+
+ def _repr_fields(self) -> ReprDict:
+ return {'error': self.value}
+
+
+class ChildLockMessage(SimpleBooleanMessage):
+ TYPE = 30
+
+ def _repr_fields(self) -> ReprDict:
+ return {'child_lock': self.value}
+
+
+class VolumeMessage(SimpleBooleanMessage):
+ TYPE = 9
+
+ def _repr_fields(self) -> ReprDict:
+ return {'volume': self.value}
+
+
+class BacklightMessage(SimpleBooleanMessage):
+ TYPE = 28
+
+ def _repr_fields(self) -> ReprDict:
+ return {'backlight': self.value}
+
+
+class CurrentTemperatureMessage(CmdIncomingMessage):
+ TYPE = 20
+
+ current_temperature: int
+
+ def __init__(self, temp: int, seq: Optional[int] = None):
+ super().__init__(seq)
+ self.current_temperature = temp
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> CurrentTemperatureMessage:
+ assert len(data) == 2, 'data size expected to be 2'
+ nat, frac = struct.unpack('BB', data)
+ temp = int(nat + (frac / 100))
+ return CurrentTemperatureMessage(temp, seq=seq)
+
+ def pack_data(self) -> bytes:
+ return bytes([self.current_temperature, 0])
+
+ def _repr_fields(self) -> ReprDict:
+ return {'current_temperature': self.current_temperature}
+
+
+class UnknownMessage(CmdIncomingMessage):
+ type: Optional[int]
+ data: bytes
+
+ def __init__(self, data: bytes, **kwargs):
+ super().__init__(**kwargs)
+ self.type = None
+ self.data = data
+
+ @classmethod
+ def from_packed_data(cls, data: bytes, seq=0) -> UnknownMessage:
+ return UnknownMessage(data, seq=seq)
+
+ def set_type(self, type: int):
+ self.type = type
+
+ def get_type(self) -> int:
+ return self.type
+
+ def _repr_fields(self) -> ReprDict:
+ return {
+ 'type': self.type,
+ 'data': self.data.hex()
+ }
+
+
+class WrappedMessage:
+ _message: Message
+ _handler: Optional[callable]
+ _validator: Optional[callable]
+ _logger: Optional[logging.Logger]
+ _phase: MessagePhase
+ _phase_update_time: float
+
+ def __init__(self,
+ message: Message,
+ handler: Optional[callable] = None,
+ validator: Optional[callable] = None,
+ ack=False):
+ self._message = message
+ self._handler = handler
+ self._validator = validator
+ self._logger = None
+ self._phase = MessagePhase.WAITING
+ self._phase_update_time = 0
+ if not validator and ack:
+ self._validator = lambda m: isinstance(m, AckMessage)
+
+ def setlogger(self, logger: logging.Logger):
+ self._logger = logger
+
+ def validate(self, message: Message):
+ if not self._validator:
+ return True
+ return self._validator(message)
+
+ def call(self, *args, error_message: str = None) -> None:
+ if not self._handler:
+ return
+ try:
+ self._handler(*args)
+ except Exception as exc:
+ logger = self._logger or logging.getLogger(self.__class__.__name__)
+ logger.error(f'{error_message}, see exception below:')
+ logger.exception(exc)
+
+ @property
+ def phase(self) -> MessagePhase:
+ return self._phase
+
+ @phase.setter
+ def phase(self, phase: MessagePhase):
+ self._phase = phase
+ self._phase_update_time = 0 if phase == MessagePhase.WAITING else time.time()
+
+ @property
+ def phase_update_time(self) -> float:
+ return self._phase_update_time
+
+ @property
+ def message(self) -> Message:
+ return self._message
+
+ @property
+ def id(self) -> int:
+ return self._message.id
+
+ @property
+ def seq(self) -> int:
+ return self._message.seq
+
+ @seq.setter
+ def seq(self, seq: int):
+ self._message.seq = seq
+
+ def __repr__(self):
+ return f'<{__name__}.{self.__class__.__name__} message={self._message.__repr__()}>'
+
+
+# Connection stuff
+# Well, strictly speaking, as it's UDP, there's no connection, but who cares.
+# ---------------------------------------------------------------------------
+
+class IncomingMessageListener:
+ @abstractmethod
+ def incoming_message(self, message: Message) -> Optional[Message]:
+ pass
+
+
+class ConnectionStatus(Enum):
+ NOT_CONNECTED = auto()
+ CONNECTING = auto()
+ CONNECTED = auto()
+ RECONNECTING = auto()
+ DISCONNECTED = auto()
+
+
+class ConnectionStatusListener:
+ @abstractmethod
+ def connection_status_updated(self, status: ConnectionStatus):
+ pass
+
+
+class UDPConnection(threading.Thread, ConnectionStatusListener):
+ inseq: int
+ outseq: int
source_port: int
device_addr: str
device_port: int
device_token: bytes
+ device_pubkey: bytes
interrupted: bool
- waiting_for_response: dict[int, callable]
+ response_handlers: Dict[int, WrappedMessage]
+ outgoing_queue: List[WrappedMessage]
pubkey: Optional[bytes]
encinkey: Optional[bytes]
encoutkey: Optional[bytes]
+ inc_listeners: List[IncomingMessageListener]
+ conn_listeners: List[ConnectionStatusListener]
+ outgoing_time: float
+ outgoing_time_1st: float
+ incoming_time: float
+ status: ConnectionStatus
+ reconnect_tries: int
+
+ _addr_lock: threading.Lock
+ _iml_lock: threading.Lock
+ _csl_lock: threading.Lock
+ _st_lock: threading.Lock
def __init__(self,
addr: Union[IPv4Address, IPv6Address],
@@ -318,36 +751,105 @@ class Connection(threading.Thread):
device_pubkey: bytes,
device_token: bytes):
super().__init__()
- self.logger = logging.getLogger(__name__+'.'+self.__class__.__name__)
+ self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__} <{hex(id(self))}>')
self.setName(self.__class__.__name__)
- # self.daemon = True
- self.seq_no = -1
+ self.inseq = 0
+ self.outseq = 0
self.source_port = random.randint(1024, 65535)
self.device_addr = str(addr)
self.device_port = port
self.device_token = device_token
- self.lock = threading.Lock()
- self.outgoing_queue = queue.SimpleQueue()
- self.waiting_for_response = {}
+ self.device_pubkey = device_pubkey
+ self.outgoing_queue = []
+ self.response_handlers = {}
self.interrupted = False
+ self.outgoing_time = 0
+ self.outgoing_time_1st = 0
+ self.incoming_time = 0
+ self.inc_listeners = []
+ self.conn_listeners = [self]
+ self.status = ConnectionStatus.NOT_CONNECTED
+ self.reconnect_tries = 0
+
+ self._iml_lock = threading.Lock()
+ self._csl_lock = threading.Lock()
+ self._addr_lock = threading.Lock()
+ self._st_lock = threading.Lock()
self.pubkey = None
self.encinkey = None
self.encoutkey = None
- self.prepare_keys(device_pubkey)
-
- def prepare_keys(self, device_pubkey: bytes):
+ def connection_status_updated(self, status: ConnectionStatus):
+ # self._logger.info(f'connection_status_updated: status = {status}')
+ with self._st_lock:
+ # self._logger.debug(f'connection_status_updated: lock acquired')
+ self.status = status
+ if status == ConnectionStatus.RECONNECTING:
+ self.reconnect_tries += 1
+ if status in (ConnectionStatus.CONNECTED, ConnectionStatus.NOT_CONNECTED, ConnectionStatus.DISCONNECTED):
+ self.reconnect_tries = 0
+
+ def _cleanup(self):
+ # erase outgoing queue
+ for wm in self.outgoing_queue:
+ wm.call(False,
+ error_message=f'_cleanup: exception while calling cb(False) on message {wm.message}')
+ self.outgoing_queue = []
+ self.response_handlers = {}
+
+ # reset timestamps
+ self.incoming_time = 0
+ self.outgoing_time = 0
+ self.outgoing_time_1st = 0
+
+ self._logger.info('_cleanup: done')
+
+ def set_address(self, addr: Union[IPv4Address, IPv6Address], port: int):
+ with self._addr_lock:
+ if self.device_addr != str(addr) or self.device_port != port:
+ self.device_addr = str(addr)
+ self.device_port = port
+ self._logger.info(f'updated device network address: {self.device_addr}:{self.device_port}')
+
+ def set_device_pubkey(self, pubkey: bytes):
+ if self.device_pubkey.hex() != pubkey.hex():
+ self._logger.info(f'device pubkey has changed (old={self.device_pubkey.hex()}, new={pubkey.hex()})')
+ self.device_pubkey = pubkey
+ self._notify_cs(ConnectionStatus.RECONNECTING)
+
+ def get_address(self) -> Tuple[str, int]:
+ with self._addr_lock:
+ return self.device_addr, self.device_port
+
+ def add_incoming_message_listener(self, listener: IncomingMessageListener):
+ with self._iml_lock:
+ if listener not in self.inc_listeners:
+ self.inc_listeners.append(listener)
+
+ def add_connection_status_listener(self, listener: ConnectionStatusListener):
+ with self._csl_lock:
+ if listener not in self.conn_listeners:
+ self.conn_listeners.append(listener)
+
+ def _notify_cs(self, status: ConnectionStatus):
+ # self._logger.debug(f'_notify_cs: status={status}')
+ with self._csl_lock:
+ for obj in self.conn_listeners:
+ # self._logger.debug(f'_notify_cs: notifying {obj}')
+ obj.connection_status_updated(status)
+
+ def _prepare_keys(self):
# generate key pair
privkey = X25519PrivateKey.generate()
- self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=cryptography.hazmat.primitives._serialization.Encoding.Raw,
- format=cryptography.hazmat.primitives._serialization.PublicFormat.Raw)))
+ self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=srlz.Encoding.Raw,
+ format=srlz.PublicFormat.Raw)))
# generate shared key
device_pubkey = X25519PublicKey.from_public_bytes(
- bytes(reversed(device_pubkey))
+ bytes(reversed(self.device_pubkey))
)
shared_key = bytes(reversed(
privkey.exchange(device_pubkey)
@@ -362,51 +864,286 @@ class Connection(threading.Thread):
self.encinkey = shared_sha256[:16]
self.encoutkey = shared_sha256[16:]
+ self._logger.info('encryption keys have been created')
+
+ def _handshake_callback(self, r: MessageResponse):
+ # if got error for our HandshakeMessage, reset everything and try again
+ if r is False:
+ # self._logger.debug('_handshake_callback: set status=RECONNETING')
+ self._notify_cs(ConnectionStatus.RECONNECTING)
+ else:
+ # self._logger.debug('_handshake_callback: set status=CONNECTED')
+ self._notify_cs(ConnectionStatus.CONNECTED)
+
def run(self):
+ self._logger.info('starting server loop')
+
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(('0.0.0.0', self.source_port))
- sock.settimeout(1)
+ sock.settimeout(READ_TIMEOUT)
while not self.interrupted:
- if not self.outgoing_queue.empty():
- message = self.outgoing_queue.get()
- message.encrypt(outkey=self.encoutkey,
- inkey=self.encinkey,
- token=self.device_token,
- pubkey=self.pubkey)
- buf = message.frame.pack()
- self.logger.debug('send: '+buf.hex())
- self.logger.debug(f'sendto: {self.device_addr}:{self.device_port}')
- sock.sendto(buf, (self.device_addr, self.device_port))
+ with self._st_lock:
+ status = self.status
+
+ if status in (ConnectionStatus.DISCONNECTED, ConnectionStatus.RECONNECTING):
+ self._cleanup()
+ if status == ConnectionStatus.DISCONNECTED:
+ break
+
+ # no activity for some time means connection is broken
+ fail = False
+ fail_path = 0
+ if self.incoming_time > 0 and time.time() - self.incoming_time >= DISCONNECT_TIMEOUT:
+ fail = True
+ fail_path = 1
+ elif self.outgoing_time_1st > 0 and self.incoming_time == 0 and time.time() - self.outgoing_time_1st >= DISCONNECT_TIMEOUT:
+ fail = True
+ fail_path = 2
+
+ if fail:
+ self._logger.debug(f'run: setting status=RECONNECTING because of long inactivity, fail_path={fail_path}')
+ self._notify_cs(ConnectionStatus.RECONNECTING)
+
+ # establishing a connection
+ if status in (ConnectionStatus.RECONNECTING, ConnectionStatus.NOT_CONNECTED):
+ if status == ConnectionStatus.RECONNECTING and self.reconnect_tries >= 3:
+ self._notify_cs(ConnectionStatus.DISCONNECTED)
+ continue
+
+ self._reset_outseq()
+ self._prepare_keys()
+
+ # shake the imaginary kettle's hand
+ wrapped = WrappedMessage(HandshakeMessage(),
+ handler=self._handshake_callback,
+ validator=lambda m: isinstance(m, (AckMessage, HandshakeResponseMessage)))
+ self.enqueue_message(wrapped, prepend=True)
+ self._notify_cs(ConnectionStatus.CONNECTING)
+
+ # pick next (wrapped) message to send
+ wm = self._get_next_message() # wm means "wrapped message"
+ if wm:
+ if not isinstance(wm.message, (AckMessage, NakMessage)):
+ old_seq = wm.seq
+ wm.seq = self.outseq
+ self._set_response_handler(wm, old_seq=old_seq)
+ elif wm.seq is None:
+ # ack/nak is a response to some incoming message (and it must have the same seqno that incoming
+ # message had)
+ raise RuntimeError(f'run: seq must be set for {wm.__class__.__name__}')
+
+ self._logger.debug(f'run: sending message: {wm.message}')
+ wm.message.encrypt(outkey=self.encoutkey, inkey=self.encinkey,
+ token=self.device_token, pubkey=self.pubkey)
+ buf = wm.message.frame.pack()
+ one_shot = isinstance(wm.message, (AckMessage, NakMessage, PingMessage))
+ # self._logger.debug(f'run: raw data to be sent: {buf.hex()}')
+
+ # sending the first time
+ if wm.phase == MessagePhase.WAITING:
+ sock.sendto(buf, self.get_address())
+ # resending
+ elif wm.phase == MessagePhase.SENT:
+ left = RESEND_ATTEMPTS
+ while left > 0:
+ sock.sendto(buf, self.get_address())
+ left -= 1
+ if left > 0:
+ time.sleep(0.05)
+
+ if one_shot or wm.phase == MessagePhase.SENT:
+ wm.phase = MessagePhase.DONE
+ else:
+ wm.phase = MessagePhase.SENT
+
+ now = time.time()
+ self.outgoing_time = now
+ if not self.outgoing_time_1st:
+ self.outgoing_time_1st = now
+
+ # receiving data
try:
data = sock.recv(4096)
- self.handle_incoming(data)
- except TimeoutError:
+ self._handle_incoming(data)
+ except (TimeoutError, socket.timeout):
pass
- def handle_incoming(self, buf: bytes):
- self.logger.debug('handle_incoming: '+buf.hex())
- message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey)
- if message.frame.head.seq in self.waiting_for_response:
- self.logger.info(f'received awaited message: {message}')
+ self._logger.info('bye...')
+
+ def _get_next_message(self) -> Optional[WrappedMessage]:
+ message = None
+ lpfx = '_get_next_message:'
+ remove_list = []
+ for wm in self.outgoing_queue:
+ if wm.phase == MessagePhase.DONE:
+ if isinstance(wm.message, (AckMessage, NakMessage)) or time.time() - wm.phase_update_time >= MESSAGE_QUEUE_REMOVE_DELAY:
+ remove_list.append(wm)
+ continue
+ message = wm
+ break
+
+ for wm in remove_list:
+ self._logger.debug(f'{lpfx} rm path: removing id={wm.id} seq={wm.seq}')
+
+ # clear message handler
+ if wm.seq in self.response_handlers:
+ self.response_handlers[wm.seq].call(
+ False, error_message=f'{lpfx} rm path: error while calling callback for seq={wm.seq}')
+ del self.response_handlers[wm.seq]
+
+ # remove from queue
try:
- f = self.waiting_for_response[message.frame.head.seq]
- f(message)
- except Exception as exc:
- self.logger.exception(exc)
- finally:
- del self.waiting_for_response[message.frame.head.seq]
+ self.outgoing_queue.remove(wm)
+ except ValueError as exc:
+ self._logger.error(f'{lpfx} rm path: removing from outgoing_queue raised an exception: {str(exc)}')
+
+ # ping pong
+ if self.outgoing_time_1st != 0 and self.status == ConnectionStatus.CONNECTED:
+ now = time.time()
+ out_delta = now - self.outgoing_time
+ in_delta = now - self.incoming_time
+ if not message and max(out_delta, in_delta) > PING_FREQUENCY:
+ self._logger.debug(f'{lpfx} no activity: in for {in_delta:.2f}s, out for {out_delta:.2f}s, time to ping the damn thing')
+ message = WrappedMessage(PingMessage(), ack=True)
+
+ return message
+
+ def _handle_incoming(self, buf: bytes):
+ self.incoming_time = time.time()
+
+ incoming_message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey)
+ seq = incoming_message.seq
+
+ lpfx = f'handle_incoming({incoming_message.id}):'
+ self._logger.debug(f'{lpfx} received: {incoming_message}')
+
+ if isinstance(incoming_message, (AckMessage, NakMessage)):
+ seq_max = self.outseq
+ seq_name = 'outseq'
+ else:
+ seq_max = self.inseq
+ seq_name = 'inseq'
+ self.inseq = seq
+
+ if seq < seq_max < 0xfd:
+ self._logger.warning(f'{lpfx} dropping: seq={seq}, {seq_name}={seq_name}')
+ return
+
+ if seq not in self.response_handlers:
+ self._handle_incoming_cmd(incoming_message)
+ return
+
+ callback_value = None # None means don't call a callback
+ handler = self.response_handlers[seq]
+
+ if handler.validate(incoming_message):
+ self._logger.info(f'{lpfx} response OK')
+ handler.phase = MessagePhase.DONE
+ callback_value = incoming_message
+ self._incr_outseq()
else:
- self.logger.info(f'received message (not awaited): {message}')
-
- def send_message(self, message: Message, callback: callable):
- seq = self.next_seqno()
- message.set_seq(seq)
- self.outgoing_queue.put(message)
- self.waiting_for_response[seq] = callback
-
- def next_seqno(self) -> int:
- with self.lock:
- self.seq_no += 1
- self.logger.debug(f'next_seqno: set to {self.seq_no}')
- return self.seq_no
+ self._logger.info(f'{lpfx} response is INVALID')
+
+ # It seems that we've received an incoming CmdMessage or PingMessage with the same seqno that our outgoing
+ # message had. Bad, but what can I say, this is quick-and-dirty made UDP based protocol and this sort of
+ # shit just happens.
+
+ # (To be fair, maybe my implementation is not perfect either. But hey, what did you expect from a
+ # reverse-engineered re-implementation of custom UDP-based protocol that some noname vendor uses for their
+ # cheap IoT devices? I think _that_ is _the_ definition of shit. At least my implementation is FOSS, which
+ # is more than you'll ever be able to say about them.)
+
+ # All this crapload of code below might not be needed at all, 'cause the protocol uses separate frame seq
+ # numbers for IN and OUT frames and this situation is not highly likely, as Theresa May could argue.
+ # After a handshake, a kettle sends us 10 or so CmdMessages, and then either we continuously ping it every
+ # 3 seconds, or kettle pings us. This in any case widens the gap between inseq and outseq.
+
+ # But! the seqno is only 1 byte in size and once it reaches 0xff, it circles back to zero. And that (plus,
+ # perhaps, some bad luck) gives a chance for a collision.
+
+ if handler.phase == MessagePhase.DONE or isinstance(handler.message, HandshakeMessage):
+ # no more attempts left, returning error back to user
+ # as to handshake, it cannot fail.
+ callback_value = False
+
+ # else:
+ # # try resending the message
+ # handler.phase_reset()
+ # max_seq = self.outseq
+ # wait_remap = {}
+ # for m in self.outgoing_queue:
+ # if m.seq in self.waiting_for_response:
+ # wait_remap[m.seq] = (m.seq+1) % 256
+ # m.set_seq((m.seq+1) % 256)
+ # if m.seq > max_seq:
+ # max_seq = m.seq
+ # if max_seq > self.outseq:
+ # self.outseq = max_seq % 256
+ # if wait_remap:
+ # waiting_new = {}
+ # for old_seq, new_seq in wait_remap.items():
+ # waiting_new[new_seq] = self.waiting_for_response[old_seq]
+ # self.waiting_for_response = waiting_new
+
+ if isinstance(incoming_message, (PingMessage, CmdIncomingMessage)):
+ # handle incoming message as usual, as we need to ack/nak it anyway
+ self._handle_incoming_cmd(incoming_message)
+
+ if callback_value is not None:
+ handler.call(callback_value,
+ error_message=f'{lpfx} error while calling callback for msg id={handler.message.id} seq={seq}')
+ del self.response_handlers[seq]
+
+ def _handle_incoming_cmd(self, incoming_message: Message):
+ if isinstance(incoming_message, (AckMessage, NakMessage)):
+ self._logger.debug(f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): it\'s {incoming_message.__class__.__name__}, ignoring')
+ return
+
+ replied = False
+ with self._iml_lock:
+ for f in self.inc_listeners:
+ retval = safe_callback_call(f.incoming_message, incoming_message,
+ logger=self._logger,
+ error_message=f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): error while calling message listener')
+ if isinstance(retval, Message):
+ if isinstance(retval, (AckMessage, NakMessage)):
+ retval.seq = incoming_message.seq
+ self.enqueue_message(WrappedMessage(retval), prepend=True)
+ replied = True
+ break
+ else:
+ raise RuntimeError('are you sure your response is correct? only ack/nak are allowed')
+
+ if not replied:
+ self.enqueue_message(WrappedMessage(AckMessage(incoming_message.seq)), prepend=True)
+
+ def enqueue_message(self, wrapped: WrappedMessage, prepend=False):
+ self._logger.debug(f'enqueue_message: {wrapped.message}')
+ if not prepend:
+ self.outgoing_queue.append(wrapped)
+ else:
+ self.outgoing_queue.insert(0, wrapped)
+
+ def _set_response_handler(self, wm: WrappedMessage, old_seq=None):
+ if old_seq in self.response_handlers:
+ del self.response_handlers[old_seq]
+
+ seq = wm.seq
+ assert seq is not None, 'seq is not set'
+
+ if seq in self.response_handlers:
+ self._logger.warning(f'_set_response_handler(seq={seq}): handler is already set, cancelling it')
+ self.response_handlers[seq].call(False,
+ error_message=f'_set_response_handler({seq}): error while calling old callback')
+ self.response_handlers[seq] = wm
+
+ def _incr_outseq(self) -> None:
+ self.outseq = (self.outseq + 1) % 256
+
+ def _reset_outseq(self):
+ self.outseq = 0
+ self._logger.debug(f'_reset_outseq: set 0')
+
+
+MessageResponse = Union[Message, bool]
diff --git a/src/polaris_kettle_bot.py b/src/polaris_kettle_bot.py
new file mode 100644
index 0000000..bad1b3a
--- /dev/null
+++ b/src/polaris_kettle_bot.py
@@ -0,0 +1,684 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import logging
+import locale
+import queue
+import time
+import threading
+import paho.mqtt.client as mqtt
+
+from home.bot import Wrapper, Context, text_filter, handlermethod
+from home.api.types import BotType
+from home.mqtt import MQTTBase
+from home.config import config
+from polaris import (
+ Kettle,
+ PowerType,
+ DeviceListener,
+ IncomingMessageListener,
+ ConnectionStatusListener,
+ ConnectionStatus
+)
+import polaris.protocol as kettle_proto
+from typing import Optional, Tuple, List
+from collections import namedtuple
+from functools import partial
+from datetime import datetime
+from abc import abstractmethod
+from telegram.error import TelegramError
+from telegram import (
+ ReplyKeyboardMarkup,
+ InlineKeyboardMarkup,
+ InlineKeyboardButton,
+ Message
+)
+from telegram.ext import (
+ CallbackQueryHandler,
+ MessageHandler,
+ CommandHandler
+)
+
+logger = logging.getLogger(__name__)
+kc: Optional[KettleController] = None
+bot: Optional[Wrapper] = None
+RenderedContent = Tuple[str, Optional[InlineKeyboardMarkup]]
+tasks_lock = threading.Lock()
+
+
+class KettleInfoListener:
+ @abstractmethod
+ def info_updated(self, field: str):
+ pass
+
+
+# class that holds data coming from the kettle over mqtt
+class KettleInfo:
+ update_time: int
+ _mode: Optional[PowerType]
+ _temperature: Optional[int]
+ _target_temperature: Optional[int]
+ _update_listener: KettleInfoListener
+
+ def __init__(self, update_listener: KettleInfoListener):
+ self.update_time = 0
+ self._mode = None
+ self._temperature = None
+ self._target_temperature = None
+ self._update_listener = update_listener
+
+ def _update(self, field: str):
+ self.update_time = int(time.time())
+ if self._update_listener:
+ self._update_listener.info_updated(field)
+
+ @property
+ def temperature(self) -> int:
+ return self._temperature
+
+ @temperature.setter
+ def temperature(self, value: int):
+ self._temperature = value
+ self._update('temperature')
+
+ @property
+ def mode(self) -> PowerType:
+ return self._mode
+
+ @mode.setter
+ def mode(self, value: PowerType):
+ self._mode = value
+ self._update('mode')
+
+ @property
+ def target_temperature(self) -> int:
+ return self._target_temperature
+
+ @target_temperature.setter
+ def target_temperature(self, value: int):
+ self._target_temperature = value
+ self._update('target_temperature')
+
+
+class KettleController(threading.Thread,
+ MQTTBase,
+ DeviceListener,
+ IncomingMessageListener,
+ KettleInfoListener,
+ ConnectionStatusListener):
+ kettle: Kettle
+ info: KettleInfo
+
+ _logger: logging.Logger
+ _stopped: bool
+ _restart_server_at: int
+ _lock: threading.Lock
+ _info_lock: threading.Lock
+ _accumulated_updates: dict
+ _info_flushed_time: float
+ _mqtt_root_topic: str
+ _muts: List[MessageUpdatingTarget]
+
+ def __init__(self):
+ # basic setup
+ MQTTBase.__init__(self, clean_session=False)
+ threading.Thread.__init__(self)
+
+ self._logger = logging.getLogger(self.__class__.__name__)
+
+ self.kettle = Kettle(mac=config['kettle']['mac'],
+ device_token=config['kettle']['token'])
+ self.kettle_reconnect()
+
+ # info
+ self.info = KettleInfo(update_listener=self)
+ self._accumulated_updates = {}
+ self._info_flushed_time = 0
+
+ # mqtt
+ self._mqtt_root_topic = '/polaris/6/'+config['kettle']['token']+'/#'
+ self.connect_and_loop(loop_forever=False)
+
+ # thread loop related
+ self._stopped = False
+ # self._lock = threading.Lock()
+ self._info_lock = threading.Lock()
+ self._restart_server_at = 0
+
+ # bot
+ self._muts = []
+ self._muts_lock = threading.Lock()
+
+ self.start()
+
+ def kettle_reconnect(self):
+ self.kettle.discover(wait=False, listener=self)
+
+ def stop_all(self):
+ self.kettle.stop_all()
+ self._stopped = True
+
+ def add_updating_message(self, mut: MessageUpdatingTarget):
+ with self._muts_lock:
+ for m in self._muts:
+ if m.user_id == m.user_id and m.user_did_turn_on() or m.user_did_turn_on() != mut.user_did_turn_on():
+ m.delete()
+ self._muts.append(mut)
+
+ # ---------------------
+ # threading.Thread impl
+
+ def run(self):
+ while not self._stopped:
+ # do_restart_srv = False
+ #
+ # with self._lock:
+ # if self._restart_server_at != 0 and time.time() - self._restart_server_at:
+ # self._restart_server_at = 0
+ # do_restart_srv = True
+ #
+ # if do_restart_srv:
+ # self.kettle_connect()
+
+ updates = []
+ deletions = []
+ with self._muts_lock and self._info_lock:
+ # self._logger.debug('muts size: '+str(len(self._muts)))
+ if self._muts and self._accumulated_updates and (self._info_flushed_time == 0 or time.time() - self._info_flushed_time >= 1):
+ forget = []
+ deletions = []
+
+ for mut in self._muts:
+ upd = mut.update(
+ mode=self.info.mode,
+ current_temp=self.info.temperature,
+ target_temp=self.info.target_temperature)
+
+ if upd.finished or upd.delete:
+ forget.append(mut)
+
+ if upd.delete:
+ deletions.append(upd)
+ elif upd.changed:
+ updates.append(upd)
+
+ if forget:
+ for mut in forget:
+ self._logger.debug(f'loop: removing mut {mut}')
+ self._muts.remove(mut)
+
+ self._info_flushed_time = time.time()
+ self._accumulated_updates = {}
+
+ for upd in updates:
+ self._logger.debug(f'loop: got update: {upd}')
+ try:
+ bot.edit_message_text(upd.user_id, upd.message_id,
+ text=upd.html,
+ reply_markup=upd.markup)
+ except TelegramError as exc:
+ self._logger.error(f'loop: edit_message_text failed for update: {upd}')
+ self._logger.exception(exc)
+
+ for upd in deletions:
+ self._logger.debug(f'loop: got deletion: {upd}')
+ try:
+ bot.delete_message(upd.user_id, upd.message_id)
+ except TelegramError as exc:
+ self._logger.error(f'loop: delete_message failed for update: {upd}')
+ self._logger.exception(exc)
+
+ time.sleep(0.5)
+
+ # -------------------
+ # DeviceListener impl
+
+ def device_updated(self):
+ self._logger.info(f'device updated: {self.kettle.device.si}')
+ self.kettle.start_server_if_needed(incoming_message_listener=self,
+ connection_status_listener=self)
+
+ # -----------------------
+ # KettleInfoListener impl
+
+ def info_updated(self, field: str):
+ with self._info_lock:
+ newval = getattr(self.info, field)
+ self._logger.debug(f'info_updated: updated {field}, new value is {newval}')
+ self._accumulated_updates[field] = newval
+
+ # ----------------------------
+ # IncomingMessageListener impl
+
+ def incoming_message(self, message: kettle_proto.Message) -> Optional[kettle_proto.Message]:
+ self._logger.info(f'incoming message: {message}')
+
+ if isinstance(message, kettle_proto.ModeMessage):
+ self.info.mode = message.pt
+ elif isinstance(message, kettle_proto.CurrentTemperatureMessage):
+ self.info.temperature = message.current_temperature
+ elif isinstance(message, kettle_proto.TargetTemperatureMessage):
+ self.info.target_temperature = message.temperature
+
+ return kettle_proto.AckMessage()
+
+ # -----------------------------
+ # ConnectionStatusListener impl
+
+ def connection_status_updated(self, status: ConnectionStatus):
+ self._logger.info(f'connection status updated: {status}')
+ if status == ConnectionStatus.DISCONNECTED:
+ self.kettle.stop_all()
+ self.kettle_reconnect()
+
+ # -------------
+ # MQTTBase impl
+
+ def on_connect(self, client: mqtt.Client, userdata, flags, rc):
+ super().on_connect(client, userdata, flags, rc)
+ client.subscribe(self._mqtt_root_topic, qos=1)
+ self._logger.info(f'subscribed to {self._mqtt_root_topic}')
+
+ def on_message(self, client: mqtt.Client, userdata, msg):
+ try:
+ topic = msg.topic[len(self._mqtt_root_topic)-2:]
+ pld = msg.payload.decode()
+
+ self._logger.debug(f'mqtt: on message: topic={topic} pld={pld}')
+
+ if topic == 'state/sensor/temperature':
+ self.info.temperature = int(float(pld))
+ elif topic == 'state/mode':
+ self.info.mode = PowerType(int(pld))
+ elif topic == 'state/temperature':
+ self.info.target_temperature = int(float(pld))
+
+ except Exception as e:
+ self._logger.exception(str(e))
+
+
+class Renderer:
+ @classmethod
+ def index(cls, ctx: Context) -> RenderedContent:
+ html = f'<b>{ctx.lang("settings")}</b>\n\n'
+ html += ctx.lang('select_place')
+ return html, None
+
+ @classmethod
+ def status(cls, ctx: Context,
+ connected: bool,
+ mode: PowerType,
+ current_temp: int,
+ target_temp: int,
+ update_time: int) -> RenderedContent:
+ if not connected:
+ return cls.not_connected(ctx)
+ else:
+ # power status
+ if mode != PowerType.OFF:
+ html = ctx.lang('status_on', target_temp)
+ else:
+ html = ctx.lang('status_off')
+
+ # current temperature
+ html += '\n'
+ html += ctx.lang('status_current_temp', current_temp)
+
+ # updated on
+ html += '\n'
+ html += cls.updated(ctx, update_time)
+
+ return html, None
+
+ @classmethod
+ def turned_on(cls, ctx: Context,
+ target_temp: int,
+ current_temp: int,
+ mode: PowerType,
+ update_time: Optional[int] = None,
+ reached=False,
+ no_keyboard=False) -> RenderedContent:
+ if mode == PowerType.OFF and not reached:
+ html = ctx.lang('enabling')
+ else:
+ if not reached:
+ emoji = '♨️' if current_temp <= 90 else '🔥'
+ html = ctx.lang('enabled', emoji, target_temp)
+
+ # current temperature
+ html += '\n'
+ html += ctx.lang('status_current_temp', current_temp)
+ else:
+ html = ctx.lang('enabled_reached', current_temp)
+
+ # updated on
+ if not reached and update_time is not None:
+ html += '\n'
+ html += cls.updated(ctx, update_time)
+
+ return html, None if no_keyboard else cls.wait_buttons(ctx)
+
+ @classmethod
+ def turned_off(cls, ctx: Context,
+ mode: PowerType,
+ update_time: Optional[int] = None,
+ reached=False,
+ no_keyboard=False) -> RenderedContent:
+ if mode != PowerType.OFF:
+ html = ctx.lang('disabling')
+ else:
+ html = ctx.lang('disabled')
+
+ # updated on
+ if not reached and update_time is not None:
+ html += '\n'
+ html += cls.updated(ctx, update_time)
+
+ return html, None if no_keyboard else cls.wait_buttons(ctx)
+
+ @classmethod
+ def not_connected(cls, ctx: Context) -> RenderedContent:
+ return ctx.lang('status_not_connected'), None
+
+ @classmethod
+ def smth_went_wrong(cls, ctx: Context) -> RenderedContent:
+ html = ctx.lang('smth_went_wrong')
+ return html, None
+
+ @classmethod
+ def updated(cls, ctx: Context, update_time: int):
+ locale_bak = locale.getlocale(locale.LC_TIME)
+ locale.setlocale(locale.LC_TIME, 'ru_RU.UTF-8' if ctx.user_lang == 'ru' else 'en_US.UTF-8')
+ dt = datetime.fromtimestamp(update_time)
+ html = ctx.lang('status_update_time', dt.strftime(ctx.lang('status_update_time_fmt')))
+ locale.setlocale(locale.LC_TIME, locale_bak)
+ return html
+
+ @classmethod
+ def wait_buttons(cls, ctx: Context):
+ return InlineKeyboardMarkup([
+ [
+ InlineKeyboardButton(ctx.lang('please_wait'), callback_data='wait')
+ ]
+ ])
+
+
+def run_tasks(tasks: queue.SimpleQueue, done: callable):
+ def next_task(r: Optional[kettle_proto.MessageResponse]):
+ if r is not None:
+ try:
+ assert r is not False, 'server error'
+ except AssertionError as exc:
+ logger.exception(exc)
+ tasks_lock.release()
+ return done(False)
+
+ if not tasks.empty():
+ task = tasks.get()
+ args = task[1:]
+ args.append(next_task)
+ f = getattr(kc.kettle, task[0])
+ f(*args)
+ else:
+ tasks_lock.release()
+ return done(True)
+
+ tasks_lock.acquire()
+ next_task(None)
+
+
+MUTUpdate = namedtuple('MUTUpdate', 'message_id, user_id, finished, changed, delete, html, markup')
+
+
+class MessageUpdatingTarget:
+ ctx: Context
+ message: Message
+ user_target_temp: Optional[int]
+ user_enabled_power_mode: PowerType
+ initial_power_mode: PowerType
+ need_to_delete: bool
+ rendered_content: Optional[RenderedContent]
+
+ def __init__(self,
+ ctx: Context,
+ message: Message,
+ user_enabled_power_mode: PowerType,
+ initial_power_mode: PowerType,
+ user_target_temp: Optional[int] = None):
+ self.ctx = ctx
+ self.message = message
+ self.initial_power_mode = initial_power_mode
+ self.user_enabled_power_mode = user_enabled_power_mode
+ self.ignore_pm = initial_power_mode is PowerType.OFF and self.user_did_turn_on()
+ self.user_target_temp = user_target_temp
+ self.need_to_delete = False
+ self.rendered_content = None
+ self.last_reported_temp = None
+
+ def set_rendered_content(self, content: RenderedContent):
+ self.rendered_content = content
+
+ def rendered_content_changed(self, content: RenderedContent) -> bool:
+ return content != self.rendered_content
+
+ def update(self,
+ mode: PowerType,
+ current_temp: int,
+ target_temp: int) -> MUTUpdate:
+
+ # determine whether status updating is finished
+ finished = False
+ reached = False
+ if self.ignore_pm:
+ if mode != PowerType.OFF:
+ self.ignore_pm = False
+ elif mode == PowerType.OFF:
+ reached = True
+ if self.user_did_turn_on():
+ # when target is 100 degrees, this kettle sometimes turns off at 91, sometimes at 95, sometimes at 98.
+ # it's totally unpredictable, so in this case, we keep updating the message until it reaches at least 97
+ # degrees, or if temperature started dropping.
+ if self.user_target_temp < 100 \
+ or current_temp >= self.user_target_temp - 3 \
+ or current_temp < self.last_reported_temp:
+ finished = True
+ else:
+ finished = True
+
+ self.last_reported_temp = current_temp
+
+ # render message
+ if self.user_did_turn_on():
+ rc = Renderer.turned_on(self.ctx,
+ target_temp=target_temp,
+ current_temp=current_temp,
+ mode=mode,
+ reached=reached,
+ no_keyboard=finished)
+ else:
+ rc = Renderer.turned_off(self.ctx,
+ mode=mode,
+ reached=reached,
+ no_keyboard=finished)
+
+ changed = self.rendered_content_changed(rc)
+ update = MUTUpdate(message_id=self.message.message_id,
+ user_id=self.ctx.user_id,
+ finished=finished,
+ changed=changed,
+ delete=self.need_to_delete,
+ html=rc[0],
+ markup=rc[1])
+ if changed:
+ self.set_rendered_content(rc)
+ return update
+
+ def user_did_turn_on(self) -> bool:
+ return self.user_enabled_power_mode in (PowerType.ON, PowerType.CUSTOM)
+
+ def delete(self):
+ self.need_to_delete = True
+
+ @property
+ def user_id(self) -> int:
+ return self.ctx.user_id
+
+
+class KettleBot(Wrapper):
+ def __init__(self):
+ super().__init__()
+
+ self.lang.ru(
+ start_message="Выберите команду на клавиатуре",
+ unknown_command="Неизвестная команда",
+ unexpected_callback_data="Ошибка: неверные данные",
+ enable_70="♨️ 70 °C",
+ enable_80="♨️ 80 °C",
+ enable_90="♨️ 90 °C",
+ enable_100="🔥 100 °C",
+ disable="❌ Выключить",
+ server_error="Ошибка сервера",
+
+ # /status
+ status_not_connected="😟 Связь с чайником не установлена",
+ status_on="✅ Чайник <b>включён</b> (до <b>%d °C</b>)",
+ status_off="❌ Чайник <b>выключен</b>",
+ status_current_temp="Сейчас: <b>%d °C</b>",
+ status_update_time="<i>Обновлено %s</i>",
+ status_update_time_fmt="%d %b в %H:%M:%S",
+
+ # enable
+ enabling="💤 Чайник включается...",
+ disabling="💤 Чайник выключается...",
+ enabled="%s Чайник <b>включён</b>.\nЦель: <b>%d °C</b>",
+ enabled_reached="✅ <b>Готово!</b> Чайник вскипел, температура <b>%d °C</b>.",
+ disabled="✅ Чайник <b>выключен</b>.",
+ please_wait="⏳ Ожидайте..."
+ )
+
+ self.lang.en(
+ start_message="Select command on the keyboard",
+ unknown_command="Unknown command",
+ unexpected_callback_data="Unexpected callback data",
+ enable_70="♨️ 70 °C",
+ enable_80="♨️ 80 °C",
+ enable_90="♨️ 90 °C",
+ enable_100="🔥 100 °C",
+ disable="❌ Turn OFF",
+ server_error="Server error",
+
+ # /status
+ not_connected="😟 Connection has not been established",
+ status_on="✅ Turned <b>ON</b>! Target: <b>%d °C</b>",
+ status_off="❌ Turned <b>OFF</b>",
+ status_current_temp="Now: <b>%d °C</b>",
+ status_update_time="<i>Updated on %s</i>",
+ status_update_time_fmt="%b %d, %Y at %H:%M:%S",
+
+ # enable
+ enabling="💤 Turning on...",
+ disabling="💤 Turning off...",
+ enabled="%s The kettle is <b>turned ON</b>.\nTarget: <b>%d °C</b>",
+ enabled_reached="✅ It's <b>done</b>! The kettle has boiled, the temperature is <b>%d °C</b>.",
+ disabled="✅ The kettle is <b>turned OFF</b>.",
+ please_wait="⏳ Please wait..."
+ )
+
+ # commands
+ self.add_handler(CommandHandler('status', self.status))
+
+ # messages
+ for temp in (70, 80, 90, 100):
+ self.add_handler(MessageHandler(text_filter(self.lang.all(f'enable_{temp}')), self.wrap(partial(self.on, temp))))
+
+ self.add_handler(MessageHandler(text_filter(self.lang.all('disable')), self.off))
+
+ def markup(self, ctx: Optional[Context]) -> Optional[ReplyKeyboardMarkup]:
+ buttons = [
+ [ctx.lang(f'enable_{x}') for x in (70, 80, 90, 100)],
+ [ctx.lang('disable')]
+ ]
+ return ReplyKeyboardMarkup(buttons, one_time_keyboard=False)
+
+ def on(self, temp: int, ctx: Context) -> None:
+ if not kc.kettle.is_connected():
+ text, markup = Renderer.not_connected(ctx)
+ ctx.reply(text, markup=markup)
+ return
+
+ tasks = queue.SimpleQueue()
+ if temp == 100:
+ power_mode = PowerType.ON
+ else:
+ power_mode = PowerType.CUSTOM
+ tasks.put(['set_target_temperature', temp])
+ tasks.put(['set_power', power_mode])
+
+ def done(ok: bool):
+ if not ok:
+ html, markup = Renderer.smth_went_wrong(ctx)
+ else:
+ html, markup = Renderer.turned_on(ctx,
+ target_temp=temp,
+ current_temp=kc.info.temperature,
+ mode=kc.info.mode)
+ message = ctx.reply(html, markup=markup)
+ logger.info(f'ctx.reply returned message: {message}')
+
+ mut = MessageUpdatingTarget(ctx, message,
+ initial_power_mode=kc.info.mode,
+ user_enabled_power_mode=power_mode,
+ user_target_temp=temp)
+ mut.set_rendered_content((html, markup))
+ kc.add_updating_message(mut)
+
+ run_tasks(tasks, done)
+
+ @handlermethod
+ def off(self, ctx: Context) -> None:
+ if not kc.kettle.is_connected():
+ text, markup = Renderer.not_connected(ctx)
+ ctx.reply(text, markup=markup)
+ return
+
+ def done(ok: bool):
+ if not ok:
+ html, markup = Renderer.smth_went_wrong(ctx)
+ else:
+ html, markup = Renderer.turned_off(ctx, mode=kc.info.mode)
+ message = ctx.reply(html, markup=markup)
+ logger.info(f'ctx.reply returned message: {message}')
+
+ mut = MessageUpdatingTarget(ctx, message,
+ initial_power_mode=kc.info.mode,
+ user_enabled_power_mode=PowerType.OFF)
+ mut.set_rendered_content((html, markup))
+ kc.add_updating_message(mut)
+
+ tasks = queue.SimpleQueue()
+ tasks.put(['set_power', PowerType.OFF])
+ run_tasks(tasks, done)
+
+ @handlermethod
+ def status(self, ctx: Context):
+ text, markup = Renderer.status(ctx,
+ connected=kc.kettle.is_connected(),
+ mode=kc.info.mode,
+ current_temp=kc.info.temperature,
+ target_temp=kc.info.target_temperature,
+ update_time=kc.info.update_time)
+ return ctx.reply(text, markup=markup)
+
+
+if __name__ == '__main__':
+ config.load('polaris_kettle_bot')
+
+ kc = KettleController()
+
+ bot = KettleBot()
+ if 'api' in config:
+ bot.enable_logging(BotType.POLARIS_KETTLE)
+ bot.run()
+
+ # bot library handles signals, so when sigterm or something like that happens, we should stop all other threads here
+ kc.stop_all()
diff --git a/src/polaris_kettle_util.py b/src/polaris_kettle_util.py
index 419739b..82b2588 100755
--- a/src/polaris_kettle_util.py
+++ b/src/polaris_kettle_util.py
@@ -3,35 +3,23 @@
import logging
import sys
-import time
import paho.mqtt.client as mqtt
-# from datetime import datetime
-# from html import escape
+from typing import Optional
from argparse import ArgumentParser
from queue import SimpleQueue
-# from home.bot import Wrapper, Context
-# from home.api.types import BotType
-# from home.util import parse_addr
from home.mqtt import MQTTBase
from home.config import config
-from polaris import Kettle, Message, FrameType, PowerType
-
-# from telegram.error import TelegramError
-# from telegram import ReplyKeyboardMarkup, InlineKeyboardMarkup, InlineKeyboardButton
-# from telegram.ext import (
-# CallbackQueryHandler,
-# MessageHandler,
-# CommandHandler
-# )
-
+from polaris import (
+ Kettle,
+ PowerType,
+ protocol as kettle_proto
+)
+k: Optional[Kettle] = None
logger = logging.getLogger(__name__)
control_tasks = SimpleQueue()
-# bot: Optional[Wrapper] = None
-# RenderedContent = tuple[str, Optional[InlineKeyboardMarkup]]
-
class MQTTServer(MQTTBase):
def __init__(self):
@@ -50,65 +38,29 @@ class MQTTServer(MQTTBase):
logger.exception(str(e))
-# class Renderer:
-# @classmethod
-# def index(cls, ctx: Context) -> RenderedContent:
-# html = f'<b>{ctx.lang("settings")}</b>\n\n'
-# html += ctx.lang('select_place')
-# return html, None
-
-
-# status handler
-# --------------
-
-# def status(ctx: Context):
-# text, markup = Renderer.index(ctx)
-# return ctx.reply(text, markup=markup)
-
-
-# class SoundBot(Wrapper):
-# def __init__(self):
-# super().__init__()
-#
-# self.lang.ru(
-# start_message="Выберите команду на клавиатуре",
-# unknown_command="Неизвестная команда",
-# unexpected_callback_data="Ошибка: неверные данные",
-# status="Статус",
-# )
-#
-# self.lang.en(
-# start_message="Select command on the keyboard",
-# unknown_command="Unknown command",
-# unexpected_callback_data="Unexpected callback data",
-# status="Status",
-# )
-#
-# self.add_handler(CommandHandler('status', self.wrap(status)))
-#
-# def markup(self, ctx: Optional[Context]) -> Optional[ReplyKeyboardMarkup]:
-# buttons = [
-# [ctx.lang('status')]
-# ]
-# return ReplyKeyboardMarkup(buttons, one_time_keyboard=False)
-
-def kettle_connection_established(k: Kettle, response: Message):
+def kettle_connection_established(response: kettle_proto.MessageResponse):
try:
- assert response.frame.head.type == FrameType.ACK, f'ACK expected, but received: {response}'
+ assert isinstance(response, kettle_proto.AckMessage), f'ACK expected, but received: {response}'
except AssertionError:
- k.stop_server()
+ k.stop_all()
return
- def next_task(k, response):
+ def next_task(response: kettle_proto.MessageResponse):
+ try:
+ assert response is not False, 'server error'
+ except AssertionError:
+ k.stop_all()
+ return
+
if not control_tasks.empty():
task = control_tasks.get()
f, args = task(k)
args.append(next_task)
f(*args)
else:
- k.stop_server()
+ k.stop_all()
- next_task(k, response)
+ next_task(response)
def main():
@@ -123,7 +75,7 @@ def main():
parser.add_argument('-t', '--temperature', dest='temp', type=int, default=tempmax,
choices=range(tempmin, tempmax+tempstep, tempstep))
- arg = config.load('polaris_kettle_bot', use_cli=True, parser=parser)
+ arg = config.load('polaris_kettle_util', use_cli=True, parser=parser)
if arg.mode == 'mqtt':
server = MQTTServer()
@@ -145,19 +97,17 @@ def main():
control_tasks.put(lambda k: (k.set_target_temperature, [arg.temp]))
control_tasks.put(lambda k: (k.set_power, [PowerType.CUSTOM]))
- k = Kettle(mac='40f52018dec1', device_token='3a5865f015950cae82cd120e76a80d28')
- info = k.find()
- print('found service:', info)
+ k = Kettle(mac=config['kettle']['mac'], device_token=config['kettle']['token'])
+ info = k.discover()
+ if not info:
+ print('no device found.')
+ return 1
- k.start_server(kettle_connection_established)
+ print('found service:', info)
+ k.start_server_if_needed(kettle_connection_established)
return 0
if __name__ == '__main__':
sys.exit(main())
-
- # bot = SoundBot()
- # if 'api' in config:
- # bot.enable_logging(BotType.POLARIS_KETTLE)
- # bot.run()
diff --git a/src/sound_bot.py b/src/sound_bot.py
index b515ae7..91e51f0 100755
--- a/src/sound_bot.py
+++ b/src/sound_bot.py
@@ -6,7 +6,7 @@ import tempfile
from enum import Enum
from datetime import datetime, timedelta
from html import escape
-from typing import Optional
+from typing import Optional, List, Dict, Tuple
from home.config import config
from home.bot import Wrapper, Context, text_filter, user_any_name
@@ -27,11 +27,11 @@ from telegram.ext import (
from PIL import Image
logger = logging.getLogger(__name__)
-RenderedContent = tuple[str, Optional[InlineKeyboardMarkup]]
+RenderedContent = Tuple[str, Optional[InlineKeyboardMarkup]]
record_client: Optional[SoundRecordClient] = None
bot: Optional[Wrapper] = None
-node_client_links: dict[str, SoundNodeClient] = {}
-cam_client_links: dict[str, CameraNodeClient] = {}
+node_client_links: Dict[str, SoundNodeClient] = {}
+cam_client_links: Dict[str, CameraNodeClient] = {}
def node_client(node: str) -> SoundNodeClient:
@@ -73,7 +73,7 @@ def interval_defined(interval: int) -> bool:
return interval in config['bot']['record_intervals']
-def callback_unpack(ctx: Context) -> list[str]:
+def callback_unpack(ctx: Context) -> List[str]:
return ctx.callback_query.data[3:].split('/')
@@ -115,7 +115,7 @@ class SettingsRenderer(Renderer):
@classmethod
def node(cls, ctx: Context,
- controls: list[dict]) -> RenderedContent:
+ controls: List[dict]) -> RenderedContent:
node, = callback_unpack(ctx)
html = []
@@ -169,7 +169,7 @@ class RecordRenderer(Renderer):
return html, cls.places_markup(ctx, callback_prefix='r0')
@classmethod
- def node(cls, ctx: Context, durations: list[int]) -> RenderedContent:
+ def node(cls, ctx: Context, durations: List[int]) -> RenderedContent:
node, = callback_unpack(ctx)
html = ctx.lang('select_interval')
@@ -241,7 +241,7 @@ class FilesRenderer(Renderer):
return html, cls.places_markup(ctx, callback_prefix='f0')
@classmethod
- def filelist(cls, ctx: Context, files: list[SoundRecordFile]) -> RenderedContent:
+ def filelist(cls, ctx: Context, files: List[SoundRecordFile]) -> RenderedContent:
node, = callback_unpack(ctx)
html_files = map(lambda file: cls.file(ctx, file, node), files)
@@ -936,7 +936,6 @@ class SoundBot(Wrapper):
# cheese
self.add_handler(CallbackQueryHandler(self.wrap(camera_capture), pattern=r'^c1/.*'))
-
def markup(self, ctx: Optional[Context]) -> Optional[ReplyKeyboardMarkup]:
buttons = [
[ctx.lang('record'), ctx.lang('settings')],
diff --git a/src/sound_sensor_server.py b/src/sound_sensor_server.py
index 0303d6d..9495678 100755
--- a/src/sound_sensor_server.py
+++ b/src/sound_sensor_server.py
@@ -3,7 +3,7 @@ import logging
import threading
from time import sleep
-from typing import Optional
+from typing import Optional, List, Dict, Tuple
from functools import partial
from home.config import config
from home.util import parse_addr
@@ -18,7 +18,7 @@ server: SoundSensorServer
def get_related_nodes(node_type: MediaNodeType,
- sensor_name: str) -> list[str]:
+ sensor_name: str) -> List[str]:
if sensor_name not in config[f'sensor_to_{node_type.name.lower()}_nodes_relations']:
raise ValueError(f'unexpected sensor name {sensor_name}')
return config[f'sensor_to_{node_type.name.lower()}_nodes_relations'][sensor_name]
@@ -52,7 +52,7 @@ class HitCounter:
with self.lock:
self.sensors[name] += hits
- def get_all(self) -> list[tuple[str, int]]:
+ def get_all(self) -> List[Tuple[str, int]]:
vals = []
with self.lock:
for name, hits in self.sensors.items():
@@ -119,7 +119,7 @@ def hits_sender():
api: Optional[WebAPIClient] = None
hc: Optional[HitCounter] = None
-record_clients: dict[MediaNodeType, RecordClient] = {}
+record_clients: Dict[MediaNodeType, RecordClient] = {}
# record callbacks