diff options
author | Evgeny Zinoviev <me@ch1p.io> | 2023-05-31 09:22:00 +0300 |
---|---|---|
committer | Evgeny Zinoviev <me@ch1p.io> | 2023-06-10 02:07:23 +0300 |
commit | f29e139cbb7e4a4d539cba6e894ef4a6acd312d6 (patch) | |
tree | 6246f126325c5c36fb573134a05f2771cd747966 /src/home | |
parent | 3e3753d726f8a02d98368f20f77dd9fa739e3d80 (diff) |
WIP: big refactoring
Diffstat (limited to 'src/home')
40 files changed, 1525 insertions, 699 deletions
diff --git a/src/home/audio/amixer.py b/src/home/audio/amixer.py index 53e6bce..5133c97 100644 --- a/src/home/audio/amixer.py +++ b/src/home/audio/amixer.py @@ -1,6 +1,6 @@ import subprocess -from ..config import config +from ..config import app_config as config from threading import Lock from typing import Union, List diff --git a/src/home/config/__init__.py b/src/home/config/__init__.py index cc9c091..2fa5214 100644 --- a/src/home/config/__init__.py +++ b/src/home/config/__init__.py @@ -1 +1,13 @@ -from .config import ConfigStore, config, is_development_mode, setup_logging +from .config import ( + Config, + ConfigUnit, + AppConfigUnit, + Translation, + config, + is_development_mode, + setup_logging +) +from ._configs import ( + LinuxBoardsConfig, + ServicesListConfig +)
\ No newline at end of file diff --git a/src/home/config/_configs.py b/src/home/config/_configs.py new file mode 100644 index 0000000..3a1aae5 --- /dev/null +++ b/src/home/config/_configs.py @@ -0,0 +1,55 @@ +from .config import ConfigUnit +from typing import Optional + + +class ServicesListConfig(ConfigUnit): + NAME = 'services_list' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'type': 'list', + 'empty': False, + 'schema': { + 'type': 'string' + } + } + + +class LinuxBoardsConfig(ConfigUnit): + NAME = 'linux_boards' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'type': 'dict', + 'schema': { + 'mdns': {'type': 'string', 'required': True}, + 'board': {'type': 'string', 'required': True}, + 'network': { + 'type': 'list', + 'required': True, + 'empty': False, + 'allowed': ['wifi', 'ethernet'] + }, + 'ram': {'type': 'integer', 'required': True}, + 'online': {'type': 'boolean', 'required': True}, + + # optional + 'services': { + 'type': 'list', + 'empty': False, + 'allowed': ServicesListConfig().get() + }, + 'ext_hdd': { + 'type': 'list', + 'schema': { + 'type': 'dict', + 'schema': { + 'mountpoint': {'type': 'string', 'required': True}, + 'size': {'type': 'integer', 'required': True} + } + }, + }, + } + } diff --git a/src/home/config/config.py b/src/home/config/config.py index 4681685..aef9ee7 100644 --- a/src/home/config/config.py +++ b/src/home/config/config.py @@ -1,58 +1,256 @@ -import toml import yaml import logging import os +import pprint -from os.path import join, isdir, isfile -from typing import Optional, Any, MutableMapping +from abc import ABC +from cerberus import Validator, DocumentError +from typing import Optional, Any, MutableMapping, Union from argparse import ArgumentParser -from ..util import parse_addr +from enum import Enum, auto +from os.path import join, isdir, isfile +from ..util import Addr + + +CONFIG_DIRECTORIES = ( + join(os.environ['HOME'], '.config', 'homekit'), + '/etc/homekit' +) + +class RootSchemaType(Enum): + DEFAULT = auto() + DICT = auto() + LIST = auto() + + +class BaseConfigUnit(ABC): + _data: MutableMapping[str, Any] + _logger: logging.Logger + def __init__(self): + self._data = {} + self._logger = logging.getLogger(self.__class__.__name__) + + def __getitem__(self, key): + return self._data[key] + + def __setitem__(self, key, value): + raise NotImplementedError('overwriting config values is prohibited') -def _get_config_path(name: str) -> str: - formats = ['toml', 'yaml'] + def __contains__(self, key): + return key in self._data - dirname = join(os.environ['HOME'], '.config', name) + def load_from(self, path: str): + with open(path, 'r') as fd: + self._data = yaml.safe_load(fd) - if isdir(dirname): - for fmt in formats: - filename = join(dirname, f'config.{fmt}') - if isfile(filename): - return filename + def get(self, + key: Optional[str] = None, + default=None): + if key is None: + return self._data - raise IOError(f'config not found in {dirname}') + cur = self._data + pts = key.split('.') + for i in range(len(pts)): + k = pts[i] + if i < len(pts)-1: + if k not in cur: + raise KeyError(f'key {k} not found') + else: + return cur[k] if k in cur else default + cur = self._data[k] - else: - filenames = [join(os.environ['HOME'], '.config', f'{name}.{format}') for format in formats] - for file in filenames: - if isfile(file): - return file + raise KeyError(f'option {key} not found') - raise IOError(f'config not found') +class ConfigUnit(BaseConfigUnit): + NAME = 'dumb' + + def __init__(self, name=None, load=True): + super().__init__() + + self._data = {} + self._logger = logging.getLogger(self.__class__.__name__) + + if self.NAME != 'dumb' and load: + self.load_from(self.get_config_path()) + self.validate() + + elif name is not None: + self.NAME = name + + @classmethod + def get_config_path(cls, name=None) -> str: + if name is None: + name = cls.NAME + if name is None: + raise ValueError('get_config_path: name is none') + + for dirname in CONFIG_DIRECTORIES: + if isdir(dirname): + filename = join(dirname, f'{name}.yaml') + if isfile(filename): + return filename + + raise IOError(f'\'{name}.yaml\' not found') + + @staticmethod + def schema() -> Optional[dict]: + return None + + def validate(self): + schema = self.schema() + if not schema: + self._logger.warning('validate: no schema') + return + + if isinstance(self, AppConfigUnit): + schema['logging'] = { + 'type': 'dict', + 'schema': { + 'logging': {'type': 'bool'} + } + } + + rst = RootSchemaType.DEFAULT + try: + if schema['type'] == 'dict': + rst = RootSchemaType.DICT + elif schema['type'] == 'list': + rst = RootSchemaType.LIST + elif schema['roottype'] == 'dict': + del schema['roottype'] + rst = RootSchemaType.DICT + except KeyError: + pass + + if rst == RootSchemaType.DICT: + v = Validator({'document': { + 'type': 'dict', + 'keysrules': {'type': 'string'}, + 'valuesrules': schema + }}) + result = v.validate({'document': self._data}) + elif rst == RootSchemaType.LIST: + v = Validator({'document': schema}) + result = v.validate({'document': self._data}) + else: + v = Validator(schema) + result = v.validate(self._data) + # pprint.pprint(self._data) + if not result: + # pprint.pprint(v.errors) + raise DocumentError(f'{self.__class__.__name__}: failed to validate data:\n{pprint.pformat(v.errors)}') + try: + self.custom_validator(self._data) + except Exception as e: + raise DocumentError(f'{self.__class__.__name__}: {str(e)}') + + @staticmethod + def custom_validator(data): + pass -class ConfigStore: - data: MutableMapping[str, Any] + def get_addr(self, key: str): + return Addr.fromstring(self.get(key)) + + +class AppConfigUnit(ConfigUnit): + _logging_verbose: bool + _logging_fmt: Optional[str] + _logging_file: Optional[str] + + def __init__(self, *args, **kwargs): + super().__init__(load=False, *args, **kwargs) + self._logging_verbose = False + self._logging_fmt = None + self._logging_file = None + + def logging_set_fmt(self, fmt: str) -> None: + self._logging_fmt = fmt + + def logging_get_fmt(self) -> Optional[str]: + try: + return self['logging']['default_fmt'] + except KeyError: + return self._logging_fmt + + def logging_set_file(self, file: str) -> None: + self._logging_file = file + + def logging_get_file(self) -> Optional[str]: + try: + return self['logging']['file'] + except KeyError: + return self._logging_file + + def logging_set_verbose(self): + self._logging_verbose = True + + def logging_is_verbose(self) -> bool: + try: + return bool(self['logging']['verbose']) + except KeyError: + return self._logging_verbose + + +class TranslationUnit(BaseConfigUnit): + pass + + +class Translation: + LANGUAGES = ('en', 'ru') + _langs: dict[str, TranslationUnit] + + def __init__(self, name: str): + super().__init__() + self._langs = {} + for lang in self.LANGUAGES: + for dirname in CONFIG_DIRECTORIES: + if isdir(dirname): + filename = join(dirname, f'i18n-{lang}', f'{name}.yaml') + if lang in self._langs: + raise RuntimeError(f'{name}: translation unit for lang \'{lang}\' already loaded') + self._langs[lang] = TranslationUnit() + self._langs[lang].load_from(filename) + diff = set() + for data in self._langs.values(): + diff ^= data.get().keys() + if len(diff) > 0: + raise RuntimeError(f'{name}: translation units have difference in keys: ' + ', '.join(diff)) + + def get(self, lang: str) -> TranslationUnit: + return self._langs[lang] + + +class Config: app_name: Optional[str] + app_config: AppConfigUnit def __init__(self): - self.data = {} self.app_name = None + self.app_config = AppConfigUnit() + + def load_app(self, + name: Optional[Union[str, AppConfigUnit, bool]] = None, + use_cli=True, + parser: ArgumentParser = None, + no_config=False): + global app_config + + if issubclass(name, AppConfigUnit) or name == AppConfigUnit: + self.app_name = name.NAME + self.app_config = name() + app_config = self.app_config + else: + self.app_name = name if isinstance(name, str) else None - def load(self, name: Optional[str] = None, - use_cli=True, - parser: ArgumentParser = None): - self.app_name = name - - if (name is None) and (not use_cli): + if self.app_name is None and not use_cli: raise RuntimeError('either config name must be none or use_cli must be True') - log_default_fmt = False - log_file = None - log_verbose = False - no_config = name is False - + no_config = name is False or no_config path = None + if use_cli: if parser is None: parser = ArgumentParser() @@ -68,75 +266,38 @@ class ConfigStore: path = args.config if args.verbose: - log_verbose = True + self.app_config.logging_set_verbose() if args.log_file: - log_file = args.log_file + self.app_config.logging_set_file(args.log_file) if args.log_default_fmt: - log_default_fmt = args.log_default_fmt + self.app_config.logging_set_fmt(args.log_default_fmt) - if not no_config and path is None: - path = _get_config_path(name) + if not isinstance(name, ConfigUnit): + if not no_config and path is None: + path = ConfigUnit.get_config_path(name=self.app_name) - if no_config: - self.data = {} - else: - if path.endswith('.toml'): - self.data = toml.load(path) - elif path.endswith('.yaml'): - with open(path, 'r') as fd: - self.data = yaml.safe_load(fd) - - if 'logging' in self: - if not log_file and 'file' in self['logging']: - log_file = self['logging']['file'] - if log_default_fmt and 'default_fmt' in self['logging']: - log_default_fmt = self['logging']['default_fmt'] + if not no_config: + self.app_config.load_from(path) - setup_logging(log_verbose, log_file, log_default_fmt) + setup_logging(self.app_config.logging_is_verbose(), + self.app_config.logging_get_file(), + self.app_config.logging_get_fmt()) if use_cli: return args - def __getitem__(self, key): - return self.data[key] - - def __setitem__(self, key, value): - raise NotImplementedError('overwriting config values is prohibited') - - def __contains__(self, key): - return key in self.data - - def get(self, key: str, default=None): - cur = self.data - pts = key.split('.') - for i in range(len(pts)): - k = pts[i] - if i < len(pts)-1: - if k not in cur: - raise KeyError(f'key {k} not found') - else: - return cur[k] if k in cur else default - cur = self.data[k] - raise KeyError(f'option {key} not found') - - def get_addr(self, key: str): - return parse_addr(self.get(key)) - - def items(self): - return self.data.items() - -config = ConfigStore() +config = Config() def is_development_mode() -> bool: if 'HK_MODE' in os.environ and os.environ['HK_MODE'] == 'dev': return True - return ('logging' in config) and ('verbose' in config['logging']) and (config['logging']['verbose'] is True) + return ('logging' in config.app_config) and ('verbose' in config.app_config['logging']) and (config.app_config['logging']['verbose'] is True) -def setup_logging(verbose=False, log_file=None, default_fmt=False): +def setup_logging(verbose=False, log_file=None, default_fmt=None): logging_level = logging.INFO if is_development_mode() or verbose: logging_level = logging.DEBUG diff --git a/src/home/database/clickhouse.py b/src/home/database/clickhouse.py index ca81628..d0ec283 100644 --- a/src/home/database/clickhouse.py +++ b/src/home/database/clickhouse.py @@ -1,7 +1,7 @@ import logging from zoneinfo import ZoneInfo -from datetime import datetime, timedelta +from datetime import datetime from clickhouse_driver import Client as ClickhouseClient from ..config import is_development_mode diff --git a/src/home/database/sqlite.py b/src/home/database/sqlite.py index bfba929..8c6145c 100644 --- a/src/home/database/sqlite.py +++ b/src/home/database/sqlite.py @@ -5,24 +5,27 @@ import logging from ..config import config, is_development_mode -def _get_database_path(name: str, dbname: str) -> str: - return os.path.join(os.environ['HOME'], '.config', name, f'{dbname}.db') +def _get_database_path(name: str) -> str: + return os.path.join( + os.environ['HOME'], + '.config', + 'homekit', + 'data', + f'{name}.db') class SQLiteBase: SCHEMA = 1 - def __init__(self, name=None, dbname='bot', check_same_thread=False): - db_path = config.get('db_path', default=None) - if db_path is None: - if not name: - name = config.app_name - if not dbname: - dbname = name - db_path = _get_database_path(name, dbname) + def __init__(self, name=None, check_same_thread=False): + if name is None: + name = config.app_config['database_name'] + database_path = _get_database_path(name) + if not os.path.exists(os.path.dirname(database_path)): + os.makedirs(os.path.dirname(database_path)) self.logger = logging.getLogger(self.__class__.__name__) - self.sqlite = sqlite3.connect(db_path, check_same_thread=check_same_thread) + self.sqlite = sqlite3.connect(database_path, check_same_thread=check_same_thread) if is_development_mode(): self.sql_logger = logging.getLogger(self.__class__.__name__) diff --git a/src/home/inverter/config.py b/src/home/inverter/config.py new file mode 100644 index 0000000..62b8859 --- /dev/null +++ b/src/home/inverter/config.py @@ -0,0 +1,13 @@ +from ..config import ConfigUnit +from typing import Optional + + +class InverterdConfig(ConfigUnit): + NAME = 'inverterd' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'remote_addr': {'type': 'string'}, + 'local_addr': {'type': 'string'}, + }
\ No newline at end of file diff --git a/src/home/media/__init__.py b/src/home/media/__init__.py index 976c990..6923105 100644 --- a/src/home/media/__init__.py +++ b/src/home/media/__init__.py @@ -12,6 +12,7 @@ __map__ = { __all__ = list(itertools.chain(*__map__.values())) + def __getattr__(name): if name in __all__: for file, names in __map__.items(): diff --git a/src/home/mqtt/__init__.py b/src/home/mqtt/__init__.py index 982e2b6..707d59c 100644 --- a/src/home/mqtt/__init__.py +++ b/src/home/mqtt/__init__.py @@ -1,4 +1,7 @@ -from .mqtt import MqttBase -from .util import poll_tick -from .relay import MqttRelay, MqttRelayState -from .temphum import MqttTempHum
\ No newline at end of file +from ._mqtt import Mqtt +from ._node import MqttNode +from ._module import MqttModule +from ._wrapper import MqttWrapper +from ._config import MqttConfig, MqttCreds, MqttNodesConfig +from ._payload import MqttPayload, MqttPayloadCustomField +from ._util import get_modules as get_mqtt_modules
\ No newline at end of file diff --git a/src/home/mqtt/_config.py b/src/home/mqtt/_config.py new file mode 100644 index 0000000..f9047b4 --- /dev/null +++ b/src/home/mqtt/_config.py @@ -0,0 +1,165 @@ +from ..config import ConfigUnit +from typing import Optional, Union +from ..util import Addr +from collections import namedtuple + +MqttCreds = namedtuple('MqttCreds', 'username, password') + + +class MqttConfig(ConfigUnit): + NAME = 'mqtt' + + @staticmethod + def schema() -> Optional[dict]: + addr_schema = { + 'type': 'dict', + 'required': True, + 'schema': { + 'host': {'type': 'string', 'required': True}, + 'port': {'type': 'integer', 'required': True} + } + } + + schema = {} + for key in ('local', 'remote'): + schema[f'{key}_addr'] = addr_schema + + schema['creds'] = { + 'type': 'dict', + 'required': True, + 'keysrules': {'type': 'string'}, + 'valuesrules': { + 'type': 'dict', + 'schema': { + 'username': {'type': 'string', 'required': True}, + 'password': {'type': 'string', 'required': True}, + } + } + } + + for key in ('client', 'server'): + schema[f'default_{key}_creds'] = {'type': 'string', 'required': True} + + return schema + + def remote_addr(self) -> Addr: + return Addr(host=self['remote_addr']['host'], + port=self['remote_addr']['port']) + + def local_addr(self) -> Addr: + return Addr(host=self['local_addr']['host'], + port=self['local_addr']['port']) + + def creds_by_name(self, name: str) -> MqttCreds: + return MqttCreds(username=self['creds'][name]['username'], + password=self['creds'][name]['password']) + + def creds(self) -> MqttCreds: + return self.creds_by_name(self['default_client_creds']) + + def server_creds(self) -> MqttCreds: + return self.creds_by_name(self['default_server_creds']) + + +class MqttNodesConfig(ConfigUnit): + NAME = 'mqtt_nodes' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'common': { + 'type': 'dict', + 'schema': { + 'temphum': { + 'type': 'dict', + 'schema': { + 'interval': {'type': 'integer'} + } + }, + 'password': {'type': 'string'} + } + }, + 'nodes': { + 'type': 'dict', + 'required': True, + 'keysrules': {'type': 'string'}, + 'valuesrules': { + 'type': 'dict', + 'schema': { + 'type': {'type': 'string', 'required': True, 'allowed': ['esp8266', 'linux', 'none'],}, + 'board': {'type': 'string', 'allowed': ['nodemcu', 'd1_mini_lite', 'esp12e']}, + 'temphum': { + 'type': 'dict', + 'schema': { + 'module': {'type': 'string', 'required': True, 'allowed': ['si7021', 'dht12']}, + 'interval': {'type': 'integer'}, + 'i2c_bus': {'type': 'integer'}, + 'tcpserver': { + 'type': 'dict', + 'schema': { + 'port': {'type': 'integer', 'required': True} + } + } + } + }, + 'relay': { + 'type': 'dict', + 'schema': { + 'device_type': {'type': 'string', 'allowed': ['lamp', 'pump', 'solenoid'], 'required': True}, + 'legacy_topics': {'type': 'boolean'} + } + }, + 'password': {'type': 'string'} + } + } + } + } + + @staticmethod + def custom_validator(data): + for name, node in data['nodes'].items(): + if 'temphum' in node: + if node['type'] == 'linux': + if 'i2c_bus' not in node['temphum']: + raise KeyError(f'nodes.{name}.temphum: i2c_bus is missing but required for type=linux') + if node['type'] in ('esp8266',) and 'board' not in node: + raise KeyError(f'nodes.{name}: board is missing but required for type={node["type"]}') + + def get_node(self, name: str) -> dict: + node = self['nodes'][name] + if node['type'] == 'none': + return node + + try: + if 'password' not in node: + node['password'] = self['common']['password'] + except KeyError: + pass + + try: + if 'temphum' in node: + for ckey, cval in self['common']['temphum'].items(): + if ckey not in node['temphum']: + node['temphum'][ckey] = cval + except KeyError: + pass + + return node + + def get_nodes(self, + filters: Optional[Union[list[str], tuple[str]]] = None, + only_names=False) -> Union[dict, list[str]]: + if filters: + for f in filters: + if f not in ('temphum', 'relay'): + raise ValueError(f'{self.__class__.__name__}::get_node(): invalid filter {f}') + reslist = [] + resdict = {} + for name in self['nodes'].keys(): + node = self.get_node(name) + if (not filters) or ('temphum' in filters and 'temphum' in node) or ('relay' in filters and 'relay' in node): + if only_names: + reslist.append(name) + else: + resdict[name] = node + return reslist if only_names else resdict diff --git a/src/home/mqtt/_module.py b/src/home/mqtt/_module.py new file mode 100644 index 0000000..80f27bb --- /dev/null +++ b/src/home/mqtt/_module.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import abc +import logging +import threading + +from time import sleep +from ..util import next_tick_gen + +from typing import TYPE_CHECKING, Optional +if TYPE_CHECKING: + from ._node import MqttNode + from ._payload import MqttPayload + + +class MqttModule(abc.ABC): + _tick_interval: int + _initialized: bool + _connected: bool + _ticker: Optional[threading.Thread] + _mqtt_node_ref: Optional[MqttNode] + + def __init__(self, tick_interval=0): + self._tick_interval = tick_interval + self._initialized = False + self._ticker = None + self._logger = logging.getLogger(self.__class__.__name__) + self._connected = False + self._mqtt_node_ref = None + + def on_connect(self, mqtt: MqttNode): + self._connected = True + self._mqtt_node_ref = mqtt + if self._tick_interval: + self._start_ticker() + + def on_disconnect(self, mqtt: MqttNode): + self._connected = False + self._mqtt_node_ref = None + + def is_initialized(self): + return self._initialized + + def set_initialized(self): + self._initialized = True + + def unset_initialized(self): + self._initialized = False + + def tick(self): + pass + + def _tick(self): + g = next_tick_gen(self._tick_interval) + while self._connected: + sleep(next(g)) + if not self._connected: + break + self.tick() + + def _start_ticker(self): + if not self._ticker or not self._ticker.is_alive(): + name_part = f'{self._mqtt_node_ref.id}/' if self._mqtt_node_ref else '' + self._ticker = None + self._ticker = threading.Thread(target=self._tick, + name=f'mqtt:{self.__class__.__name__}/{name_part}ticker') + self._ticker.start() + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + pass diff --git a/src/home/mqtt/mqtt.py b/src/home/mqtt/_mqtt.py index 4acd4f6..746ae2e 100644 --- a/src/home/mqtt/mqtt.py +++ b/src/home/mqtt/_mqtt.py @@ -3,19 +3,24 @@ import paho.mqtt.client as mqtt import ssl import logging -from typing import Tuple -from ..config import config +from ._config import MqttCreds, MqttConfig +from typing import Optional -def username_and_password() -> Tuple[str, str]: - username = config['mqtt']['username'] if 'username' in config['mqtt'] else None - password = config['mqtt']['password'] if 'password' in config['mqtt'] else None - return username, password +class Mqtt: + _connected: bool + _is_server: bool + _mqtt_config: MqttConfig + def __init__(self, + clean_session=True, + client_id='', + creds: Optional[MqttCreds] = None, + is_server=False): + if not client_id: + raise ValueError('client_id must not be empty') -class MqttBase: - def __init__(self, clean_session=True): - self._client = mqtt.Client(client_id=config['mqtt']['client_id'], + self._client = mqtt.Client(client_id=client_id, protocol=mqtt.MQTTv311, clean_session=clean_session) self._client.on_connect = self.on_connect @@ -24,15 +29,17 @@ class MqttBase: self._client.on_log = self.on_log self._client.on_publish = self.on_publish self._loop_started = False - + self._connected = False + self._is_server = is_server + self._mqtt_config = MqttConfig() self._logger = logging.getLogger(self.__class__.__name__) - username, password = username_and_password() - if username and password: - self._logger.debug(f'username={username} password={password}') - self._client.username_pw_set(username, password) + if not creds: + creds = self._mqtt_config.creds() if not is_server else self._mqtt_config.server_creds() + + self._client.username_pw_set(creds.username, creds.password) - def configure_tls(self): + def _configure_tls(self): ca_certs = os.path.realpath(os.path.join( os.path.dirname(os.path.realpath(__file__)), '..', @@ -41,13 +48,14 @@ class MqttBase: 'assets', 'mqtt_ca.crt' )) - self._client.tls_set(ca_certs=ca_certs, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_TLSv1_2) + self._client.tls_set(ca_certs=ca_certs, + cert_reqs=ssl.CERT_REQUIRED, + tls_version=ssl.PROTOCOL_TLSv1_2) def connect_and_loop(self, loop_forever=True): - host = config['mqtt']['host'] - port = config['mqtt']['port'] - - self._client.connect(host, port, 60) + self._configure_tls() + addr = self._mqtt_config.local_addr() if self._is_server else self._mqtt_config.remote_addr() + self._client.connect(addr.host, addr.port, 60) if loop_forever: self._client.loop_forever() else: @@ -61,9 +69,11 @@ class MqttBase: def on_connect(self, client: mqtt.Client, userdata, flags, rc): self._logger.info("Connected with result code " + str(rc)) + self._connected = True def on_disconnect(self, client: mqtt.Client, userdata, rc): self._logger.info("Disconnected with result code " + str(rc)) + self._connected = False def on_log(self, client: mqtt.Client, userdata, level, buf): level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO @@ -73,4 +83,4 @@ class MqttBase: self._logger.debug(msg.topic + ": " + str(msg.payload)) def on_publish(self, client: mqtt.Client, userdata, mid): - self._logger.debug(f'publish done, mid={mid}')
\ No newline at end of file + self._logger.debug(f'publish done, mid={mid}') diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py new file mode 100644 index 0000000..4e259a4 --- /dev/null +++ b/src/home/mqtt/_node.py @@ -0,0 +1,92 @@ +import logging +import importlib + +from typing import List, TYPE_CHECKING, Optional +from ._payload import MqttPayload +from ._module import MqttModule +if TYPE_CHECKING: + from ._wrapper import MqttWrapper +else: + MqttWrapper = None + + +class MqttNode: + _modules: List[MqttModule] + _module_subscriptions: dict[str, MqttModule] + _node_id: str + _node_secret: str + _payload_callbacks: list[callable] + _wrapper: Optional[MqttWrapper] + + def __init__(self, + node_id: str, + node_secret: Optional[str] = None): + self._modules = [] + self._module_subscriptions = {} + self._node_id = node_id + self._node_secret = node_secret + self._payload_callbacks = [] + self._logger = logging.getLogger(self.__class__.__name__) + self._wrapper = None + + def on_connect(self, wrapper: MqttWrapper): + self._wrapper = wrapper + for module in self._modules: + if not module.is_initialized(): + module.on_connect(self) + module.set_initialized() + + def on_disconnect(self): + self._wrapper = None + for module in self._modules: + module.unset_initialized() + + def on_message(self, topic, payload): + if topic in self._module_subscriptions: + payload = self._module_subscriptions[topic].handle_payload(self, topic, payload) + if isinstance(payload, MqttPayload): + for f in self._payload_callbacks: + f(self, payload) + + def load_module(self, module_name: str, *args, **kwargs) -> MqttModule: + module = importlib.import_module(f'..module.{module_name}', __name__) + if not hasattr(module, 'MODULE_NAME'): + raise RuntimeError(f'MODULE_NAME not found in module {module}') + cl = getattr(module, getattr(module, 'MODULE_NAME')) + instance = cl(*args, **kwargs) + self.add_module(instance) + return instance + + def add_module(self, module: MqttModule): + self._modules.append(module) + if self._wrapper and self._wrapper._connected: + module.on_connect(self) + module.set_initialized() + + def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1): + if not self._wrapper or not self._wrapper._connected: + raise RuntimeError('not connected') + + self._module_subscriptions[topic] = module + self._wrapper.subscribe(self.id, topic, qos) + + def publish(self, + topic: str, + payload: bytes, + qos: int = 1): + self._wrapper.publish(self.id, topic, payload, qos) + + def add_payload_callback(self, callback: callable): + self._payload_callbacks.append(callback) + + @property + def id(self) -> str: + return self._node_id + + @property + def secret(self) -> str: + return self._node_secret + + @secret.setter + def secret(self, secret: str) -> None: + self._node_secret = secret diff --git a/src/home/mqtt/payload/base_payload.py b/src/home/mqtt/_payload.py index 1abd898..58eeae3 100644 --- a/src/home/mqtt/payload/base_payload.py +++ b/src/home/mqtt/_payload.py @@ -1,5 +1,5 @@ -import abc import struct +import abc import re from typing import Optional, Tuple @@ -142,4 +142,4 @@ def _bit_field_params(cl) -> Optional[Tuple[int, ...]]: match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__) if match is not None: return tuple([int(match.group(i)) for i in range(1, 4)]) - return None + return None
\ No newline at end of file diff --git a/src/home/mqtt/_util.py b/src/home/mqtt/_util.py new file mode 100644 index 0000000..390d463 --- /dev/null +++ b/src/home/mqtt/_util.py @@ -0,0 +1,15 @@ +import os +import re + +from typing import List + + +def get_modules() -> List[str]: + modules = [] + modules_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module') + for name in os.listdir(modules_dir): + if os.path.isdir(os.path.join(modules_dir, name)): + continue + name = re.sub(r'\.py$', '', name) + modules.append(name) + return modules diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py new file mode 100644 index 0000000..f858f88 --- /dev/null +++ b/src/home/mqtt/_wrapper.py @@ -0,0 +1,59 @@ +import paho.mqtt.client as mqtt + +from ._mqtt import Mqtt +from ._node import MqttNode +from ..config import config +from ..util import strgen + + +class MqttWrapper(Mqtt): + _nodes: list[MqttNode] + + def __init__(self, + client_id: str, + topic_prefix='hk', + randomize_client_id=False, + clean_session=True): + if randomize_client_id: + client_id += '_'+strgen(6) + super().__init__(clean_session=clean_session, + client_id=client_id) + self._nodes = [] + self._topic_prefix = topic_prefix + + def on_connect(self, client: mqtt.Client, userdata, flags, rc): + super().on_connect(client, userdata, flags, rc) + for node in self._nodes: + node.on_connect(self) + + def on_disconnect(self, client: mqtt.Client, userdata, rc): + super().on_disconnect(client, userdata, rc) + for node in self._nodes: + node.on_disconnect() + + def on_message(self, client: mqtt.Client, userdata, msg): + try: + topic = msg.topic + for node in self._nodes: + node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) + except Exception as e: + self._logger.exception(str(e)) + + def add_node(self, node: MqttNode): + self._nodes.append(node) + if self._connected: + node.on_connect(self) + + def subscribe(self, + node_id: str, + topic: str, + qos: int): + self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos) + + def publish(self, + node_id: str, + topic: str, + payload: bytes, + qos: int): + self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos) + self._client.loop_write() diff --git a/src/home/mqtt/esp.py b/src/home/mqtt/esp.py deleted file mode 100644 index 56ced83..0000000 --- a/src/home/mqtt/esp.py +++ /dev/null @@ -1,106 +0,0 @@ -import re -import paho.mqtt.client as mqtt - -from .mqtt import MqttBase -from typing import Optional, Union -from .payload.esp import ( - OTAPayload, - OTAResultPayload, - DiagnosticsPayload, - InitialDiagnosticsPayload -) - - -class MqttEspDevice: - id: str - secret: Optional[str] - - def __init__(self, id: str, secret: Optional[str] = None): - self.id = id - self.secret = secret - - -class MqttEspBase(MqttBase): - _devices: list[MqttEspDevice] - _message_callback: Optional[callable] - _ota_publish_callback: Optional[callable] - - TOPIC_LEAF = 'esp' - - def __init__(self, - devices: Union[MqttEspDevice, list[MqttEspDevice]], - subscribe_to_updates=True): - super().__init__(clean_session=True) - if not isinstance(devices, list): - devices = [devices] - self._devices = devices - self._message_callback = None - self._ota_publish_callback = None - self._subscribe_to_updates = subscribe_to_updates - self._ota_mid = None - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - super().on_connect(client, userdata, flags, rc) - - if self._subscribe_to_updates: - for device in self._devices: - topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#' - self._logger.debug(f"subscribing to {topic}") - client.subscribe(topic, qos=1) - - def on_publish(self, client: mqtt.Client, userdata, mid): - if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback: - self._ota_publish_callback() - - def set_message_callback(self, callback: callable): - self._message_callback = callback - - def on_message(self, client: mqtt.Client, userdata, msg): - try: - match = re.match(self.get_mqtt_topics(), msg.topic) - self._logger.debug(f'topic: {msg.topic}') - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - # try: - next(d for d in self._devices if d.id == device_id) - # except StopIteration:h - # return - - message = None - if subtopic == 'stat': - message = DiagnosticsPayload.unpack(msg.payload) - elif subtopic == 'stat1': - message = InitialDiagnosticsPayload.unpack(msg.payload) - elif subtopic == 'otares': - message = OTAResultPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - return True - - except Exception as e: - self._logger.exception(str(e)) - - def push_ota(self, - device_id, - filename: str, - publish_callback: callable, - qos: int): - device = next(d for d in self._devices if d.id == device_id) - assert device.secret is not None, 'device secret not specified' - - self._ota_publish_callback = publish_callback - payload = OTAPayload(secret=device.secret, filename=filename) - publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota', - payload=payload.pack(), - qos=qos) - self._ota_mid = publish_result.mid - self._client.loop_write() - - @classmethod - def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None): - return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'
\ No newline at end of file diff --git a/src/home/mqtt/payload/esp.py b/src/home/mqtt/module/diagnostics.py index 171cdb9..5db5e99 100644 --- a/src/home/mqtt/payload/esp.py +++ b/src/home/mqtt/module/diagnostics.py @@ -1,39 +1,8 @@ -import hashlib +from .._payload import MqttPayload, MqttPayloadCustomField +from .._node import MqttNode, MqttModule +from typing import Optional -from .base_payload import MqttPayload, MqttPayloadCustomField - - -class OTAResultPayload(MqttPayload): - FORMAT = '=BB' - result: int - error_code: int - - -class OTAPayload(MqttPayload): - secret: str - filename: str - - # structure of returned data: - # - # uint8_t[len(secret)] secret; - # uint8_t[16] md5; - # *uint8_t data - - def pack(self): - buf = bytearray(self.secret.encode()) - m = hashlib.md5() - with open(self.filename, 'rb') as fd: - content = fd.read() - m.update(content) - buf.extend(m.digest()) - buf.extend(content) - return buf - - def unpack(cls, buf: bytes): - raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') - # secret = buf[:12].decode() - # filename = buf[12:].decode() - # return OTAPayload(secret=secret, filename=filename) +MODULE_NAME = 'MqttDiagnosticsModule' class DiagnosticsFlags(MqttPayloadCustomField): @@ -76,3 +45,20 @@ class DiagnosticsPayload(MqttPayload): rssi: int free_heap: int flags: DiagnosticsFlags + + +class MqttDiagnosticsModule(MqttModule): + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + for topic in ('diag', 'd1ag', 'stat', 'stat1'): + mqtt.subscribe_module(topic, self) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + message = None + if topic in ('stat', 'diag'): + message = DiagnosticsPayload.unpack(payload) + elif topic in ('stat1', 'd1ag'): + message = InitialDiagnosticsPayload.unpack(payload) + if message: + self._logger.debug(message) + return message diff --git a/src/home/mqtt/module/inverter.py b/src/home/mqtt/module/inverter.py new file mode 100644 index 0000000..d927a06 --- /dev/null +++ b/src/home/mqtt/module/inverter.py @@ -0,0 +1,195 @@ +import time +import json +import datetime +try: + import inverterd +except: + pass + +from typing import Optional +from .._module import MqttModule +from .._node import MqttNode +from .._payload import MqttPayload, bit_field +try: + from home.database import InverterDatabase +except: + pass + +_mult_10 = lambda n: int(n*10) +_div_10 = lambda n: n/10 + + +MODULE_NAME = 'MqttInverterModule' + +STATUS_TOPIC = 'status' +GENERATION_TOPIC = 'generation' + + +class MqttInverterStatusPayload(MqttPayload): + # 46 bytes + FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH' + + PACKER = { + 'grid_voltage': _mult_10, + 'grid_freq': _mult_10, + 'ac_output_voltage': _mult_10, + 'ac_output_freq': _mult_10, + 'battery_voltage': _mult_10, + 'battery_voltage_scc': _mult_10, + 'battery_voltage_scc2': _mult_10, + 'pv1_input_voltage': _mult_10, + 'pv2_input_voltage': _mult_10 + } + UNPACKER = { + 'grid_voltage': _div_10, + 'grid_freq': _div_10, + 'ac_output_voltage': _div_10, + 'ac_output_freq': _div_10, + 'battery_voltage': _div_10, + 'battery_voltage_scc': _div_10, + 'battery_voltage_scc2': _div_10, + 'pv1_input_voltage': _div_10, + 'pv2_input_voltage': _div_10 + } + + time: int + grid_voltage: float + grid_freq: float + ac_output_voltage: float + ac_output_freq: float + ac_output_apparent_power: int + ac_output_active_power: int + output_load_percent: int + battery_voltage: float + battery_voltage_scc: float + battery_voltage_scc2: float + battery_discharge_current: int + battery_charge_current: int + battery_capacity: int + inverter_heat_sink_temp: int + mppt1_charger_temp: int + mppt2_charger_temp: int + pv1_input_power: int + pv2_input_power: int + pv1_input_voltage: float + pv2_input_voltage: float + + # H + mppt1_charger_status: bit_field(0, 16, 2) + mppt2_charger_status: bit_field(0, 16, 2) + battery_power_direction: bit_field(0, 16, 2) + dc_ac_power_direction: bit_field(0, 16, 2) + line_power_direction: bit_field(0, 16, 2) + load_connected: bit_field(0, 16, 1) + + +class MqttInverterGenerationPayload(MqttPayload): + # 8 bytes + FORMAT = 'II' + + time: int + wh: int + + +class MqttInverterModule(MqttModule): + _status_poll_freq: int + _generation_poll_freq: int + _inverter: Optional[inverterd.Client] + _database: Optional[InverterDatabase] + _gen_prev: float + + def __init__(self, status_poll_freq=0, generation_poll_freq=0): + super().__init__(tick_interval=status_poll_freq) + self._status_poll_freq = status_poll_freq + self._generation_poll_freq = generation_poll_freq + + # this defines whether this is a publisher or a subscriber + if status_poll_freq > 0: + self._inverter = inverterd.Client() + self._inverter.connect() + self._inverter.format(inverterd.Format.SIMPLE_JSON) + self._database = None + else: + self._inverter = None + self._database = InverterDatabase() + + self._gen_prev = 0 + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + if not self._inverter: + mqtt.subscribe_module(STATUS_TOPIC, self) + mqtt.subscribe_module(GENERATION_TOPIC, self) + + def tick(self): + if not self._inverter: + return + + # read status + now = time.time() + try: + raw = self._inverter.exec('get-status') + except inverterd.InverterError as e: + self._logger.error(f'inverter error: {str(e)}') + # TODO send to server + return + + data = json.loads(raw)['data'] + status = MqttInverterStatusPayload(time=round(now), **data) + self._mqtt_node_ref.publish(STATUS_TOPIC, status.pack()) + + # read today's generation stat + now = time.time() + if self._gen_prev == 0 or now - self._gen_prev >= self._generation_poll_freq: + self._gen_prev = now + today = datetime.date.today() + try: + raw = self._inverter.exec('get-day-generated', (today.year, today.month, today.day)) + except inverterd.InverterError as e: + self._logger.error(f'inverter error: {str(e)}') + # TODO send to server + return + + data = json.loads(raw)['data'] + gen = MqttInverterGenerationPayload(time=round(now), wh=data['wh']) + self._mqtt_node_ref.publish(GENERATION_TOPIC, gen.pack()) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + home_id = 1 # legacy compat + + if topic == STATUS_TOPIC: + s = MqttInverterStatusPayload.unpack(payload) + self._database.add_status(home_id=home_id, + client_time=s.time, + grid_voltage=int(s.grid_voltage*10), + grid_freq=int(s.grid_freq * 10), + ac_output_voltage=int(s.ac_output_voltage * 10), + ac_output_freq=int(s.ac_output_freq * 10), + ac_output_apparent_power=s.ac_output_apparent_power, + ac_output_active_power=s.ac_output_active_power, + output_load_percent=s.output_load_percent, + battery_voltage=int(s.battery_voltage * 10), + battery_voltage_scc=int(s.battery_voltage_scc * 10), + battery_voltage_scc2=int(s.battery_voltage_scc2 * 10), + battery_discharge_current=s.battery_discharge_current, + battery_charge_current=s.battery_charge_current, + battery_capacity=s.battery_capacity, + inverter_heat_sink_temp=s.inverter_heat_sink_temp, + mppt1_charger_temp=s.mppt1_charger_temp, + mppt2_charger_temp=s.mppt2_charger_temp, + pv1_input_power=s.pv1_input_power, + pv2_input_power=s.pv2_input_power, + pv1_input_voltage=int(s.pv1_input_voltage * 10), + pv2_input_voltage=int(s.pv2_input_voltage * 10), + mppt1_charger_status=s.mppt1_charger_status, + mppt2_charger_status=s.mppt2_charger_status, + battery_power_direction=s.battery_power_direction, + dc_ac_power_direction=s.dc_ac_power_direction, + line_power_direction=s.line_power_direction, + load_connected=s.load_connected) + return s + + elif topic == GENERATION_TOPIC: + gen = MqttInverterGenerationPayload.unpack(payload) + self._database.add_generation(home_id, gen.time, gen.wh) + return gen diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py new file mode 100644 index 0000000..cd34332 --- /dev/null +++ b/src/home/mqtt/module/ota.py @@ -0,0 +1,77 @@ +import hashlib + +from typing import Optional +from .._payload import MqttPayload +from .._node import MqttModule, MqttNode + +MODULE_NAME = 'MqttOtaModule' + + +class OtaResultPayload(MqttPayload): + FORMAT = '=BB' + result: int + error_code: int + + +class OtaPayload(MqttPayload): + secret: str + filename: str + + # structure of returned data: + # + # uint8_t[len(secret)] secret; + # uint8_t[16] md5; + # *uint8_t data + + def pack(self): + buf = bytearray(self.secret.encode()) + m = hashlib.md5() + with open(self.filename, 'rb') as fd: + content = fd.read() + m.update(content) + buf.extend(m.digest()) + buf.extend(content) + return buf + + def unpack(cls, buf: bytes): + raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') + # secret = buf[:12].decode() + # filename = buf[12:].decode() + # return OTAPayload(secret=secret, filename=filename) + + +class MqttOtaModule(MqttModule): + _ota_request: Optional[tuple[str, int]] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._ota_request = None + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + mqtt.subscribe_module("otares", self) + + if self._ota_request is not None: + filename, qos = self._ota_request + self._ota_request = None + self.do_push_ota(self._mqtt_node_ref.secret, filename, qos) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + if topic == 'otares': + message = OtaResultPayload.unpack(payload) + self._logger.debug(message) + return message + + def do_push_ota(self, secret: str, filename: str, qos: int): + payload = OtaPayload(secret=secret, filename=filename) + self._mqtt_node_ref.publish('ota', + payload=payload.pack(), + qos=qos) + + def push_ota(self, + filename: str, + qos: int): + if not self._initialized: + self._ota_request = (filename, qos) + else: + self.do_push_ota(filename, qos) diff --git a/src/home/mqtt/module/relay.py b/src/home/mqtt/module/relay.py new file mode 100644 index 0000000..e968031 --- /dev/null +++ b/src/home/mqtt/module/relay.py @@ -0,0 +1,92 @@ +import datetime + +from typing import Optional +from .. import MqttModule, MqttPayload, MqttNode + +MODULE_NAME = 'MqttRelayModule' + + +class MqttPowerSwitchPayload(MqttPayload): + FORMAT = '=12sB' + PACKER = { + 'state': lambda n: int(n), + 'secret': lambda s: s.encode('utf-8') + } + UNPACKER = { + 'state': lambda n: bool(n), + 'secret': lambda s: s.decode('utf-8') + } + + secret: str + state: bool + + +class MqttPowerStatusPayload(MqttPayload): + FORMAT = '=B' + PACKER = { + 'opened': lambda n: int(n), + } + UNPACKER = { + 'opened': lambda n: bool(n), + } + + opened: bool + + +class MqttRelayState: + enabled: bool + update_time: datetime.datetime + rssi: int + fw_version: int + ever_updated: bool + + def __init__(self): + self.ever_updated = False + self.enabled = False + self.rssi = 0 + + def update(self, + enabled: bool, + rssi: int, + fw_version=None): + self.ever_updated = True + self.enabled = enabled + self.rssi = rssi + self.update_time = datetime.datetime.now() + if fw_version: + self.fw_version = fw_version + + +class MqttRelayModule(MqttModule): + _legacy_topics: bool + + def __init__(self, legacy_topics=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self._legacy_topics = legacy_topics + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + mqtt.subscribe_module(self._get_switch_topic(), self) + mqtt.subscribe_module('relay/status', self) + + def switchpower(self, + enable: bool): + payload = MqttPowerSwitchPayload(secret=self._mqtt_node_ref.secret, + state=enable) + self._mqtt_node_ref.publish(self._get_switch_topic(), + payload=payload.pack()) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + message = None + + if topic == self._get_switch_topic(): + message = MqttPowerSwitchPayload.unpack(payload) + elif topic == 'relay/status': + message = MqttPowerStatusPayload.unpack(payload) + + if message is not None: + self._logger.debug(message) + return message + + def _get_switch_topic(self) -> str: + return 'relay/power' if self._legacy_topics else 'relay/switch' diff --git a/src/home/mqtt/module/temphum.py b/src/home/mqtt/module/temphum.py new file mode 100644 index 0000000..fd02cca --- /dev/null +++ b/src/home/mqtt/module/temphum.py @@ -0,0 +1,82 @@ +from .._node import MqttNode +from .._module import MqttModule +from .._payload import MqttPayload +from typing import Optional +from ...temphum import BaseSensor + +two_digits_precision = lambda x: round(x, 2) + +MODULE_NAME = 'MqttTempHumModule' +DATA_TOPIC = 'temphum/data' + + +class MqttTemphumDataPayload(MqttPayload): + FORMAT = '=ddb' + UNPACKER = { + 'temp': two_digits_precision, + 'rh': two_digits_precision + } + + temp: float + rh: float + error: int + + +# class MqttTempHumNodes(HashableEnum): +# KBN_SH_HALL = auto() +# KBN_SH_BATHROOM = auto() +# KBN_SH_LIVINGROOM = auto() +# KBN_SH_BEDROOM = auto() +# +# KBN_BH_2FL = auto() +# KBN_BH_2FL_STREET = auto() +# KBN_BH_1FL_LIVINGROOM = auto() +# KBN_BH_1FL_BEDROOM = auto() +# KBN_BH_1FL_BATHROOM = auto() +# +# KBN_NH_1FL_INV = auto() +# KBN_NH_1FL_CENTER = auto() +# KBN_NH_1LF_KT = auto() +# KBN_NH_1FL_DS = auto() +# KBN_NH_1FS_EZ = auto() +# +# SPB_FLAT120_CABINET = auto() + + +class MqttTempHumModule(MqttModule): + def __init__(self, + sensor: Optional[BaseSensor] = None, + write_to_database=False, + *args, **kwargs): + if sensor is not None: + kwargs['tick_interval'] = 10 + super().__init__(*args, **kwargs) + self._sensor = sensor + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + mqtt.subscribe_module(DATA_TOPIC, self) + + def tick(self): + if not self._sensor: + return + + error = 0 + temp = 0 + rh = 0 + try: + temp = self._sensor.temperature() + rh = self._sensor.humidity() + except: + error = 1 + pld = MqttTemphumDataPayload(temp=temp, rh=rh, error=error) + self._mqtt_node_ref.publish(DATA_TOPIC, pld.pack()) + + def handle_payload(self, + mqtt: MqttNode, + topic: str, + payload: bytes) -> Optional[MqttPayload]: + if topic == DATA_TOPIC: + message = MqttTemphumDataPayload.unpack(payload) + self._logger.debug(message) + return message diff --git a/src/home/mqtt/payload/__init__.py b/src/home/mqtt/payload/__init__.py deleted file mode 100644 index eee6709..0000000 --- a/src/home/mqtt/payload/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base_payload import MqttPayload
\ No newline at end of file diff --git a/src/home/mqtt/payload/inverter.py b/src/home/mqtt/payload/inverter.py deleted file mode 100644 index 09388df..0000000 --- a/src/home/mqtt/payload/inverter.py +++ /dev/null @@ -1,73 +0,0 @@ -import struct - -from .base_payload import MqttPayload, bit_field -from typing import Tuple - -_mult_10 = lambda n: int(n*10) -_div_10 = lambda n: n/10 - - -class Status(MqttPayload): - # 46 bytes - FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH' - - PACKER = { - 'grid_voltage': _mult_10, - 'grid_freq': _mult_10, - 'ac_output_voltage': _mult_10, - 'ac_output_freq': _mult_10, - 'battery_voltage': _mult_10, - 'battery_voltage_scc': _mult_10, - 'battery_voltage_scc2': _mult_10, - 'pv1_input_voltage': _mult_10, - 'pv2_input_voltage': _mult_10 - } - UNPACKER = { - 'grid_voltage': _div_10, - 'grid_freq': _div_10, - 'ac_output_voltage': _div_10, - 'ac_output_freq': _div_10, - 'battery_voltage': _div_10, - 'battery_voltage_scc': _div_10, - 'battery_voltage_scc2': _div_10, - 'pv1_input_voltage': _div_10, - 'pv2_input_voltage': _div_10 - } - - time: int - grid_voltage: float - grid_freq: float - ac_output_voltage: float - ac_output_freq: float - ac_output_apparent_power: int - ac_output_active_power: int - output_load_percent: int - battery_voltage: float - battery_voltage_scc: float - battery_voltage_scc2: float - battery_discharge_current: int - battery_charge_current: int - battery_capacity: int - inverter_heat_sink_temp: int - mppt1_charger_temp: int - mppt2_charger_temp: int - pv1_input_power: int - pv2_input_power: int - pv1_input_voltage: float - pv2_input_voltage: float - - # H - mppt1_charger_status: bit_field(0, 16, 2) - mppt2_charger_status: bit_field(0, 16, 2) - battery_power_direction: bit_field(0, 16, 2) - dc_ac_power_direction: bit_field(0, 16, 2) - line_power_direction: bit_field(0, 16, 2) - load_connected: bit_field(0, 16, 1) - - -class Generation(MqttPayload): - # 8 bytes - FORMAT = 'II' - - time: int - wh: int diff --git a/src/home/mqtt/payload/relay.py b/src/home/mqtt/payload/relay.py deleted file mode 100644 index 4902991..0000000 --- a/src/home/mqtt/payload/relay.py +++ /dev/null @@ -1,22 +0,0 @@ -from .base_payload import MqttPayload -from .esp import ( - OTAResultPayload, - OTAPayload, - InitialDiagnosticsPayload, - DiagnosticsPayload -) - - -class PowerPayload(MqttPayload): - FORMAT = '=12sB' - PACKER = { - 'state': lambda n: int(n), - 'secret': lambda s: s.encode('utf-8') - } - UNPACKER = { - 'state': lambda n: bool(n), - 'secret': lambda s: s.decode('utf-8') - } - - secret: str - state: bool diff --git a/src/home/mqtt/payload/sensors.py b/src/home/mqtt/payload/sensors.py deleted file mode 100644 index f99b307..0000000 --- a/src/home/mqtt/payload/sensors.py +++ /dev/null @@ -1,20 +0,0 @@ -from .base_payload import MqttPayload - -_mult_100 = lambda n: int(n*100) -_div_100 = lambda n: n/100 - - -class Temperature(MqttPayload): - FORMAT = 'IhH' - PACKER = { - 'temp': _mult_100, - 'rh': _mult_100, - } - UNPACKER = { - 'temp': _div_100, - 'rh': _div_100, - } - - time: int - temp: float - rh: float diff --git a/src/home/mqtt/payload/temphum.py b/src/home/mqtt/payload/temphum.py deleted file mode 100644 index c0b744e..0000000 --- a/src/home/mqtt/payload/temphum.py +++ /dev/null @@ -1,15 +0,0 @@ -from .base_payload import MqttPayload - -two_digits_precision = lambda x: round(x, 2) - - -class TempHumDataPayload(MqttPayload): - FORMAT = '=ddb' - UNPACKER = { - 'temp': two_digits_precision, - 'rh': two_digits_precision - } - - temp: float - rh: float - error: int diff --git a/src/home/mqtt/relay.py b/src/home/mqtt/relay.py deleted file mode 100644 index a90f19c..0000000 --- a/src/home/mqtt/relay.py +++ /dev/null @@ -1,71 +0,0 @@ -import paho.mqtt.client as mqtt -import re -import datetime - -from .payload.relay import ( - PowerPayload, -) -from .esp import MqttEspBase - - -class MqttRelay(MqttEspBase): - TOPIC_LEAF = 'relay' - - def set_power(self, device_id, enable: bool, secret=None): - device = next(d for d in self._devices if d.id == device_id) - secret = secret if secret else device.secret - - assert secret is not None, 'device secret not specified' - - payload = PowerPayload(secret=secret, - state=enable) - self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/power', - payload=payload.pack(), - qos=1) - self._client.loop_write() - - def on_message(self, client: mqtt.Client, userdata, msg): - if super().on_message(client, userdata, msg): - return - - try: - match = re.match(self.get_mqtt_topics(['power']), msg.topic) - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - message = None - if subtopic == 'power': - message = PowerPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - - except Exception as e: - self._logger.exception(str(e)) - - -class MqttRelayState: - enabled: bool - update_time: datetime.datetime - rssi: int - fw_version: int - ever_updated: bool - - def __init__(self): - self.ever_updated = False - self.enabled = False - self.rssi = 0 - - def update(self, - enabled: bool, - rssi: int, - fw_version=None): - self.ever_updated = True - self.enabled = enabled - self.rssi = rssi - self.update_time = datetime.datetime.now() - if fw_version: - self.fw_version = fw_version diff --git a/src/home/mqtt/temphum.py b/src/home/mqtt/temphum.py deleted file mode 100644 index 44810ef..0000000 --- a/src/home/mqtt/temphum.py +++ /dev/null @@ -1,54 +0,0 @@ -import paho.mqtt.client as mqtt -import re - -from enum import auto -from .payload.temphum import TempHumDataPayload -from .esp import MqttEspBase -from ..util import HashableEnum - - -class MqttTempHumNodes(HashableEnum): - KBN_SH_HALL = auto() - KBN_SH_BATHROOM = auto() - KBN_SH_LIVINGROOM = auto() - KBN_SH_BEDROOM = auto() - - KBN_BH_2FL = auto() - KBN_BH_2FL_STREET = auto() - KBN_BH_1FL_LIVINGROOM = auto() - KBN_BH_1FL_BEDROOM = auto() - KBN_BH_1FL_BATHROOM = auto() - - KBN_NH_1FL_INV = auto() - KBN_NH_1FL_CENTER = auto() - KBN_NH_1LF_KT = auto() - KBN_NH_1FL_DS = auto() - KBN_NH_1FS_EZ = auto() - - SPB_FLAT120_CABINET = auto() - - -class MqttTempHum(MqttEspBase): - TOPIC_LEAF = 'temphum' - - def on_message(self, client: mqtt.Client, userdata, msg): - if super().on_message(client, userdata, msg): - return - - try: - match = re.match(self.get_mqtt_topics(['data']), msg.topic) - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - message = None - if subtopic == 'data': - message = TempHumDataPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - - except Exception as e: - self._logger.exception(str(e)) diff --git a/src/home/mqtt/util.py b/src/home/mqtt/util.py deleted file mode 100644 index f71ffd8..0000000 --- a/src/home/mqtt/util.py +++ /dev/null @@ -1,8 +0,0 @@ -import time - - -def poll_tick(freq): - t = time.time() - while True: - t += freq - yield max(t - time.time(), 0) diff --git a/src/home/pio/products.py b/src/home/pio/products.py index 7649078..388da03 100644 --- a/src/home/pio/products.py +++ b/src/home/pio/products.py @@ -16,10 +16,6 @@ _products_dir = os.path.join( def get_products(): products = [] for f in os.listdir(_products_dir): - # temp hack - if f.endswith('-esp01'): - continue - # skip the common dir if f in ('common',): continue diff --git a/src/home/telegram/_botcontext.py b/src/home/telegram/_botcontext.py index f343eeb..a143bfe 100644 --- a/src/home/telegram/_botcontext.py +++ b/src/home/telegram/_botcontext.py @@ -1,6 +1,7 @@ from typing import Optional, List -from telegram import Update, ParseMode, User, CallbackQuery +from telegram import Update, User, CallbackQuery +from telegram.constants import ParseMode from telegram.ext import CallbackContext from ._botdb import BotDatabase @@ -26,25 +27,25 @@ class Context: self._store = store self._user_lang = None - def reply(self, text, markup=None): + async def reply(self, text, markup=None): if markup is None: markup = self._markup_getter(self) kwargs = dict(parse_mode=ParseMode.HTML) if not isinstance(markup, IgnoreMarkup): kwargs['reply_markup'] = markup - return self._update.message.reply_text(text, **kwargs) + return await self._update.message.reply_text(text, **kwargs) - def reply_exc(self, e: Exception) -> None: - self.reply(exc2text(e), markup=IgnoreMarkup()) + async def reply_exc(self, e: Exception) -> None: + await self.reply(exc2text(e), markup=IgnoreMarkup()) - def answer(self, text: str = None): - self.callback_query.answer(text) + async def answer(self, text: str = None): + await self.callback_query.answer(text) - def edit(self, text, markup=None): + async def edit(self, text, markup=None): kwargs = dict(parse_mode=ParseMode.HTML) if not isinstance(markup, IgnoreMarkup): kwargs['reply_markup'] = markup - self.callback_query.edit_message_text(text, **kwargs) + await self.callback_query.edit_message_text(text, **kwargs) @property def text(self) -> str: diff --git a/src/home/telegram/bot.py b/src/home/telegram/bot.py index 10bfe06..7e22263 100644 --- a/src/home/telegram/bot.py +++ b/src/home/telegram/bot.py @@ -5,19 +5,19 @@ import itertools from enum import Enum, auto from functools import wraps -from typing import Optional, Union, Tuple +from typing import Optional, Union, Tuple, Coroutine from telegram import Update, ReplyKeyboardMarkup from telegram.ext import ( - Updater, - Filters, - BaseFilter, + Application, + filters, CommandHandler, MessageHandler, CallbackQueryHandler, CallbackContext, ConversationHandler ) +from telegram.ext.filters import BaseFilter from telegram.error import TimedOut from home.config import config @@ -33,26 +33,26 @@ from ._botcontext import Context db: Optional[BotDatabase] = None _user_filter: Optional[BaseFilter] = None -_cancel_filter = Filters.text(lang.all('cancel')) -_back_filter = Filters.text(lang.all('back')) -_cancel_and_back_filter = Filters.text(lang.all('back') + lang.all('cancel')) +_cancel_filter = filters.Text(lang.all('cancel')) +_back_filter = filters.Text(lang.all('back')) +_cancel_and_back_filter = filters.Text(lang.all('back') + lang.all('cancel')) _logger = logging.getLogger(__name__) -_updater: Optional[Updater] = None +_application: Optional[Application] = None _reporting: Optional[ReportingHelper] = None -_exception_handler: Optional[callable] = None +_exception_handler: Optional[Coroutine] = None _dispatcher = None _markup_getter: Optional[callable] = None -_start_handler_ref: Optional[callable] = None +_start_handler_ref: Optional[Coroutine] = None def text_filter(*args): if not _user_filter: raise RuntimeError('user_filter is not initialized') - return Filters.text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter + return filters.Text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter -def _handler_of_handler(*args, **kwargs): +async def _handler_of_handler(*args, **kwargs): self = None context = None update = None @@ -99,7 +99,7 @@ def _handler_of_handler(*args, **kwargs): if self: _args.insert(0, self) - result = f(*_args, **kwargs) + result = await f(*_args, **kwargs) return result if not return_with_context else (result, ctx) except Exception as e: @@ -107,7 +107,7 @@ def _handler_of_handler(*args, **kwargs): if not _exception_handler(e, ctx) and not isinstance(e, TimedOut): _logger.exception(e) if not ctx.is_callback_context(): - ctx.reply_exc(e) + await ctx.reply_exc(e) else: notify_user(ctx.user_id, exc2text(e)) else: @@ -117,10 +117,10 @@ def _handler_of_handler(*args, **kwargs): def handler(**kwargs): def inner(f): @wraps(f) - def _handler(*args, **inner_kwargs): + async def _handler(*args, **inner_kwargs): if 'argument' in kwargs and kwargs['argument'] == 'message_key': inner_kwargs['argument'] = 'message_key' - return _handler_of_handler(f=f, *args, **inner_kwargs) + return await _handler_of_handler(f=f, *args, **inner_kwargs) messages = [] texts = [] @@ -139,43 +139,43 @@ def handler(**kwargs): new_messages = list(itertools.chain.from_iterable([lang.all(m) for m in messages])) texts += new_messages texts = list(set(texts)) - _updater.dispatcher.add_handler( + _application.add_handler( MessageHandler(text_filter(*texts), _handler), group=0 ) if 'command' in kwargs: - _updater.dispatcher.add_handler(CommandHandler(kwargs['command'], _handler), group=0) + _application.add_handler(CommandHandler(kwargs['command'], _handler), group=0) if 'callback' in kwargs: - _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) + _application.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) return _handler return inner -def simplehandler(f: callable): +def simplehandler(f: Coroutine): @wraps(f) - def _handler(*args, **kwargs): - return _handler_of_handler(f=f, *args, **kwargs) + async def _handler(*args, **kwargs): + return await _handler_of_handler(f=f, *args, **kwargs) return _handler def callbackhandler(*args, **kwargs): def inner(f): @wraps(f) - def _handler(*args, **kwargs): - return _handler_of_handler(f=f, *args, **kwargs) + async def _handler(*args, **kwargs): + return await _handler_of_handler(f=f, *args, **kwargs) pattern_kwargs = {} if kwargs['callback'] != '*': pattern_kwargs['pattern'] = kwargs['callback'] - _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) + _application.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) return _handler return inner -def exceptionhandler(f: callable): +async def exceptionhandler(f: callable): global _exception_handler if _exception_handler: _logger.warning('exception handler already set, we will overwrite it') @@ -198,10 +198,10 @@ def convinput(state, is_enter=False, **kwargs): ) @wraps(f) - def _impl(*args, **kwargs): - result, ctx = _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) + async def _impl(*args, **kwargs): + result, ctx = await _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) if result == conversation.END: - start(ctx) + await start(ctx) return result return _impl @@ -252,7 +252,7 @@ class conversation: handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), self.make_invoker(target_state))) if 'regex' in kwargs: - handlers.append(MessageHandler(Filters.regex(kwargs['regex']) & _user_filter, f)) + handlers.append(MessageHandler(filters.Regex(kwargs['regex']) & _user_filter, f)) if 'command' in kwargs: handlers.append(CommandHandler(kwargs['command'], f, _user_filter)) @@ -327,21 +327,21 @@ class conversation: @staticmethod @simplehandler - def invalid(ctx: Context): - ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) + async def invalid(ctx: Context): + await ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) # return 0 # FIXME is this needed @simplehandler - def cancel(self, ctx: Context): - start(ctx) + async def cancel(self, ctx: Context): + await start(ctx) self.set_user_state(ctx.user_id, None) return conversation.END @simplehandler - def back(self, ctx: Context): + async def back(self, ctx: Context): cur_state = self.get_user_state(ctx.user_id) if cur_state is None: - start(ctx) + await start(ctx) self.set_user_state(ctx.user_id, None) return conversation.END @@ -411,7 +411,7 @@ class LangConversation(conversation): START, = range(1) @conventer(START, command='lang') - def entry(self, ctx: Context): + async def entry(self, ctx: Context): self._logger.debug(f'current language: {ctx.user_lang}') buttons = [] @@ -419,11 +419,11 @@ class LangConversation(conversation): buttons.append(name) markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False) - ctx.reply(ctx.lang('select_language'), markup=markup) + await ctx.reply(ctx.lang('select_language'), markup=markup) return self.START @convinput(START, messages=lang.languages) - def input(self, ctx: Context): + async def input(self, ctx: Context): selected_lang = None for key, value in languages.items(): if value == ctx.text: @@ -434,30 +434,34 @@ class LangConversation(conversation): raise ValueError('could not find the language') db.set_user_lang(ctx.user_id, selected_lang) - ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) + await ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) return self.END def initialize(): global _user_filter - global _updater + global _application + # global _updater global _dispatcher # init user_filter - if 'users' in config['bot']: - _logger.info('allowed users: ' + str(config['bot']['users'])) - _user_filter = Filters.user(config['bot']['users']) + _user_ids = config.app_config.get_user_ids() + if len(_user_ids) > 0: + _logger.info('allowed users: ' + str(_user_ids)) + _user_filter = filters.User(_user_ids) else: - _user_filter = Filters.all # not sure if this is correct + _user_filter = filters.ALL # not sure if this is correct - # init updater - _updater = Updater(config['bot']['token'], - request_kwargs={'read_timeout': 6, 'connect_timeout': 7}) + _application = Application.builder()\ + .token(config.app_config.get('bot.token'))\ + .connect_timeout(7)\ + .read_timeout(6)\ + .build() # transparently log all messages - _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, _logging_message_handler), group=10) - _updater.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) + # _application.dispatcher.add_handler(MessageHandler(filters.ALL & _user_filter, _logging_message_handler), group=10) + # _application.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) def run(start_handler=None, any_handler=None): @@ -473,37 +477,38 @@ def run(start_handler=None, any_handler=None): _start_handler_ref = start_handler - _updater.dispatcher.add_handler(LangConversation().get_handler(), group=0) - _updater.dispatcher.add_handler(CommandHandler('start', simplehandler(start_handler), _user_filter)) - _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, any_handler)) + _application.add_handler(LangConversation().get_handler(), group=0) + _application.add_handler(CommandHandler('start', + callback=simplehandler(start_handler), + filters=_user_filter)) + _application.add_handler(MessageHandler(filters.ALL & _user_filter, any_handler)) - _updater.start_polling() - _updater.idle() + _application.run_polling() def add_conversation(conv: conversation) -> None: - _updater.dispatcher.add_handler(conv.get_handler(), group=0) + _application.add_handler(conv.get_handler(), group=0) def add_handler(h): - _updater.dispatcher.add_handler(h, group=0) + _application.add_handler(h, group=0) -def start(ctx: Context): - return _start_handler_ref(ctx) +async def start(ctx: Context): + return await _start_handler_ref(ctx) -def _default_start_handler(ctx: Context): +async def _default_start_handler(ctx: Context): if 'start_message' not in lang: - return ctx.reply('Please define start_message or override start()') - ctx.reply(ctx.lang('start_message')) + return await ctx.reply('Please define start_message or override start()') + await ctx.reply(ctx.lang('start_message')) @simplehandler -def _default_any_handler(ctx: Context): +async def _default_any_handler(ctx: Context): if 'invalid_command' not in lang: - return ctx.reply('Please define invalid_command or override any()') - ctx.reply(ctx.lang('invalid_command')) + return await ctx.reply('Please define invalid_command or override any()') + await ctx.reply(ctx.lang('invalid_command')) def _logging_message_handler(update: Update, context: CallbackContext): @@ -535,7 +540,7 @@ def notify_all(text_getter: callable, continue text = text_getter(db.get_user_lang(user_id)) - _updater.bot.send_message(chat_id=user_id, + _application.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML') @@ -543,33 +548,33 @@ def notify_all(text_getter: callable, def notify_user(user_id: int, text: Union[str, Exception], **kwargs) -> None: if isinstance(text, Exception): text = exc2text(text) - _updater.bot.send_message(chat_id=user_id, + _application.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML', **kwargs) def send_photo(user_id, **kwargs): - _updater.bot.send_photo(chat_id=user_id, **kwargs) + _application.bot.send_photo(chat_id=user_id, **kwargs) def send_audio(user_id, **kwargs): - _updater.bot.send_audio(chat_id=user_id, **kwargs) + _application.bot.send_audio(chat_id=user_id, **kwargs) def send_file(user_id, **kwargs): - _updater.bot.send_document(chat_id=user_id, **kwargs) + _application.bot.send_document(chat_id=user_id, **kwargs) def edit_message_text(user_id, message_id, *args, **kwargs): - _updater.bot.edit_message_text(chat_id=user_id, + _application.bot.edit_message_text(chat_id=user_id, message_id=message_id, parse_mode='HTML', *args, **kwargs) def delete_message(user_id, message_id): - _updater.bot.delete_message(chat_id=user_id, message_id=message_id) + _application.bot.delete_message(chat_id=user_id, message_id=message_id) def set_database(_db: BotDatabase): diff --git a/src/home/telegram/config.py b/src/home/telegram/config.py new file mode 100644 index 0000000..7a46087 --- /dev/null +++ b/src/home/telegram/config.py @@ -0,0 +1,75 @@ +from ..config import ConfigUnit +from typing import Optional, Union +from abc import ABC +from enum import Enum + + +class TelegramUserListType(Enum): + USERS = 'users' + NOTIFY = 'notify_users' + + +class TelegramUserIdsConfig(ConfigUnit): + NAME = 'telegram_user_ids' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'roottype': 'dict', + 'type': 'integer' + } + + +_user_ids_config = TelegramUserIdsConfig() + + +def _user_id_mapper(user: Union[str, int]) -> int: + if isinstance(user, int): + return user + return _user_ids_config[user] + + +class TelegramChatsConfig(ConfigUnit): + NAME = 'telegram_chats' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'type': 'dict', + 'schema': { + 'id': {'type': 'string', 'required': True}, + 'token': {'type': 'string', 'required': True}, + } + } + + +class TelegramBotConfig(ConfigUnit, ABC): + @staticmethod + def schema() -> Optional[dict]: + return { + 'bot': { + 'type': 'dict', + 'schema': { + 'token': {'type': 'string', 'required': True}, + TelegramUserListType.USERS: {**TelegramBotConfig._userlist_schema(), 'required': True}, + TelegramUserListType.NOTIFY: TelegramBotConfig._userlist_schema(), + } + } + } + + @staticmethod + def _userlist_schema() -> dict: + return {'type': 'list', 'schema': {'type': ['string', 'int']}} + + @staticmethod + def custom_validator(data): + for ult in TelegramUserListType: + users = data['bot'][ult.value] + for user in users: + if isinstance(user, str): + if user not in _user_ids_config: + raise ValueError(f'user {user} not found in {TelegramUserIdsConfig.NAME}') + + def get_user_ids(self, + ult: TelegramUserListType = TelegramUserListType.USERS) -> list[int]: + return list(map(_user_id_mapper, self['bot'][ult.value]))
\ No newline at end of file diff --git a/src/home/temphum/__init__.py b/src/home/temphum/__init__.py index 55a7e1f..46d14e6 100644 --- a/src/home/temphum/__init__.py +++ b/src/home/temphum/__init__.py @@ -1,18 +1 @@ -from .base import SensorType, TempHumSensor -from .si7021 import Si7021 -from .dht12 import DHT12 - -__all__ = [ - 'SensorType', - 'TempHumSensor', - 'create_sensor' -] - - -def create_sensor(type: SensorType, bus: int) -> TempHumSensor: - if type == SensorType.Si7021: - return Si7021(bus) - elif type == SensorType.DHT12: - return DHT12(bus) - else: - raise ValueError('unexpected sensor type') +from .base import SensorType, BaseSensor diff --git a/src/home/temphum/base.py b/src/home/temphum/base.py index e774433..602cab7 100644 --- a/src/home/temphum/base.py +++ b/src/home/temphum/base.py @@ -1,25 +1,19 @@ -import smbus - -from abc import abstractmethod, ABC +from abc import ABC from enum import Enum -class TempHumSensor: - @abstractmethod +class BaseSensor(ABC): + def __init__(self, bus: int): + super().__init__() + self.bus = smbus.SMBus(bus) + def humidity(self) -> float: pass - @abstractmethod def temperature(self) -> float: pass -class I2CTempHumSensor(TempHumSensor, ABC): - def __init__(self, bus: int): - super().__init__() - self.bus = smbus.SMBus(bus) - - class SensorType(Enum): Si7021 = 'si7021' - DHT12 = 'dht12' + DHT12 = 'dht12'
\ No newline at end of file diff --git a/src/home/temphum/dht12.py b/src/home/temphum/dht12.py deleted file mode 100644 index d495766..0000000 --- a/src/home/temphum/dht12.py +++ /dev/null @@ -1,22 +0,0 @@ -from .base import I2CTempHumSensor - - -class DHT12(I2CTempHumSensor): - i2c_addr = 0x5C - - def _measure(self): - raw = self.bus.read_i2c_block_data(self.i2c_addr, 0, 5) - if (raw[0] + raw[1] + raw[2] + raw[3]) & 0xff != raw[4]: - raise ValueError("checksum error") - return raw - - def temperature(self) -> float: - raw = self._measure() - temp = raw[2] + (raw[3] & 0x7f) * 0.1 - if raw[3] & 0x80: - temp *= -1 - return temp - - def humidity(self) -> float: - raw = self._measure() - return raw[0] + raw[1] * 0.1 diff --git a/src/home/temphum/i2c.py b/src/home/temphum/i2c.py new file mode 100644 index 0000000..7d8e2e3 --- /dev/null +++ b/src/home/temphum/i2c.py @@ -0,0 +1,52 @@ +import abc +import smbus + +from .base import BaseSensor, SensorType + + +class I2CSensor(BaseSensor, abc.ABC): + def __init__(self, bus: int): + super().__init__() + self.bus = smbus.SMBus(bus) + + +class DHT12(I2CSensor): + i2c_addr = 0x5C + + def _measure(self): + raw = self.bus.read_i2c_block_data(self.i2c_addr, 0, 5) + if (raw[0] + raw[1] + raw[2] + raw[3]) & 0xff != raw[4]: + raise ValueError("checksum error") + return raw + + def temperature(self) -> float: + raw = self._measure() + temp = raw[2] + (raw[3] & 0x7f) * 0.1 + if raw[3] & 0x80: + temp *= -1 + return temp + + def humidity(self) -> float: + raw = self._measure() + return raw[0] + raw[1] * 0.1 + + +class Si7021(I2CSensor): + i2c_addr = 0x40 + + def temperature(self) -> float: + raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE3, 2) + return 175.72 * (raw[0] << 8 | raw[1]) / 65536.0 - 46.85 + + def humidity(self) -> float: + raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE5, 2) + return 125.0 * (raw[0] << 8 | raw[1]) / 65536.0 - 6.0 + + +def create_sensor(type: SensorType, bus: int) -> BaseSensor: + if type == SensorType.Si7021: + return Si7021(bus) + elif type == SensorType.DHT12: + return DHT12(bus) + else: + raise ValueError('unexpected sensor type') diff --git a/src/home/temphum/si7021.py b/src/home/temphum/si7021.py deleted file mode 100644 index 6289e15..0000000 --- a/src/home/temphum/si7021.py +++ /dev/null @@ -1,13 +0,0 @@ -from .base import I2CTempHumSensor - - -class Si7021(I2CTempHumSensor): - i2c_addr = 0x40 - - def temperature(self) -> float: - raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE3, 2) - return 175.72 * (raw[0] << 8 | raw[1]) / 65536.0 - 46.85 - - def humidity(self) -> float: - raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE5, 2) - return 125.0 * (raw[0] << 8 | raw[1]) / 65536.0 - 6.0 diff --git a/src/home/util.py b/src/home/util.py index 93a9d8f..35505bc 100644 --- a/src/home/util.py +++ b/src/home/util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import socket import time @@ -6,17 +8,57 @@ import traceback import logging import string import random +import re from enum import Enum from datetime import datetime from typing import Tuple, Optional, List from zlib import adler32 -Addr = Tuple[str, int] # network address type (host, port) - logger = logging.getLogger(__name__) +def validate_ipv4_or_hostname(address: str, raise_exception: bool = False) -> bool: + if re.match(r'^(\d{1,3}\.){3}\d{1,3}$', address): + parts = address.split('.') + if all(0 <= int(part) < 256 for part in parts): + return True + else: + if raise_exception: + raise ValueError(f"invalid IPv4 address: {address}") + return False + + if re.match(r'^[a-zA-Z0-9.-]+$', address): + return True + else: + if raise_exception: + raise ValueError(f"invalid hostname: {address}") + return False + + +class Addr: + host: str + port: int + + def __init__(self, host: str, port: int): + self.host = host + self.port = port + + @staticmethod + def fromstring(addr: str) -> Addr: + if addr.count(':') != 1: + raise ValueError('invalid host:port format') + + host, port = addr.split(':') + validate_ipv4_or_hostname(host, raise_exception=True) + + port = int(port) + if not 0 <= port <= 65535: + raise ValueError(f'invalid port {port}') + + return Addr(host, port) + + # https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks def chunks(lst, n): """Yield successive n-sized chunks from lst.""" @@ -45,21 +87,6 @@ def ipv4_valid(ip: str) -> bool: return False -def parse_addr(addr: str) -> Addr: - if addr.count(':') != 1: - raise ValueError('invalid host:port format') - - host, port = addr.split(':') - if not ipv4_valid(host): - raise ValueError('invalid ipv4 address') - - port = int(port) - if not 0 <= port <= 65535: - raise ValueError('invalid port') - - return host, port - - def strgen(n: int): return ''.join(random.choices(string.ascii_letters + string.digits, k=n)) @@ -193,4 +220,11 @@ def filesize_fmt(num, suffix="B") -> str: class HashableEnum(Enum): def hash(self) -> int: - return adler32(self.name.encode())
\ No newline at end of file + return adler32(self.name.encode()) + + +def next_tick_gen(freq): + t = time.time() + while True: + t += freq + yield max(t - time.time(), 0)
\ No newline at end of file |