From f29e139cbb7e4a4d539cba6e894ef4a6acd312d6 Mon Sep 17 00:00:00 2001 From: Evgeny Zinoviev Date: Wed, 31 May 2023 09:22:00 +0300 Subject: WIP: big refactoring --- src/home/audio/amixer.py | 2 +- src/home/config/__init__.py | 14 +- src/home/config/_configs.py | 55 ++++++ src/home/config/config.py | 329 +++++++++++++++++++++++++--------- src/home/database/clickhouse.py | 2 +- src/home/database/sqlite.py | 25 +-- src/home/inverter/config.py | 13 ++ src/home/media/__init__.py | 1 + src/home/mqtt/__init__.py | 11 +- src/home/mqtt/_config.py | 165 +++++++++++++++++ src/home/mqtt/_module.py | 70 ++++++++ src/home/mqtt/_mqtt.py | 86 +++++++++ src/home/mqtt/_node.py | 92 ++++++++++ src/home/mqtt/_payload.py | 145 +++++++++++++++ src/home/mqtt/_util.py | 15 ++ src/home/mqtt/_wrapper.py | 59 ++++++ src/home/mqtt/esp.py | 106 ----------- src/home/mqtt/module/diagnostics.py | 64 +++++++ src/home/mqtt/module/inverter.py | 195 ++++++++++++++++++++ src/home/mqtt/module/ota.py | 77 ++++++++ src/home/mqtt/module/relay.py | 92 ++++++++++ src/home/mqtt/module/temphum.py | 82 +++++++++ src/home/mqtt/mqtt.py | 76 -------- src/home/mqtt/payload/__init__.py | 1 - src/home/mqtt/payload/base_payload.py | 145 --------------- src/home/mqtt/payload/esp.py | 78 -------- src/home/mqtt/payload/inverter.py | 73 -------- src/home/mqtt/payload/relay.py | 22 --- src/home/mqtt/payload/sensors.py | 20 --- src/home/mqtt/payload/temphum.py | 15 -- src/home/mqtt/relay.py | 71 -------- src/home/mqtt/temphum.py | 54 ------ src/home/mqtt/util.py | 8 - src/home/pio/products.py | 4 - src/home/telegram/_botcontext.py | 19 +- src/home/telegram/bot.py | 149 +++++++-------- src/home/telegram/config.py | 75 ++++++++ src/home/temphum/__init__.py | 19 +- src/home/temphum/base.py | 20 +-- src/home/temphum/dht12.py | 22 --- src/home/temphum/i2c.py | 52 ++++++ src/home/temphum/si7021.py | 13 -- src/home/util.py | 70 ++++++-- 43 files changed, 1766 insertions(+), 940 deletions(-) create mode 100644 src/home/config/_configs.py create mode 100644 src/home/inverter/config.py create mode 100644 src/home/mqtt/_config.py create mode 100644 src/home/mqtt/_module.py create mode 100644 src/home/mqtt/_mqtt.py create mode 100644 src/home/mqtt/_node.py create mode 100644 src/home/mqtt/_payload.py create mode 100644 src/home/mqtt/_util.py create mode 100644 src/home/mqtt/_wrapper.py delete mode 100644 src/home/mqtt/esp.py create mode 100644 src/home/mqtt/module/diagnostics.py create mode 100644 src/home/mqtt/module/inverter.py create mode 100644 src/home/mqtt/module/ota.py create mode 100644 src/home/mqtt/module/relay.py create mode 100644 src/home/mqtt/module/temphum.py delete mode 100644 src/home/mqtt/mqtt.py delete mode 100644 src/home/mqtt/payload/__init__.py delete mode 100644 src/home/mqtt/payload/base_payload.py delete mode 100644 src/home/mqtt/payload/esp.py delete mode 100644 src/home/mqtt/payload/inverter.py delete mode 100644 src/home/mqtt/payload/relay.py delete mode 100644 src/home/mqtt/payload/sensors.py delete mode 100644 src/home/mqtt/payload/temphum.py delete mode 100644 src/home/mqtt/relay.py delete mode 100644 src/home/mqtt/temphum.py delete mode 100644 src/home/mqtt/util.py create mode 100644 src/home/telegram/config.py delete mode 100644 src/home/temphum/dht12.py create mode 100644 src/home/temphum/i2c.py delete mode 100644 src/home/temphum/si7021.py (limited to 'src/home') diff --git a/src/home/audio/amixer.py b/src/home/audio/amixer.py index 53e6bce..5133c97 100644 --- a/src/home/audio/amixer.py +++ b/src/home/audio/amixer.py @@ -1,6 +1,6 @@ import subprocess -from ..config import config +from ..config import app_config as config from threading import Lock from typing import Union, List diff --git a/src/home/config/__init__.py b/src/home/config/__init__.py index cc9c091..2fa5214 100644 --- a/src/home/config/__init__.py +++ b/src/home/config/__init__.py @@ -1 +1,13 @@ -from .config import ConfigStore, config, is_development_mode, setup_logging +from .config import ( + Config, + ConfigUnit, + AppConfigUnit, + Translation, + config, + is_development_mode, + setup_logging +) +from ._configs import ( + LinuxBoardsConfig, + ServicesListConfig +) \ No newline at end of file diff --git a/src/home/config/_configs.py b/src/home/config/_configs.py new file mode 100644 index 0000000..3a1aae5 --- /dev/null +++ b/src/home/config/_configs.py @@ -0,0 +1,55 @@ +from .config import ConfigUnit +from typing import Optional + + +class ServicesListConfig(ConfigUnit): + NAME = 'services_list' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'type': 'list', + 'empty': False, + 'schema': { + 'type': 'string' + } + } + + +class LinuxBoardsConfig(ConfigUnit): + NAME = 'linux_boards' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'type': 'dict', + 'schema': { + 'mdns': {'type': 'string', 'required': True}, + 'board': {'type': 'string', 'required': True}, + 'network': { + 'type': 'list', + 'required': True, + 'empty': False, + 'allowed': ['wifi', 'ethernet'] + }, + 'ram': {'type': 'integer', 'required': True}, + 'online': {'type': 'boolean', 'required': True}, + + # optional + 'services': { + 'type': 'list', + 'empty': False, + 'allowed': ServicesListConfig().get() + }, + 'ext_hdd': { + 'type': 'list', + 'schema': { + 'type': 'dict', + 'schema': { + 'mountpoint': {'type': 'string', 'required': True}, + 'size': {'type': 'integer', 'required': True} + } + }, + }, + } + } diff --git a/src/home/config/config.py b/src/home/config/config.py index 4681685..aef9ee7 100644 --- a/src/home/config/config.py +++ b/src/home/config/config.py @@ -1,58 +1,256 @@ -import toml import yaml import logging import os +import pprint -from os.path import join, isdir, isfile -from typing import Optional, Any, MutableMapping +from abc import ABC +from cerberus import Validator, DocumentError +from typing import Optional, Any, MutableMapping, Union from argparse import ArgumentParser -from ..util import parse_addr +from enum import Enum, auto +from os.path import join, isdir, isfile +from ..util import Addr + + +CONFIG_DIRECTORIES = ( + join(os.environ['HOME'], '.config', 'homekit'), + '/etc/homekit' +) + +class RootSchemaType(Enum): + DEFAULT = auto() + DICT = auto() + LIST = auto() + + +class BaseConfigUnit(ABC): + _data: MutableMapping[str, Any] + _logger: logging.Logger + def __init__(self): + self._data = {} + self._logger = logging.getLogger(self.__class__.__name__) + + def __getitem__(self, key): + return self._data[key] + + def __setitem__(self, key, value): + raise NotImplementedError('overwriting config values is prohibited') -def _get_config_path(name: str) -> str: - formats = ['toml', 'yaml'] + def __contains__(self, key): + return key in self._data - dirname = join(os.environ['HOME'], '.config', name) + def load_from(self, path: str): + with open(path, 'r') as fd: + self._data = yaml.safe_load(fd) - if isdir(dirname): - for fmt in formats: - filename = join(dirname, f'config.{fmt}') - if isfile(filename): - return filename + def get(self, + key: Optional[str] = None, + default=None): + if key is None: + return self._data - raise IOError(f'config not found in {dirname}') + cur = self._data + pts = key.split('.') + for i in range(len(pts)): + k = pts[i] + if i < len(pts)-1: + if k not in cur: + raise KeyError(f'key {k} not found') + else: + return cur[k] if k in cur else default + cur = self._data[k] - else: - filenames = [join(os.environ['HOME'], '.config', f'{name}.{format}') for format in formats] - for file in filenames: - if isfile(file): - return file + raise KeyError(f'option {key} not found') - raise IOError(f'config not found') +class ConfigUnit(BaseConfigUnit): + NAME = 'dumb' + + def __init__(self, name=None, load=True): + super().__init__() + + self._data = {} + self._logger = logging.getLogger(self.__class__.__name__) + + if self.NAME != 'dumb' and load: + self.load_from(self.get_config_path()) + self.validate() + + elif name is not None: + self.NAME = name + + @classmethod + def get_config_path(cls, name=None) -> str: + if name is None: + name = cls.NAME + if name is None: + raise ValueError('get_config_path: name is none') + + for dirname in CONFIG_DIRECTORIES: + if isdir(dirname): + filename = join(dirname, f'{name}.yaml') + if isfile(filename): + return filename + + raise IOError(f'\'{name}.yaml\' not found') + + @staticmethod + def schema() -> Optional[dict]: + return None + + def validate(self): + schema = self.schema() + if not schema: + self._logger.warning('validate: no schema') + return + + if isinstance(self, AppConfigUnit): + schema['logging'] = { + 'type': 'dict', + 'schema': { + 'logging': {'type': 'bool'} + } + } + + rst = RootSchemaType.DEFAULT + try: + if schema['type'] == 'dict': + rst = RootSchemaType.DICT + elif schema['type'] == 'list': + rst = RootSchemaType.LIST + elif schema['roottype'] == 'dict': + del schema['roottype'] + rst = RootSchemaType.DICT + except KeyError: + pass + + if rst == RootSchemaType.DICT: + v = Validator({'document': { + 'type': 'dict', + 'keysrules': {'type': 'string'}, + 'valuesrules': schema + }}) + result = v.validate({'document': self._data}) + elif rst == RootSchemaType.LIST: + v = Validator({'document': schema}) + result = v.validate({'document': self._data}) + else: + v = Validator(schema) + result = v.validate(self._data) + # pprint.pprint(self._data) + if not result: + # pprint.pprint(v.errors) + raise DocumentError(f'{self.__class__.__name__}: failed to validate data:\n{pprint.pformat(v.errors)}') + try: + self.custom_validator(self._data) + except Exception as e: + raise DocumentError(f'{self.__class__.__name__}: {str(e)}') + + @staticmethod + def custom_validator(data): + pass -class ConfigStore: - data: MutableMapping[str, Any] + def get_addr(self, key: str): + return Addr.fromstring(self.get(key)) + + +class AppConfigUnit(ConfigUnit): + _logging_verbose: bool + _logging_fmt: Optional[str] + _logging_file: Optional[str] + + def __init__(self, *args, **kwargs): + super().__init__(load=False, *args, **kwargs) + self._logging_verbose = False + self._logging_fmt = None + self._logging_file = None + + def logging_set_fmt(self, fmt: str) -> None: + self._logging_fmt = fmt + + def logging_get_fmt(self) -> Optional[str]: + try: + return self['logging']['default_fmt'] + except KeyError: + return self._logging_fmt + + def logging_set_file(self, file: str) -> None: + self._logging_file = file + + def logging_get_file(self) -> Optional[str]: + try: + return self['logging']['file'] + except KeyError: + return self._logging_file + + def logging_set_verbose(self): + self._logging_verbose = True + + def logging_is_verbose(self) -> bool: + try: + return bool(self['logging']['verbose']) + except KeyError: + return self._logging_verbose + + +class TranslationUnit(BaseConfigUnit): + pass + + +class Translation: + LANGUAGES = ('en', 'ru') + _langs: dict[str, TranslationUnit] + + def __init__(self, name: str): + super().__init__() + self._langs = {} + for lang in self.LANGUAGES: + for dirname in CONFIG_DIRECTORIES: + if isdir(dirname): + filename = join(dirname, f'i18n-{lang}', f'{name}.yaml') + if lang in self._langs: + raise RuntimeError(f'{name}: translation unit for lang \'{lang}\' already loaded') + self._langs[lang] = TranslationUnit() + self._langs[lang].load_from(filename) + diff = set() + for data in self._langs.values(): + diff ^= data.get().keys() + if len(diff) > 0: + raise RuntimeError(f'{name}: translation units have difference in keys: ' + ', '.join(diff)) + + def get(self, lang: str) -> TranslationUnit: + return self._langs[lang] + + +class Config: app_name: Optional[str] + app_config: AppConfigUnit def __init__(self): - self.data = {} self.app_name = None + self.app_config = AppConfigUnit() + + def load_app(self, + name: Optional[Union[str, AppConfigUnit, bool]] = None, + use_cli=True, + parser: ArgumentParser = None, + no_config=False): + global app_config + + if issubclass(name, AppConfigUnit) or name == AppConfigUnit: + self.app_name = name.NAME + self.app_config = name() + app_config = self.app_config + else: + self.app_name = name if isinstance(name, str) else None - def load(self, name: Optional[str] = None, - use_cli=True, - parser: ArgumentParser = None): - self.app_name = name - - if (name is None) and (not use_cli): + if self.app_name is None and not use_cli: raise RuntimeError('either config name must be none or use_cli must be True') - log_default_fmt = False - log_file = None - log_verbose = False - no_config = name is False - + no_config = name is False or no_config path = None + if use_cli: if parser is None: parser = ArgumentParser() @@ -68,75 +266,38 @@ class ConfigStore: path = args.config if args.verbose: - log_verbose = True + self.app_config.logging_set_verbose() if args.log_file: - log_file = args.log_file + self.app_config.logging_set_file(args.log_file) if args.log_default_fmt: - log_default_fmt = args.log_default_fmt + self.app_config.logging_set_fmt(args.log_default_fmt) - if not no_config and path is None: - path = _get_config_path(name) + if not isinstance(name, ConfigUnit): + if not no_config and path is None: + path = ConfigUnit.get_config_path(name=self.app_name) - if no_config: - self.data = {} - else: - if path.endswith('.toml'): - self.data = toml.load(path) - elif path.endswith('.yaml'): - with open(path, 'r') as fd: - self.data = yaml.safe_load(fd) - - if 'logging' in self: - if not log_file and 'file' in self['logging']: - log_file = self['logging']['file'] - if log_default_fmt and 'default_fmt' in self['logging']: - log_default_fmt = self['logging']['default_fmt'] + if not no_config: + self.app_config.load_from(path) - setup_logging(log_verbose, log_file, log_default_fmt) + setup_logging(self.app_config.logging_is_verbose(), + self.app_config.logging_get_file(), + self.app_config.logging_get_fmt()) if use_cli: return args - def __getitem__(self, key): - return self.data[key] - - def __setitem__(self, key, value): - raise NotImplementedError('overwriting config values is prohibited') - - def __contains__(self, key): - return key in self.data - - def get(self, key: str, default=None): - cur = self.data - pts = key.split('.') - for i in range(len(pts)): - k = pts[i] - if i < len(pts)-1: - if k not in cur: - raise KeyError(f'key {k} not found') - else: - return cur[k] if k in cur else default - cur = self.data[k] - raise KeyError(f'option {key} not found') - - def get_addr(self, key: str): - return parse_addr(self.get(key)) - - def items(self): - return self.data.items() - -config = ConfigStore() +config = Config() def is_development_mode() -> bool: if 'HK_MODE' in os.environ and os.environ['HK_MODE'] == 'dev': return True - return ('logging' in config) and ('verbose' in config['logging']) and (config['logging']['verbose'] is True) + return ('logging' in config.app_config) and ('verbose' in config.app_config['logging']) and (config.app_config['logging']['verbose'] is True) -def setup_logging(verbose=False, log_file=None, default_fmt=False): +def setup_logging(verbose=False, log_file=None, default_fmt=None): logging_level = logging.INFO if is_development_mode() or verbose: logging_level = logging.DEBUG diff --git a/src/home/database/clickhouse.py b/src/home/database/clickhouse.py index ca81628..d0ec283 100644 --- a/src/home/database/clickhouse.py +++ b/src/home/database/clickhouse.py @@ -1,7 +1,7 @@ import logging from zoneinfo import ZoneInfo -from datetime import datetime, timedelta +from datetime import datetime from clickhouse_driver import Client as ClickhouseClient from ..config import is_development_mode diff --git a/src/home/database/sqlite.py b/src/home/database/sqlite.py index bfba929..8c6145c 100644 --- a/src/home/database/sqlite.py +++ b/src/home/database/sqlite.py @@ -5,24 +5,27 @@ import logging from ..config import config, is_development_mode -def _get_database_path(name: str, dbname: str) -> str: - return os.path.join(os.environ['HOME'], '.config', name, f'{dbname}.db') +def _get_database_path(name: str) -> str: + return os.path.join( + os.environ['HOME'], + '.config', + 'homekit', + 'data', + f'{name}.db') class SQLiteBase: SCHEMA = 1 - def __init__(self, name=None, dbname='bot', check_same_thread=False): - db_path = config.get('db_path', default=None) - if db_path is None: - if not name: - name = config.app_name - if not dbname: - dbname = name - db_path = _get_database_path(name, dbname) + def __init__(self, name=None, check_same_thread=False): + if name is None: + name = config.app_config['database_name'] + database_path = _get_database_path(name) + if not os.path.exists(os.path.dirname(database_path)): + os.makedirs(os.path.dirname(database_path)) self.logger = logging.getLogger(self.__class__.__name__) - self.sqlite = sqlite3.connect(db_path, check_same_thread=check_same_thread) + self.sqlite = sqlite3.connect(database_path, check_same_thread=check_same_thread) if is_development_mode(): self.sql_logger = logging.getLogger(self.__class__.__name__) diff --git a/src/home/inverter/config.py b/src/home/inverter/config.py new file mode 100644 index 0000000..62b8859 --- /dev/null +++ b/src/home/inverter/config.py @@ -0,0 +1,13 @@ +from ..config import ConfigUnit +from typing import Optional + + +class InverterdConfig(ConfigUnit): + NAME = 'inverterd' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'remote_addr': {'type': 'string'}, + 'local_addr': {'type': 'string'}, + } \ No newline at end of file diff --git a/src/home/media/__init__.py b/src/home/media/__init__.py index 976c990..6923105 100644 --- a/src/home/media/__init__.py +++ b/src/home/media/__init__.py @@ -12,6 +12,7 @@ __map__ = { __all__ = list(itertools.chain(*__map__.values())) + def __getattr__(name): if name in __all__: for file, names in __map__.items(): diff --git a/src/home/mqtt/__init__.py b/src/home/mqtt/__init__.py index 982e2b6..707d59c 100644 --- a/src/home/mqtt/__init__.py +++ b/src/home/mqtt/__init__.py @@ -1,4 +1,7 @@ -from .mqtt import MqttBase -from .util import poll_tick -from .relay import MqttRelay, MqttRelayState -from .temphum import MqttTempHum \ No newline at end of file +from ._mqtt import Mqtt +from ._node import MqttNode +from ._module import MqttModule +from ._wrapper import MqttWrapper +from ._config import MqttConfig, MqttCreds, MqttNodesConfig +from ._payload import MqttPayload, MqttPayloadCustomField +from ._util import get_modules as get_mqtt_modules \ No newline at end of file diff --git a/src/home/mqtt/_config.py b/src/home/mqtt/_config.py new file mode 100644 index 0000000..f9047b4 --- /dev/null +++ b/src/home/mqtt/_config.py @@ -0,0 +1,165 @@ +from ..config import ConfigUnit +from typing import Optional, Union +from ..util import Addr +from collections import namedtuple + +MqttCreds = namedtuple('MqttCreds', 'username, password') + + +class MqttConfig(ConfigUnit): + NAME = 'mqtt' + + @staticmethod + def schema() -> Optional[dict]: + addr_schema = { + 'type': 'dict', + 'required': True, + 'schema': { + 'host': {'type': 'string', 'required': True}, + 'port': {'type': 'integer', 'required': True} + } + } + + schema = {} + for key in ('local', 'remote'): + schema[f'{key}_addr'] = addr_schema + + schema['creds'] = { + 'type': 'dict', + 'required': True, + 'keysrules': {'type': 'string'}, + 'valuesrules': { + 'type': 'dict', + 'schema': { + 'username': {'type': 'string', 'required': True}, + 'password': {'type': 'string', 'required': True}, + } + } + } + + for key in ('client', 'server'): + schema[f'default_{key}_creds'] = {'type': 'string', 'required': True} + + return schema + + def remote_addr(self) -> Addr: + return Addr(host=self['remote_addr']['host'], + port=self['remote_addr']['port']) + + def local_addr(self) -> Addr: + return Addr(host=self['local_addr']['host'], + port=self['local_addr']['port']) + + def creds_by_name(self, name: str) -> MqttCreds: + return MqttCreds(username=self['creds'][name]['username'], + password=self['creds'][name]['password']) + + def creds(self) -> MqttCreds: + return self.creds_by_name(self['default_client_creds']) + + def server_creds(self) -> MqttCreds: + return self.creds_by_name(self['default_server_creds']) + + +class MqttNodesConfig(ConfigUnit): + NAME = 'mqtt_nodes' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'common': { + 'type': 'dict', + 'schema': { + 'temphum': { + 'type': 'dict', + 'schema': { + 'interval': {'type': 'integer'} + } + }, + 'password': {'type': 'string'} + } + }, + 'nodes': { + 'type': 'dict', + 'required': True, + 'keysrules': {'type': 'string'}, + 'valuesrules': { + 'type': 'dict', + 'schema': { + 'type': {'type': 'string', 'required': True, 'allowed': ['esp8266', 'linux', 'none'],}, + 'board': {'type': 'string', 'allowed': ['nodemcu', 'd1_mini_lite', 'esp12e']}, + 'temphum': { + 'type': 'dict', + 'schema': { + 'module': {'type': 'string', 'required': True, 'allowed': ['si7021', 'dht12']}, + 'interval': {'type': 'integer'}, + 'i2c_bus': {'type': 'integer'}, + 'tcpserver': { + 'type': 'dict', + 'schema': { + 'port': {'type': 'integer', 'required': True} + } + } + } + }, + 'relay': { + 'type': 'dict', + 'schema': { + 'device_type': {'type': 'string', 'allowed': ['lamp', 'pump', 'solenoid'], 'required': True}, + 'legacy_topics': {'type': 'boolean'} + } + }, + 'password': {'type': 'string'} + } + } + } + } + + @staticmethod + def custom_validator(data): + for name, node in data['nodes'].items(): + if 'temphum' in node: + if node['type'] == 'linux': + if 'i2c_bus' not in node['temphum']: + raise KeyError(f'nodes.{name}.temphum: i2c_bus is missing but required for type=linux') + if node['type'] in ('esp8266',) and 'board' not in node: + raise KeyError(f'nodes.{name}: board is missing but required for type={node["type"]}') + + def get_node(self, name: str) -> dict: + node = self['nodes'][name] + if node['type'] == 'none': + return node + + try: + if 'password' not in node: + node['password'] = self['common']['password'] + except KeyError: + pass + + try: + if 'temphum' in node: + for ckey, cval in self['common']['temphum'].items(): + if ckey not in node['temphum']: + node['temphum'][ckey] = cval + except KeyError: + pass + + return node + + def get_nodes(self, + filters: Optional[Union[list[str], tuple[str]]] = None, + only_names=False) -> Union[dict, list[str]]: + if filters: + for f in filters: + if f not in ('temphum', 'relay'): + raise ValueError(f'{self.__class__.__name__}::get_node(): invalid filter {f}') + reslist = [] + resdict = {} + for name in self['nodes'].keys(): + node = self.get_node(name) + if (not filters) or ('temphum' in filters and 'temphum' in node) or ('relay' in filters and 'relay' in node): + if only_names: + reslist.append(name) + else: + resdict[name] = node + return reslist if only_names else resdict diff --git a/src/home/mqtt/_module.py b/src/home/mqtt/_module.py new file mode 100644 index 0000000..80f27bb --- /dev/null +++ b/src/home/mqtt/_module.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import abc +import logging +import threading + +from time import sleep +from ..util import next_tick_gen + +from typing import TYPE_CHECKING, Optional +if TYPE_CHECKING: + from ._node import MqttNode + from ._payload import MqttPayload + + +class MqttModule(abc.ABC): + _tick_interval: int + _initialized: bool + _connected: bool + _ticker: Optional[threading.Thread] + _mqtt_node_ref: Optional[MqttNode] + + def __init__(self, tick_interval=0): + self._tick_interval = tick_interval + self._initialized = False + self._ticker = None + self._logger = logging.getLogger(self.__class__.__name__) + self._connected = False + self._mqtt_node_ref = None + + def on_connect(self, mqtt: MqttNode): + self._connected = True + self._mqtt_node_ref = mqtt + if self._tick_interval: + self._start_ticker() + + def on_disconnect(self, mqtt: MqttNode): + self._connected = False + self._mqtt_node_ref = None + + def is_initialized(self): + return self._initialized + + def set_initialized(self): + self._initialized = True + + def unset_initialized(self): + self._initialized = False + + def tick(self): + pass + + def _tick(self): + g = next_tick_gen(self._tick_interval) + while self._connected: + sleep(next(g)) + if not self._connected: + break + self.tick() + + def _start_ticker(self): + if not self._ticker or not self._ticker.is_alive(): + name_part = f'{self._mqtt_node_ref.id}/' if self._mqtt_node_ref else '' + self._ticker = None + self._ticker = threading.Thread(target=self._tick, + name=f'mqtt:{self.__class__.__name__}/{name_part}ticker') + self._ticker.start() + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + pass diff --git a/src/home/mqtt/_mqtt.py b/src/home/mqtt/_mqtt.py new file mode 100644 index 0000000..746ae2e --- /dev/null +++ b/src/home/mqtt/_mqtt.py @@ -0,0 +1,86 @@ +import os.path +import paho.mqtt.client as mqtt +import ssl +import logging + +from ._config import MqttCreds, MqttConfig +from typing import Optional + + +class Mqtt: + _connected: bool + _is_server: bool + _mqtt_config: MqttConfig + + def __init__(self, + clean_session=True, + client_id='', + creds: Optional[MqttCreds] = None, + is_server=False): + if not client_id: + raise ValueError('client_id must not be empty') + + self._client = mqtt.Client(client_id=client_id, + protocol=mqtt.MQTTv311, + clean_session=clean_session) + self._client.on_connect = self.on_connect + self._client.on_disconnect = self.on_disconnect + self._client.on_message = self.on_message + self._client.on_log = self.on_log + self._client.on_publish = self.on_publish + self._loop_started = False + self._connected = False + self._is_server = is_server + self._mqtt_config = MqttConfig() + self._logger = logging.getLogger(self.__class__.__name__) + + if not creds: + creds = self._mqtt_config.creds() if not is_server else self._mqtt_config.server_creds() + + self._client.username_pw_set(creds.username, creds.password) + + def _configure_tls(self): + ca_certs = os.path.realpath(os.path.join( + os.path.dirname(os.path.realpath(__file__)), + '..', + '..', + '..', + 'assets', + 'mqtt_ca.crt' + )) + self._client.tls_set(ca_certs=ca_certs, + cert_reqs=ssl.CERT_REQUIRED, + tls_version=ssl.PROTOCOL_TLSv1_2) + + def connect_and_loop(self, loop_forever=True): + self._configure_tls() + addr = self._mqtt_config.local_addr() if self._is_server else self._mqtt_config.remote_addr() + self._client.connect(addr.host, addr.port, 60) + if loop_forever: + self._client.loop_forever() + else: + self._client.loop_start() + self._loop_started = True + + def disconnect(self): + self._client.disconnect() + self._client.loop_write() + self._client.loop_stop() + + def on_connect(self, client: mqtt.Client, userdata, flags, rc): + self._logger.info("Connected with result code " + str(rc)) + self._connected = True + + def on_disconnect(self, client: mqtt.Client, userdata, rc): + self._logger.info("Disconnected with result code " + str(rc)) + self._connected = False + + def on_log(self, client: mqtt.Client, userdata, level, buf): + level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO + self._logger.log(level, f'MQTT: {buf}') + + def on_message(self, client: mqtt.Client, userdata, msg): + self._logger.debug(msg.topic + ": " + str(msg.payload)) + + def on_publish(self, client: mqtt.Client, userdata, mid): + self._logger.debug(f'publish done, mid={mid}') diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py new file mode 100644 index 0000000..4e259a4 --- /dev/null +++ b/src/home/mqtt/_node.py @@ -0,0 +1,92 @@ +import logging +import importlib + +from typing import List, TYPE_CHECKING, Optional +from ._payload import MqttPayload +from ._module import MqttModule +if TYPE_CHECKING: + from ._wrapper import MqttWrapper +else: + MqttWrapper = None + + +class MqttNode: + _modules: List[MqttModule] + _module_subscriptions: dict[str, MqttModule] + _node_id: str + _node_secret: str + _payload_callbacks: list[callable] + _wrapper: Optional[MqttWrapper] + + def __init__(self, + node_id: str, + node_secret: Optional[str] = None): + self._modules = [] + self._module_subscriptions = {} + self._node_id = node_id + self._node_secret = node_secret + self._payload_callbacks = [] + self._logger = logging.getLogger(self.__class__.__name__) + self._wrapper = None + + def on_connect(self, wrapper: MqttWrapper): + self._wrapper = wrapper + for module in self._modules: + if not module.is_initialized(): + module.on_connect(self) + module.set_initialized() + + def on_disconnect(self): + self._wrapper = None + for module in self._modules: + module.unset_initialized() + + def on_message(self, topic, payload): + if topic in self._module_subscriptions: + payload = self._module_subscriptions[topic].handle_payload(self, topic, payload) + if isinstance(payload, MqttPayload): + for f in self._payload_callbacks: + f(self, payload) + + def load_module(self, module_name: str, *args, **kwargs) -> MqttModule: + module = importlib.import_module(f'..module.{module_name}', __name__) + if not hasattr(module, 'MODULE_NAME'): + raise RuntimeError(f'MODULE_NAME not found in module {module}') + cl = getattr(module, getattr(module, 'MODULE_NAME')) + instance = cl(*args, **kwargs) + self.add_module(instance) + return instance + + def add_module(self, module: MqttModule): + self._modules.append(module) + if self._wrapper and self._wrapper._connected: + module.on_connect(self) + module.set_initialized() + + def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1): + if not self._wrapper or not self._wrapper._connected: + raise RuntimeError('not connected') + + self._module_subscriptions[topic] = module + self._wrapper.subscribe(self.id, topic, qos) + + def publish(self, + topic: str, + payload: bytes, + qos: int = 1): + self._wrapper.publish(self.id, topic, payload, qos) + + def add_payload_callback(self, callback: callable): + self._payload_callbacks.append(callback) + + @property + def id(self) -> str: + return self._node_id + + @property + def secret(self) -> str: + return self._node_secret + + @secret.setter + def secret(self, secret: str) -> None: + self._node_secret = secret diff --git a/src/home/mqtt/_payload.py b/src/home/mqtt/_payload.py new file mode 100644 index 0000000..58eeae3 --- /dev/null +++ b/src/home/mqtt/_payload.py @@ -0,0 +1,145 @@ +import struct +import abc +import re + +from typing import Optional, Tuple + + +def pldstr(self) -> str: + attrs = [] + for field in self.__class__.__annotations__: + if hasattr(self, field): + attr = getattr(self, field) + attrs.append(f'{field}={attr}') + if attrs: + attrs_s = ' ' + attrs_s += ', '.join(attrs) + else: + attrs_s = '' + return f'<%s{attrs_s}>' % (self.__class__.__name__,) + + +class MqttPayload(abc.ABC): + FORMAT = '' + PACKER = {} + UNPACKER = {} + + def __init__(self, **kwargs): + for field in self.__class__.__annotations__: + setattr(self, field, kwargs[field]) + + def pack(self): + args = [] + bf_number = -1 + bf_arg = 0 + bf_progress = 0 + + for field, field_type in self.__class__.__annotations__.items(): + bfp = _bit_field_params(field_type) + if bfp: + n, s, b = bfp + if n != bf_number: + if bf_number != -1: + args.append(bf_arg) + bf_number = n + bf_progress = 0 + bf_arg = 0 + bf_arg |= (getattr(self, field) & (2 ** b - 1)) << bf_progress + bf_progress += b + + else: + if bf_number != -1: + args.append(bf_arg) + bf_number = -1 + bf_progress = 0 + bf_arg = 0 + + args.append(self._pack_field(field)) + + if bf_number != -1: + args.append(bf_arg) + + return struct.pack(self.FORMAT, *args) + + @classmethod + def unpack(cls, buf: bytes): + data = struct.unpack(cls.FORMAT, buf) + kwargs = {} + i = 0 + bf_number = -1 + bf_progress = 0 + + for field, field_type in cls.__annotations__.items(): + bfp = _bit_field_params(field_type) + if bfp: + n, s, b = bfp + if n != bf_number: + bf_number = n + bf_progress = 0 + kwargs[field] = (data[i] >> bf_progress) & (2 ** b - 1) + bf_progress += b + continue # don't increment i + + if bf_number != -1: + bf_number = -1 + i += 1 + + if issubclass(field_type, MqttPayloadCustomField): + kwargs[field] = field_type.unpack(data[i]) + else: + kwargs[field] = cls._unpack_field(field, data[i]) + i += 1 + + return cls(**kwargs) + + def _pack_field(self, name): + val = getattr(self, name) + if self.PACKER and name in self.PACKER: + return self.PACKER[name](val) + else: + return val + + @classmethod + def _unpack_field(cls, name, val): + if isinstance(val, MqttPayloadCustomField): + return + if cls.UNPACKER and name in cls.UNPACKER: + return cls.UNPACKER[name](val) + else: + return val + + def __str__(self): + return pldstr(self) + + +class MqttPayloadCustomField(abc.ABC): + def __init__(self, **kwargs): + for field in self.__class__.__annotations__: + setattr(self, field, kwargs[field]) + + @abc.abstractmethod + def __index__(self): + pass + + @classmethod + @abc.abstractmethod + def unpack(cls, *args, **kwargs): + pass + + def __str__(self): + return pldstr(self) + + +def bit_field(seq_no: int, total_bits: int, bits: int): + return type(f'MQTTPayloadBitField_{seq_no}_{total_bits}_{bits}', (object,), { + 'seq_no': seq_no, + 'total_bits': total_bits, + 'bits': bits + }) + + +def _bit_field_params(cl) -> Optional[Tuple[int, ...]]: + match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__) + if match is not None: + return tuple([int(match.group(i)) for i in range(1, 4)]) + return None \ No newline at end of file diff --git a/src/home/mqtt/_util.py b/src/home/mqtt/_util.py new file mode 100644 index 0000000..390d463 --- /dev/null +++ b/src/home/mqtt/_util.py @@ -0,0 +1,15 @@ +import os +import re + +from typing import List + + +def get_modules() -> List[str]: + modules = [] + modules_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module') + for name in os.listdir(modules_dir): + if os.path.isdir(os.path.join(modules_dir, name)): + continue + name = re.sub(r'\.py$', '', name) + modules.append(name) + return modules diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py new file mode 100644 index 0000000..f858f88 --- /dev/null +++ b/src/home/mqtt/_wrapper.py @@ -0,0 +1,59 @@ +import paho.mqtt.client as mqtt + +from ._mqtt import Mqtt +from ._node import MqttNode +from ..config import config +from ..util import strgen + + +class MqttWrapper(Mqtt): + _nodes: list[MqttNode] + + def __init__(self, + client_id: str, + topic_prefix='hk', + randomize_client_id=False, + clean_session=True): + if randomize_client_id: + client_id += '_'+strgen(6) + super().__init__(clean_session=clean_session, + client_id=client_id) + self._nodes = [] + self._topic_prefix = topic_prefix + + def on_connect(self, client: mqtt.Client, userdata, flags, rc): + super().on_connect(client, userdata, flags, rc) + for node in self._nodes: + node.on_connect(self) + + def on_disconnect(self, client: mqtt.Client, userdata, rc): + super().on_disconnect(client, userdata, rc) + for node in self._nodes: + node.on_disconnect() + + def on_message(self, client: mqtt.Client, userdata, msg): + try: + topic = msg.topic + for node in self._nodes: + node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) + except Exception as e: + self._logger.exception(str(e)) + + def add_node(self, node: MqttNode): + self._nodes.append(node) + if self._connected: + node.on_connect(self) + + def subscribe(self, + node_id: str, + topic: str, + qos: int): + self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos) + + def publish(self, + node_id: str, + topic: str, + payload: bytes, + qos: int): + self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos) + self._client.loop_write() diff --git a/src/home/mqtt/esp.py b/src/home/mqtt/esp.py deleted file mode 100644 index 56ced83..0000000 --- a/src/home/mqtt/esp.py +++ /dev/null @@ -1,106 +0,0 @@ -import re -import paho.mqtt.client as mqtt - -from .mqtt import MqttBase -from typing import Optional, Union -from .payload.esp import ( - OTAPayload, - OTAResultPayload, - DiagnosticsPayload, - InitialDiagnosticsPayload -) - - -class MqttEspDevice: - id: str - secret: Optional[str] - - def __init__(self, id: str, secret: Optional[str] = None): - self.id = id - self.secret = secret - - -class MqttEspBase(MqttBase): - _devices: list[MqttEspDevice] - _message_callback: Optional[callable] - _ota_publish_callback: Optional[callable] - - TOPIC_LEAF = 'esp' - - def __init__(self, - devices: Union[MqttEspDevice, list[MqttEspDevice]], - subscribe_to_updates=True): - super().__init__(clean_session=True) - if not isinstance(devices, list): - devices = [devices] - self._devices = devices - self._message_callback = None - self._ota_publish_callback = None - self._subscribe_to_updates = subscribe_to_updates - self._ota_mid = None - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - super().on_connect(client, userdata, flags, rc) - - if self._subscribe_to_updates: - for device in self._devices: - topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#' - self._logger.debug(f"subscribing to {topic}") - client.subscribe(topic, qos=1) - - def on_publish(self, client: mqtt.Client, userdata, mid): - if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback: - self._ota_publish_callback() - - def set_message_callback(self, callback: callable): - self._message_callback = callback - - def on_message(self, client: mqtt.Client, userdata, msg): - try: - match = re.match(self.get_mqtt_topics(), msg.topic) - self._logger.debug(f'topic: {msg.topic}') - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - # try: - next(d for d in self._devices if d.id == device_id) - # except StopIteration:h - # return - - message = None - if subtopic == 'stat': - message = DiagnosticsPayload.unpack(msg.payload) - elif subtopic == 'stat1': - message = InitialDiagnosticsPayload.unpack(msg.payload) - elif subtopic == 'otares': - message = OTAResultPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - return True - - except Exception as e: - self._logger.exception(str(e)) - - def push_ota(self, - device_id, - filename: str, - publish_callback: callable, - qos: int): - device = next(d for d in self._devices if d.id == device_id) - assert device.secret is not None, 'device secret not specified' - - self._ota_publish_callback = publish_callback - payload = OTAPayload(secret=device.secret, filename=filename) - publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota', - payload=payload.pack(), - qos=qos) - self._ota_mid = publish_result.mid - self._client.loop_write() - - @classmethod - def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None): - return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$' \ No newline at end of file diff --git a/src/home/mqtt/module/diagnostics.py b/src/home/mqtt/module/diagnostics.py new file mode 100644 index 0000000..5db5e99 --- /dev/null +++ b/src/home/mqtt/module/diagnostics.py @@ -0,0 +1,64 @@ +from .._payload import MqttPayload, MqttPayloadCustomField +from .._node import MqttNode, MqttModule +from typing import Optional + +MODULE_NAME = 'MqttDiagnosticsModule' + + +class DiagnosticsFlags(MqttPayloadCustomField): + state: bool + config_changed_value_present: bool + config_changed: bool + + @staticmethod + def unpack(flags: int): + # _logger.debug(f'StatFlags.unpack: flags={flags}') + state = flags & 0x1 + ccvp = (flags >> 1) & 0x1 + cc = (flags >> 2) & 0x1 + # _logger.debug(f'StatFlags.unpack: state={state}') + return DiagnosticsFlags(state=(state == 1), + config_changed_value_present=(ccvp == 1), + config_changed=(cc == 1)) + + def __index__(self): + bits = 0 + bits |= (int(self.state) & 0x1) + bits |= (int(self.config_changed_value_present) & 0x1) << 1 + bits |= (int(self.config_changed) & 0x1) << 2 + return bits + + +class InitialDiagnosticsPayload(MqttPayload): + FORMAT = '=IBbIB' + + ip: int + fw_version: int + rssi: int + free_heap: int + flags: DiagnosticsFlags + + +class DiagnosticsPayload(MqttPayload): + FORMAT = '=bIB' + + rssi: int + free_heap: int + flags: DiagnosticsFlags + + +class MqttDiagnosticsModule(MqttModule): + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + for topic in ('diag', 'd1ag', 'stat', 'stat1'): + mqtt.subscribe_module(topic, self) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + message = None + if topic in ('stat', 'diag'): + message = DiagnosticsPayload.unpack(payload) + elif topic in ('stat1', 'd1ag'): + message = InitialDiagnosticsPayload.unpack(payload) + if message: + self._logger.debug(message) + return message diff --git a/src/home/mqtt/module/inverter.py b/src/home/mqtt/module/inverter.py new file mode 100644 index 0000000..d927a06 --- /dev/null +++ b/src/home/mqtt/module/inverter.py @@ -0,0 +1,195 @@ +import time +import json +import datetime +try: + import inverterd +except: + pass + +from typing import Optional +from .._module import MqttModule +from .._node import MqttNode +from .._payload import MqttPayload, bit_field +try: + from home.database import InverterDatabase +except: + pass + +_mult_10 = lambda n: int(n*10) +_div_10 = lambda n: n/10 + + +MODULE_NAME = 'MqttInverterModule' + +STATUS_TOPIC = 'status' +GENERATION_TOPIC = 'generation' + + +class MqttInverterStatusPayload(MqttPayload): + # 46 bytes + FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH' + + PACKER = { + 'grid_voltage': _mult_10, + 'grid_freq': _mult_10, + 'ac_output_voltage': _mult_10, + 'ac_output_freq': _mult_10, + 'battery_voltage': _mult_10, + 'battery_voltage_scc': _mult_10, + 'battery_voltage_scc2': _mult_10, + 'pv1_input_voltage': _mult_10, + 'pv2_input_voltage': _mult_10 + } + UNPACKER = { + 'grid_voltage': _div_10, + 'grid_freq': _div_10, + 'ac_output_voltage': _div_10, + 'ac_output_freq': _div_10, + 'battery_voltage': _div_10, + 'battery_voltage_scc': _div_10, + 'battery_voltage_scc2': _div_10, + 'pv1_input_voltage': _div_10, + 'pv2_input_voltage': _div_10 + } + + time: int + grid_voltage: float + grid_freq: float + ac_output_voltage: float + ac_output_freq: float + ac_output_apparent_power: int + ac_output_active_power: int + output_load_percent: int + battery_voltage: float + battery_voltage_scc: float + battery_voltage_scc2: float + battery_discharge_current: int + battery_charge_current: int + battery_capacity: int + inverter_heat_sink_temp: int + mppt1_charger_temp: int + mppt2_charger_temp: int + pv1_input_power: int + pv2_input_power: int + pv1_input_voltage: float + pv2_input_voltage: float + + # H + mppt1_charger_status: bit_field(0, 16, 2) + mppt2_charger_status: bit_field(0, 16, 2) + battery_power_direction: bit_field(0, 16, 2) + dc_ac_power_direction: bit_field(0, 16, 2) + line_power_direction: bit_field(0, 16, 2) + load_connected: bit_field(0, 16, 1) + + +class MqttInverterGenerationPayload(MqttPayload): + # 8 bytes + FORMAT = 'II' + + time: int + wh: int + + +class MqttInverterModule(MqttModule): + _status_poll_freq: int + _generation_poll_freq: int + _inverter: Optional[inverterd.Client] + _database: Optional[InverterDatabase] + _gen_prev: float + + def __init__(self, status_poll_freq=0, generation_poll_freq=0): + super().__init__(tick_interval=status_poll_freq) + self._status_poll_freq = status_poll_freq + self._generation_poll_freq = generation_poll_freq + + # this defines whether this is a publisher or a subscriber + if status_poll_freq > 0: + self._inverter = inverterd.Client() + self._inverter.connect() + self._inverter.format(inverterd.Format.SIMPLE_JSON) + self._database = None + else: + self._inverter = None + self._database = InverterDatabase() + + self._gen_prev = 0 + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + if not self._inverter: + mqtt.subscribe_module(STATUS_TOPIC, self) + mqtt.subscribe_module(GENERATION_TOPIC, self) + + def tick(self): + if not self._inverter: + return + + # read status + now = time.time() + try: + raw = self._inverter.exec('get-status') + except inverterd.InverterError as e: + self._logger.error(f'inverter error: {str(e)}') + # TODO send to server + return + + data = json.loads(raw)['data'] + status = MqttInverterStatusPayload(time=round(now), **data) + self._mqtt_node_ref.publish(STATUS_TOPIC, status.pack()) + + # read today's generation stat + now = time.time() + if self._gen_prev == 0 or now - self._gen_prev >= self._generation_poll_freq: + self._gen_prev = now + today = datetime.date.today() + try: + raw = self._inverter.exec('get-day-generated', (today.year, today.month, today.day)) + except inverterd.InverterError as e: + self._logger.error(f'inverter error: {str(e)}') + # TODO send to server + return + + data = json.loads(raw)['data'] + gen = MqttInverterGenerationPayload(time=round(now), wh=data['wh']) + self._mqtt_node_ref.publish(GENERATION_TOPIC, gen.pack()) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + home_id = 1 # legacy compat + + if topic == STATUS_TOPIC: + s = MqttInverterStatusPayload.unpack(payload) + self._database.add_status(home_id=home_id, + client_time=s.time, + grid_voltage=int(s.grid_voltage*10), + grid_freq=int(s.grid_freq * 10), + ac_output_voltage=int(s.ac_output_voltage * 10), + ac_output_freq=int(s.ac_output_freq * 10), + ac_output_apparent_power=s.ac_output_apparent_power, + ac_output_active_power=s.ac_output_active_power, + output_load_percent=s.output_load_percent, + battery_voltage=int(s.battery_voltage * 10), + battery_voltage_scc=int(s.battery_voltage_scc * 10), + battery_voltage_scc2=int(s.battery_voltage_scc2 * 10), + battery_discharge_current=s.battery_discharge_current, + battery_charge_current=s.battery_charge_current, + battery_capacity=s.battery_capacity, + inverter_heat_sink_temp=s.inverter_heat_sink_temp, + mppt1_charger_temp=s.mppt1_charger_temp, + mppt2_charger_temp=s.mppt2_charger_temp, + pv1_input_power=s.pv1_input_power, + pv2_input_power=s.pv2_input_power, + pv1_input_voltage=int(s.pv1_input_voltage * 10), + pv2_input_voltage=int(s.pv2_input_voltage * 10), + mppt1_charger_status=s.mppt1_charger_status, + mppt2_charger_status=s.mppt2_charger_status, + battery_power_direction=s.battery_power_direction, + dc_ac_power_direction=s.dc_ac_power_direction, + line_power_direction=s.line_power_direction, + load_connected=s.load_connected) + return s + + elif topic == GENERATION_TOPIC: + gen = MqttInverterGenerationPayload.unpack(payload) + self._database.add_generation(home_id, gen.time, gen.wh) + return gen diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py new file mode 100644 index 0000000..cd34332 --- /dev/null +++ b/src/home/mqtt/module/ota.py @@ -0,0 +1,77 @@ +import hashlib + +from typing import Optional +from .._payload import MqttPayload +from .._node import MqttModule, MqttNode + +MODULE_NAME = 'MqttOtaModule' + + +class OtaResultPayload(MqttPayload): + FORMAT = '=BB' + result: int + error_code: int + + +class OtaPayload(MqttPayload): + secret: str + filename: str + + # structure of returned data: + # + # uint8_t[len(secret)] secret; + # uint8_t[16] md5; + # *uint8_t data + + def pack(self): + buf = bytearray(self.secret.encode()) + m = hashlib.md5() + with open(self.filename, 'rb') as fd: + content = fd.read() + m.update(content) + buf.extend(m.digest()) + buf.extend(content) + return buf + + def unpack(cls, buf: bytes): + raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') + # secret = buf[:12].decode() + # filename = buf[12:].decode() + # return OTAPayload(secret=secret, filename=filename) + + +class MqttOtaModule(MqttModule): + _ota_request: Optional[tuple[str, int]] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._ota_request = None + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + mqtt.subscribe_module("otares", self) + + if self._ota_request is not None: + filename, qos = self._ota_request + self._ota_request = None + self.do_push_ota(self._mqtt_node_ref.secret, filename, qos) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + if topic == 'otares': + message = OtaResultPayload.unpack(payload) + self._logger.debug(message) + return message + + def do_push_ota(self, secret: str, filename: str, qos: int): + payload = OtaPayload(secret=secret, filename=filename) + self._mqtt_node_ref.publish('ota', + payload=payload.pack(), + qos=qos) + + def push_ota(self, + filename: str, + qos: int): + if not self._initialized: + self._ota_request = (filename, qos) + else: + self.do_push_ota(filename, qos) diff --git a/src/home/mqtt/module/relay.py b/src/home/mqtt/module/relay.py new file mode 100644 index 0000000..e968031 --- /dev/null +++ b/src/home/mqtt/module/relay.py @@ -0,0 +1,92 @@ +import datetime + +from typing import Optional +from .. import MqttModule, MqttPayload, MqttNode + +MODULE_NAME = 'MqttRelayModule' + + +class MqttPowerSwitchPayload(MqttPayload): + FORMAT = '=12sB' + PACKER = { + 'state': lambda n: int(n), + 'secret': lambda s: s.encode('utf-8') + } + UNPACKER = { + 'state': lambda n: bool(n), + 'secret': lambda s: s.decode('utf-8') + } + + secret: str + state: bool + + +class MqttPowerStatusPayload(MqttPayload): + FORMAT = '=B' + PACKER = { + 'opened': lambda n: int(n), + } + UNPACKER = { + 'opened': lambda n: bool(n), + } + + opened: bool + + +class MqttRelayState: + enabled: bool + update_time: datetime.datetime + rssi: int + fw_version: int + ever_updated: bool + + def __init__(self): + self.ever_updated = False + self.enabled = False + self.rssi = 0 + + def update(self, + enabled: bool, + rssi: int, + fw_version=None): + self.ever_updated = True + self.enabled = enabled + self.rssi = rssi + self.update_time = datetime.datetime.now() + if fw_version: + self.fw_version = fw_version + + +class MqttRelayModule(MqttModule): + _legacy_topics: bool + + def __init__(self, legacy_topics=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self._legacy_topics = legacy_topics + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + mqtt.subscribe_module(self._get_switch_topic(), self) + mqtt.subscribe_module('relay/status', self) + + def switchpower(self, + enable: bool): + payload = MqttPowerSwitchPayload(secret=self._mqtt_node_ref.secret, + state=enable) + self._mqtt_node_ref.publish(self._get_switch_topic(), + payload=payload.pack()) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + message = None + + if topic == self._get_switch_topic(): + message = MqttPowerSwitchPayload.unpack(payload) + elif topic == 'relay/status': + message = MqttPowerStatusPayload.unpack(payload) + + if message is not None: + self._logger.debug(message) + return message + + def _get_switch_topic(self) -> str: + return 'relay/power' if self._legacy_topics else 'relay/switch' diff --git a/src/home/mqtt/module/temphum.py b/src/home/mqtt/module/temphum.py new file mode 100644 index 0000000..fd02cca --- /dev/null +++ b/src/home/mqtt/module/temphum.py @@ -0,0 +1,82 @@ +from .._node import MqttNode +from .._module import MqttModule +from .._payload import MqttPayload +from typing import Optional +from ...temphum import BaseSensor + +two_digits_precision = lambda x: round(x, 2) + +MODULE_NAME = 'MqttTempHumModule' +DATA_TOPIC = 'temphum/data' + + +class MqttTemphumDataPayload(MqttPayload): + FORMAT = '=ddb' + UNPACKER = { + 'temp': two_digits_precision, + 'rh': two_digits_precision + } + + temp: float + rh: float + error: int + + +# class MqttTempHumNodes(HashableEnum): +# KBN_SH_HALL = auto() +# KBN_SH_BATHROOM = auto() +# KBN_SH_LIVINGROOM = auto() +# KBN_SH_BEDROOM = auto() +# +# KBN_BH_2FL = auto() +# KBN_BH_2FL_STREET = auto() +# KBN_BH_1FL_LIVINGROOM = auto() +# KBN_BH_1FL_BEDROOM = auto() +# KBN_BH_1FL_BATHROOM = auto() +# +# KBN_NH_1FL_INV = auto() +# KBN_NH_1FL_CENTER = auto() +# KBN_NH_1LF_KT = auto() +# KBN_NH_1FL_DS = auto() +# KBN_NH_1FS_EZ = auto() +# +# SPB_FLAT120_CABINET = auto() + + +class MqttTempHumModule(MqttModule): + def __init__(self, + sensor: Optional[BaseSensor] = None, + write_to_database=False, + *args, **kwargs): + if sensor is not None: + kwargs['tick_interval'] = 10 + super().__init__(*args, **kwargs) + self._sensor = sensor + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + mqtt.subscribe_module(DATA_TOPIC, self) + + def tick(self): + if not self._sensor: + return + + error = 0 + temp = 0 + rh = 0 + try: + temp = self._sensor.temperature() + rh = self._sensor.humidity() + except: + error = 1 + pld = MqttTemphumDataPayload(temp=temp, rh=rh, error=error) + self._mqtt_node_ref.publish(DATA_TOPIC, pld.pack()) + + def handle_payload(self, + mqtt: MqttNode, + topic: str, + payload: bytes) -> Optional[MqttPayload]: + if topic == DATA_TOPIC: + message = MqttTemphumDataPayload.unpack(payload) + self._logger.debug(message) + return message diff --git a/src/home/mqtt/mqtt.py b/src/home/mqtt/mqtt.py deleted file mode 100644 index 4acd4f6..0000000 --- a/src/home/mqtt/mqtt.py +++ /dev/null @@ -1,76 +0,0 @@ -import os.path -import paho.mqtt.client as mqtt -import ssl -import logging - -from typing import Tuple -from ..config import config - - -def username_and_password() -> Tuple[str, str]: - username = config['mqtt']['username'] if 'username' in config['mqtt'] else None - password = config['mqtt']['password'] if 'password' in config['mqtt'] else None - return username, password - - -class MqttBase: - def __init__(self, clean_session=True): - self._client = mqtt.Client(client_id=config['mqtt']['client_id'], - protocol=mqtt.MQTTv311, - clean_session=clean_session) - self._client.on_connect = self.on_connect - self._client.on_disconnect = self.on_disconnect - self._client.on_message = self.on_message - self._client.on_log = self.on_log - self._client.on_publish = self.on_publish - self._loop_started = False - - self._logger = logging.getLogger(self.__class__.__name__) - - username, password = username_and_password() - if username and password: - self._logger.debug(f'username={username} password={password}') - self._client.username_pw_set(username, password) - - def configure_tls(self): - ca_certs = os.path.realpath(os.path.join( - os.path.dirname(os.path.realpath(__file__)), - '..', - '..', - '..', - 'assets', - 'mqtt_ca.crt' - )) - self._client.tls_set(ca_certs=ca_certs, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_TLSv1_2) - - def connect_and_loop(self, loop_forever=True): - host = config['mqtt']['host'] - port = config['mqtt']['port'] - - self._client.connect(host, port, 60) - if loop_forever: - self._client.loop_forever() - else: - self._client.loop_start() - self._loop_started = True - - def disconnect(self): - self._client.disconnect() - self._client.loop_write() - self._client.loop_stop() - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - self._logger.info("Connected with result code " + str(rc)) - - def on_disconnect(self, client: mqtt.Client, userdata, rc): - self._logger.info("Disconnected with result code " + str(rc)) - - def on_log(self, client: mqtt.Client, userdata, level, buf): - level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO - self._logger.log(level, f'MQTT: {buf}') - - def on_message(self, client: mqtt.Client, userdata, msg): - self._logger.debug(msg.topic + ": " + str(msg.payload)) - - def on_publish(self, client: mqtt.Client, userdata, mid): - self._logger.debug(f'publish done, mid={mid}') \ No newline at end of file diff --git a/src/home/mqtt/payload/__init__.py b/src/home/mqtt/payload/__init__.py deleted file mode 100644 index eee6709..0000000 --- a/src/home/mqtt/payload/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base_payload import MqttPayload \ No newline at end of file diff --git a/src/home/mqtt/payload/base_payload.py b/src/home/mqtt/payload/base_payload.py deleted file mode 100644 index 1abd898..0000000 --- a/src/home/mqtt/payload/base_payload.py +++ /dev/null @@ -1,145 +0,0 @@ -import abc -import struct -import re - -from typing import Optional, Tuple - - -def pldstr(self) -> str: - attrs = [] - for field in self.__class__.__annotations__: - if hasattr(self, field): - attr = getattr(self, field) - attrs.append(f'{field}={attr}') - if attrs: - attrs_s = ' ' - attrs_s += ', '.join(attrs) - else: - attrs_s = '' - return f'<%s{attrs_s}>' % (self.__class__.__name__,) - - -class MqttPayload(abc.ABC): - FORMAT = '' - PACKER = {} - UNPACKER = {} - - def __init__(self, **kwargs): - for field in self.__class__.__annotations__: - setattr(self, field, kwargs[field]) - - def pack(self): - args = [] - bf_number = -1 - bf_arg = 0 - bf_progress = 0 - - for field, field_type in self.__class__.__annotations__.items(): - bfp = _bit_field_params(field_type) - if bfp: - n, s, b = bfp - if n != bf_number: - if bf_number != -1: - args.append(bf_arg) - bf_number = n - bf_progress = 0 - bf_arg = 0 - bf_arg |= (getattr(self, field) & (2 ** b - 1)) << bf_progress - bf_progress += b - - else: - if bf_number != -1: - args.append(bf_arg) - bf_number = -1 - bf_progress = 0 - bf_arg = 0 - - args.append(self._pack_field(field)) - - if bf_number != -1: - args.append(bf_arg) - - return struct.pack(self.FORMAT, *args) - - @classmethod - def unpack(cls, buf: bytes): - data = struct.unpack(cls.FORMAT, buf) - kwargs = {} - i = 0 - bf_number = -1 - bf_progress = 0 - - for field, field_type in cls.__annotations__.items(): - bfp = _bit_field_params(field_type) - if bfp: - n, s, b = bfp - if n != bf_number: - bf_number = n - bf_progress = 0 - kwargs[field] = (data[i] >> bf_progress) & (2 ** b - 1) - bf_progress += b - continue # don't increment i - - if bf_number != -1: - bf_number = -1 - i += 1 - - if issubclass(field_type, MqttPayloadCustomField): - kwargs[field] = field_type.unpack(data[i]) - else: - kwargs[field] = cls._unpack_field(field, data[i]) - i += 1 - - return cls(**kwargs) - - def _pack_field(self, name): - val = getattr(self, name) - if self.PACKER and name in self.PACKER: - return self.PACKER[name](val) - else: - return val - - @classmethod - def _unpack_field(cls, name, val): - if isinstance(val, MqttPayloadCustomField): - return - if cls.UNPACKER and name in cls.UNPACKER: - return cls.UNPACKER[name](val) - else: - return val - - def __str__(self): - return pldstr(self) - - -class MqttPayloadCustomField(abc.ABC): - def __init__(self, **kwargs): - for field in self.__class__.__annotations__: - setattr(self, field, kwargs[field]) - - @abc.abstractmethod - def __index__(self): - pass - - @classmethod - @abc.abstractmethod - def unpack(cls, *args, **kwargs): - pass - - def __str__(self): - return pldstr(self) - - -def bit_field(seq_no: int, total_bits: int, bits: int): - return type(f'MQTTPayloadBitField_{seq_no}_{total_bits}_{bits}', (object,), { - 'seq_no': seq_no, - 'total_bits': total_bits, - 'bits': bits - }) - - -def _bit_field_params(cl) -> Optional[Tuple[int, ...]]: - match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__) - if match is not None: - return tuple([int(match.group(i)) for i in range(1, 4)]) - return None diff --git a/src/home/mqtt/payload/esp.py b/src/home/mqtt/payload/esp.py deleted file mode 100644 index 171cdb9..0000000 --- a/src/home/mqtt/payload/esp.py +++ /dev/null @@ -1,78 +0,0 @@ -import hashlib - -from .base_payload import MqttPayload, MqttPayloadCustomField - - -class OTAResultPayload(MqttPayload): - FORMAT = '=BB' - result: int - error_code: int - - -class OTAPayload(MqttPayload): - secret: str - filename: str - - # structure of returned data: - # - # uint8_t[len(secret)] secret; - # uint8_t[16] md5; - # *uint8_t data - - def pack(self): - buf = bytearray(self.secret.encode()) - m = hashlib.md5() - with open(self.filename, 'rb') as fd: - content = fd.read() - m.update(content) - buf.extend(m.digest()) - buf.extend(content) - return buf - - def unpack(cls, buf: bytes): - raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') - # secret = buf[:12].decode() - # filename = buf[12:].decode() - # return OTAPayload(secret=secret, filename=filename) - - -class DiagnosticsFlags(MqttPayloadCustomField): - state: bool - config_changed_value_present: bool - config_changed: bool - - @staticmethod - def unpack(flags: int): - # _logger.debug(f'StatFlags.unpack: flags={flags}') - state = flags & 0x1 - ccvp = (flags >> 1) & 0x1 - cc = (flags >> 2) & 0x1 - # _logger.debug(f'StatFlags.unpack: state={state}') - return DiagnosticsFlags(state=(state == 1), - config_changed_value_present=(ccvp == 1), - config_changed=(cc == 1)) - - def __index__(self): - bits = 0 - bits |= (int(self.state) & 0x1) - bits |= (int(self.config_changed_value_present) & 0x1) << 1 - bits |= (int(self.config_changed) & 0x1) << 2 - return bits - - -class InitialDiagnosticsPayload(MqttPayload): - FORMAT = '=IBbIB' - - ip: int - fw_version: int - rssi: int - free_heap: int - flags: DiagnosticsFlags - - -class DiagnosticsPayload(MqttPayload): - FORMAT = '=bIB' - - rssi: int - free_heap: int - flags: DiagnosticsFlags diff --git a/src/home/mqtt/payload/inverter.py b/src/home/mqtt/payload/inverter.py deleted file mode 100644 index 09388df..0000000 --- a/src/home/mqtt/payload/inverter.py +++ /dev/null @@ -1,73 +0,0 @@ -import struct - -from .base_payload import MqttPayload, bit_field -from typing import Tuple - -_mult_10 = lambda n: int(n*10) -_div_10 = lambda n: n/10 - - -class Status(MqttPayload): - # 46 bytes - FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH' - - PACKER = { - 'grid_voltage': _mult_10, - 'grid_freq': _mult_10, - 'ac_output_voltage': _mult_10, - 'ac_output_freq': _mult_10, - 'battery_voltage': _mult_10, - 'battery_voltage_scc': _mult_10, - 'battery_voltage_scc2': _mult_10, - 'pv1_input_voltage': _mult_10, - 'pv2_input_voltage': _mult_10 - } - UNPACKER = { - 'grid_voltage': _div_10, - 'grid_freq': _div_10, - 'ac_output_voltage': _div_10, - 'ac_output_freq': _div_10, - 'battery_voltage': _div_10, - 'battery_voltage_scc': _div_10, - 'battery_voltage_scc2': _div_10, - 'pv1_input_voltage': _div_10, - 'pv2_input_voltage': _div_10 - } - - time: int - grid_voltage: float - grid_freq: float - ac_output_voltage: float - ac_output_freq: float - ac_output_apparent_power: int - ac_output_active_power: int - output_load_percent: int - battery_voltage: float - battery_voltage_scc: float - battery_voltage_scc2: float - battery_discharge_current: int - battery_charge_current: int - battery_capacity: int - inverter_heat_sink_temp: int - mppt1_charger_temp: int - mppt2_charger_temp: int - pv1_input_power: int - pv2_input_power: int - pv1_input_voltage: float - pv2_input_voltage: float - - # H - mppt1_charger_status: bit_field(0, 16, 2) - mppt2_charger_status: bit_field(0, 16, 2) - battery_power_direction: bit_field(0, 16, 2) - dc_ac_power_direction: bit_field(0, 16, 2) - line_power_direction: bit_field(0, 16, 2) - load_connected: bit_field(0, 16, 1) - - -class Generation(MqttPayload): - # 8 bytes - FORMAT = 'II' - - time: int - wh: int diff --git a/src/home/mqtt/payload/relay.py b/src/home/mqtt/payload/relay.py deleted file mode 100644 index 4902991..0000000 --- a/src/home/mqtt/payload/relay.py +++ /dev/null @@ -1,22 +0,0 @@ -from .base_payload import MqttPayload -from .esp import ( - OTAResultPayload, - OTAPayload, - InitialDiagnosticsPayload, - DiagnosticsPayload -) - - -class PowerPayload(MqttPayload): - FORMAT = '=12sB' - PACKER = { - 'state': lambda n: int(n), - 'secret': lambda s: s.encode('utf-8') - } - UNPACKER = { - 'state': lambda n: bool(n), - 'secret': lambda s: s.decode('utf-8') - } - - secret: str - state: bool diff --git a/src/home/mqtt/payload/sensors.py b/src/home/mqtt/payload/sensors.py deleted file mode 100644 index f99b307..0000000 --- a/src/home/mqtt/payload/sensors.py +++ /dev/null @@ -1,20 +0,0 @@ -from .base_payload import MqttPayload - -_mult_100 = lambda n: int(n*100) -_div_100 = lambda n: n/100 - - -class Temperature(MqttPayload): - FORMAT = 'IhH' - PACKER = { - 'temp': _mult_100, - 'rh': _mult_100, - } - UNPACKER = { - 'temp': _div_100, - 'rh': _div_100, - } - - time: int - temp: float - rh: float diff --git a/src/home/mqtt/payload/temphum.py b/src/home/mqtt/payload/temphum.py deleted file mode 100644 index c0b744e..0000000 --- a/src/home/mqtt/payload/temphum.py +++ /dev/null @@ -1,15 +0,0 @@ -from .base_payload import MqttPayload - -two_digits_precision = lambda x: round(x, 2) - - -class TempHumDataPayload(MqttPayload): - FORMAT = '=ddb' - UNPACKER = { - 'temp': two_digits_precision, - 'rh': two_digits_precision - } - - temp: float - rh: float - error: int diff --git a/src/home/mqtt/relay.py b/src/home/mqtt/relay.py deleted file mode 100644 index a90f19c..0000000 --- a/src/home/mqtt/relay.py +++ /dev/null @@ -1,71 +0,0 @@ -import paho.mqtt.client as mqtt -import re -import datetime - -from .payload.relay import ( - PowerPayload, -) -from .esp import MqttEspBase - - -class MqttRelay(MqttEspBase): - TOPIC_LEAF = 'relay' - - def set_power(self, device_id, enable: bool, secret=None): - device = next(d for d in self._devices if d.id == device_id) - secret = secret if secret else device.secret - - assert secret is not None, 'device secret not specified' - - payload = PowerPayload(secret=secret, - state=enable) - self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/power', - payload=payload.pack(), - qos=1) - self._client.loop_write() - - def on_message(self, client: mqtt.Client, userdata, msg): - if super().on_message(client, userdata, msg): - return - - try: - match = re.match(self.get_mqtt_topics(['power']), msg.topic) - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - message = None - if subtopic == 'power': - message = PowerPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - - except Exception as e: - self._logger.exception(str(e)) - - -class MqttRelayState: - enabled: bool - update_time: datetime.datetime - rssi: int - fw_version: int - ever_updated: bool - - def __init__(self): - self.ever_updated = False - self.enabled = False - self.rssi = 0 - - def update(self, - enabled: bool, - rssi: int, - fw_version=None): - self.ever_updated = True - self.enabled = enabled - self.rssi = rssi - self.update_time = datetime.datetime.now() - if fw_version: - self.fw_version = fw_version diff --git a/src/home/mqtt/temphum.py b/src/home/mqtt/temphum.py deleted file mode 100644 index 44810ef..0000000 --- a/src/home/mqtt/temphum.py +++ /dev/null @@ -1,54 +0,0 @@ -import paho.mqtt.client as mqtt -import re - -from enum import auto -from .payload.temphum import TempHumDataPayload -from .esp import MqttEspBase -from ..util import HashableEnum - - -class MqttTempHumNodes(HashableEnum): - KBN_SH_HALL = auto() - KBN_SH_BATHROOM = auto() - KBN_SH_LIVINGROOM = auto() - KBN_SH_BEDROOM = auto() - - KBN_BH_2FL = auto() - KBN_BH_2FL_STREET = auto() - KBN_BH_1FL_LIVINGROOM = auto() - KBN_BH_1FL_BEDROOM = auto() - KBN_BH_1FL_BATHROOM = auto() - - KBN_NH_1FL_INV = auto() - KBN_NH_1FL_CENTER = auto() - KBN_NH_1LF_KT = auto() - KBN_NH_1FL_DS = auto() - KBN_NH_1FS_EZ = auto() - - SPB_FLAT120_CABINET = auto() - - -class MqttTempHum(MqttEspBase): - TOPIC_LEAF = 'temphum' - - def on_message(self, client: mqtt.Client, userdata, msg): - if super().on_message(client, userdata, msg): - return - - try: - match = re.match(self.get_mqtt_topics(['data']), msg.topic) - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - message = None - if subtopic == 'data': - message = TempHumDataPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - - except Exception as e: - self._logger.exception(str(e)) diff --git a/src/home/mqtt/util.py b/src/home/mqtt/util.py deleted file mode 100644 index f71ffd8..0000000 --- a/src/home/mqtt/util.py +++ /dev/null @@ -1,8 +0,0 @@ -import time - - -def poll_tick(freq): - t = time.time() - while True: - t += freq - yield max(t - time.time(), 0) diff --git a/src/home/pio/products.py b/src/home/pio/products.py index 7649078..388da03 100644 --- a/src/home/pio/products.py +++ b/src/home/pio/products.py @@ -16,10 +16,6 @@ _products_dir = os.path.join( def get_products(): products = [] for f in os.listdir(_products_dir): - # temp hack - if f.endswith('-esp01'): - continue - # skip the common dir if f in ('common',): continue diff --git a/src/home/telegram/_botcontext.py b/src/home/telegram/_botcontext.py index f343eeb..a143bfe 100644 --- a/src/home/telegram/_botcontext.py +++ b/src/home/telegram/_botcontext.py @@ -1,6 +1,7 @@ from typing import Optional, List -from telegram import Update, ParseMode, User, CallbackQuery +from telegram import Update, User, CallbackQuery +from telegram.constants import ParseMode from telegram.ext import CallbackContext from ._botdb import BotDatabase @@ -26,25 +27,25 @@ class Context: self._store = store self._user_lang = None - def reply(self, text, markup=None): + async def reply(self, text, markup=None): if markup is None: markup = self._markup_getter(self) kwargs = dict(parse_mode=ParseMode.HTML) if not isinstance(markup, IgnoreMarkup): kwargs['reply_markup'] = markup - return self._update.message.reply_text(text, **kwargs) + return await self._update.message.reply_text(text, **kwargs) - def reply_exc(self, e: Exception) -> None: - self.reply(exc2text(e), markup=IgnoreMarkup()) + async def reply_exc(self, e: Exception) -> None: + await self.reply(exc2text(e), markup=IgnoreMarkup()) - def answer(self, text: str = None): - self.callback_query.answer(text) + async def answer(self, text: str = None): + await self.callback_query.answer(text) - def edit(self, text, markup=None): + async def edit(self, text, markup=None): kwargs = dict(parse_mode=ParseMode.HTML) if not isinstance(markup, IgnoreMarkup): kwargs['reply_markup'] = markup - self.callback_query.edit_message_text(text, **kwargs) + await self.callback_query.edit_message_text(text, **kwargs) @property def text(self) -> str: diff --git a/src/home/telegram/bot.py b/src/home/telegram/bot.py index 10bfe06..7e22263 100644 --- a/src/home/telegram/bot.py +++ b/src/home/telegram/bot.py @@ -5,19 +5,19 @@ import itertools from enum import Enum, auto from functools import wraps -from typing import Optional, Union, Tuple +from typing import Optional, Union, Tuple, Coroutine from telegram import Update, ReplyKeyboardMarkup from telegram.ext import ( - Updater, - Filters, - BaseFilter, + Application, + filters, CommandHandler, MessageHandler, CallbackQueryHandler, CallbackContext, ConversationHandler ) +from telegram.ext.filters import BaseFilter from telegram.error import TimedOut from home.config import config @@ -33,26 +33,26 @@ from ._botcontext import Context db: Optional[BotDatabase] = None _user_filter: Optional[BaseFilter] = None -_cancel_filter = Filters.text(lang.all('cancel')) -_back_filter = Filters.text(lang.all('back')) -_cancel_and_back_filter = Filters.text(lang.all('back') + lang.all('cancel')) +_cancel_filter = filters.Text(lang.all('cancel')) +_back_filter = filters.Text(lang.all('back')) +_cancel_and_back_filter = filters.Text(lang.all('back') + lang.all('cancel')) _logger = logging.getLogger(__name__) -_updater: Optional[Updater] = None +_application: Optional[Application] = None _reporting: Optional[ReportingHelper] = None -_exception_handler: Optional[callable] = None +_exception_handler: Optional[Coroutine] = None _dispatcher = None _markup_getter: Optional[callable] = None -_start_handler_ref: Optional[callable] = None +_start_handler_ref: Optional[Coroutine] = None def text_filter(*args): if not _user_filter: raise RuntimeError('user_filter is not initialized') - return Filters.text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter + return filters.Text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter -def _handler_of_handler(*args, **kwargs): +async def _handler_of_handler(*args, **kwargs): self = None context = None update = None @@ -99,7 +99,7 @@ def _handler_of_handler(*args, **kwargs): if self: _args.insert(0, self) - result = f(*_args, **kwargs) + result = await f(*_args, **kwargs) return result if not return_with_context else (result, ctx) except Exception as e: @@ -107,7 +107,7 @@ def _handler_of_handler(*args, **kwargs): if not _exception_handler(e, ctx) and not isinstance(e, TimedOut): _logger.exception(e) if not ctx.is_callback_context(): - ctx.reply_exc(e) + await ctx.reply_exc(e) else: notify_user(ctx.user_id, exc2text(e)) else: @@ -117,10 +117,10 @@ def _handler_of_handler(*args, **kwargs): def handler(**kwargs): def inner(f): @wraps(f) - def _handler(*args, **inner_kwargs): + async def _handler(*args, **inner_kwargs): if 'argument' in kwargs and kwargs['argument'] == 'message_key': inner_kwargs['argument'] = 'message_key' - return _handler_of_handler(f=f, *args, **inner_kwargs) + return await _handler_of_handler(f=f, *args, **inner_kwargs) messages = [] texts = [] @@ -139,43 +139,43 @@ def handler(**kwargs): new_messages = list(itertools.chain.from_iterable([lang.all(m) for m in messages])) texts += new_messages texts = list(set(texts)) - _updater.dispatcher.add_handler( + _application.add_handler( MessageHandler(text_filter(*texts), _handler), group=0 ) if 'command' in kwargs: - _updater.dispatcher.add_handler(CommandHandler(kwargs['command'], _handler), group=0) + _application.add_handler(CommandHandler(kwargs['command'], _handler), group=0) if 'callback' in kwargs: - _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) + _application.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) return _handler return inner -def simplehandler(f: callable): +def simplehandler(f: Coroutine): @wraps(f) - def _handler(*args, **kwargs): - return _handler_of_handler(f=f, *args, **kwargs) + async def _handler(*args, **kwargs): + return await _handler_of_handler(f=f, *args, **kwargs) return _handler def callbackhandler(*args, **kwargs): def inner(f): @wraps(f) - def _handler(*args, **kwargs): - return _handler_of_handler(f=f, *args, **kwargs) + async def _handler(*args, **kwargs): + return await _handler_of_handler(f=f, *args, **kwargs) pattern_kwargs = {} if kwargs['callback'] != '*': pattern_kwargs['pattern'] = kwargs['callback'] - _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) + _application.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) return _handler return inner -def exceptionhandler(f: callable): +async def exceptionhandler(f: callable): global _exception_handler if _exception_handler: _logger.warning('exception handler already set, we will overwrite it') @@ -198,10 +198,10 @@ def convinput(state, is_enter=False, **kwargs): ) @wraps(f) - def _impl(*args, **kwargs): - result, ctx = _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) + async def _impl(*args, **kwargs): + result, ctx = await _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) if result == conversation.END: - start(ctx) + await start(ctx) return result return _impl @@ -252,7 +252,7 @@ class conversation: handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), self.make_invoker(target_state))) if 'regex' in kwargs: - handlers.append(MessageHandler(Filters.regex(kwargs['regex']) & _user_filter, f)) + handlers.append(MessageHandler(filters.Regex(kwargs['regex']) & _user_filter, f)) if 'command' in kwargs: handlers.append(CommandHandler(kwargs['command'], f, _user_filter)) @@ -327,21 +327,21 @@ class conversation: @staticmethod @simplehandler - def invalid(ctx: Context): - ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) + async def invalid(ctx: Context): + await ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) # return 0 # FIXME is this needed @simplehandler - def cancel(self, ctx: Context): - start(ctx) + async def cancel(self, ctx: Context): + await start(ctx) self.set_user_state(ctx.user_id, None) return conversation.END @simplehandler - def back(self, ctx: Context): + async def back(self, ctx: Context): cur_state = self.get_user_state(ctx.user_id) if cur_state is None: - start(ctx) + await start(ctx) self.set_user_state(ctx.user_id, None) return conversation.END @@ -411,7 +411,7 @@ class LangConversation(conversation): START, = range(1) @conventer(START, command='lang') - def entry(self, ctx: Context): + async def entry(self, ctx: Context): self._logger.debug(f'current language: {ctx.user_lang}') buttons = [] @@ -419,11 +419,11 @@ class LangConversation(conversation): buttons.append(name) markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False) - ctx.reply(ctx.lang('select_language'), markup=markup) + await ctx.reply(ctx.lang('select_language'), markup=markup) return self.START @convinput(START, messages=lang.languages) - def input(self, ctx: Context): + async def input(self, ctx: Context): selected_lang = None for key, value in languages.items(): if value == ctx.text: @@ -434,30 +434,34 @@ class LangConversation(conversation): raise ValueError('could not find the language') db.set_user_lang(ctx.user_id, selected_lang) - ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) + await ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) return self.END def initialize(): global _user_filter - global _updater + global _application + # global _updater global _dispatcher # init user_filter - if 'users' in config['bot']: - _logger.info('allowed users: ' + str(config['bot']['users'])) - _user_filter = Filters.user(config['bot']['users']) + _user_ids = config.app_config.get_user_ids() + if len(_user_ids) > 0: + _logger.info('allowed users: ' + str(_user_ids)) + _user_filter = filters.User(_user_ids) else: - _user_filter = Filters.all # not sure if this is correct + _user_filter = filters.ALL # not sure if this is correct - # init updater - _updater = Updater(config['bot']['token'], - request_kwargs={'read_timeout': 6, 'connect_timeout': 7}) + _application = Application.builder()\ + .token(config.app_config.get('bot.token'))\ + .connect_timeout(7)\ + .read_timeout(6)\ + .build() # transparently log all messages - _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, _logging_message_handler), group=10) - _updater.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) + # _application.dispatcher.add_handler(MessageHandler(filters.ALL & _user_filter, _logging_message_handler), group=10) + # _application.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) def run(start_handler=None, any_handler=None): @@ -473,37 +477,38 @@ def run(start_handler=None, any_handler=None): _start_handler_ref = start_handler - _updater.dispatcher.add_handler(LangConversation().get_handler(), group=0) - _updater.dispatcher.add_handler(CommandHandler('start', simplehandler(start_handler), _user_filter)) - _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, any_handler)) + _application.add_handler(LangConversation().get_handler(), group=0) + _application.add_handler(CommandHandler('start', + callback=simplehandler(start_handler), + filters=_user_filter)) + _application.add_handler(MessageHandler(filters.ALL & _user_filter, any_handler)) - _updater.start_polling() - _updater.idle() + _application.run_polling() def add_conversation(conv: conversation) -> None: - _updater.dispatcher.add_handler(conv.get_handler(), group=0) + _application.add_handler(conv.get_handler(), group=0) def add_handler(h): - _updater.dispatcher.add_handler(h, group=0) + _application.add_handler(h, group=0) -def start(ctx: Context): - return _start_handler_ref(ctx) +async def start(ctx: Context): + return await _start_handler_ref(ctx) -def _default_start_handler(ctx: Context): +async def _default_start_handler(ctx: Context): if 'start_message' not in lang: - return ctx.reply('Please define start_message or override start()') - ctx.reply(ctx.lang('start_message')) + return await ctx.reply('Please define start_message or override start()') + await ctx.reply(ctx.lang('start_message')) @simplehandler -def _default_any_handler(ctx: Context): +async def _default_any_handler(ctx: Context): if 'invalid_command' not in lang: - return ctx.reply('Please define invalid_command or override any()') - ctx.reply(ctx.lang('invalid_command')) + return await ctx.reply('Please define invalid_command or override any()') + await ctx.reply(ctx.lang('invalid_command')) def _logging_message_handler(update: Update, context: CallbackContext): @@ -535,7 +540,7 @@ def notify_all(text_getter: callable, continue text = text_getter(db.get_user_lang(user_id)) - _updater.bot.send_message(chat_id=user_id, + _application.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML') @@ -543,33 +548,33 @@ def notify_all(text_getter: callable, def notify_user(user_id: int, text: Union[str, Exception], **kwargs) -> None: if isinstance(text, Exception): text = exc2text(text) - _updater.bot.send_message(chat_id=user_id, + _application.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML', **kwargs) def send_photo(user_id, **kwargs): - _updater.bot.send_photo(chat_id=user_id, **kwargs) + _application.bot.send_photo(chat_id=user_id, **kwargs) def send_audio(user_id, **kwargs): - _updater.bot.send_audio(chat_id=user_id, **kwargs) + _application.bot.send_audio(chat_id=user_id, **kwargs) def send_file(user_id, **kwargs): - _updater.bot.send_document(chat_id=user_id, **kwargs) + _application.bot.send_document(chat_id=user_id, **kwargs) def edit_message_text(user_id, message_id, *args, **kwargs): - _updater.bot.edit_message_text(chat_id=user_id, + _application.bot.edit_message_text(chat_id=user_id, message_id=message_id, parse_mode='HTML', *args, **kwargs) def delete_message(user_id, message_id): - _updater.bot.delete_message(chat_id=user_id, message_id=message_id) + _application.bot.delete_message(chat_id=user_id, message_id=message_id) def set_database(_db: BotDatabase): diff --git a/src/home/telegram/config.py b/src/home/telegram/config.py new file mode 100644 index 0000000..7a46087 --- /dev/null +++ b/src/home/telegram/config.py @@ -0,0 +1,75 @@ +from ..config import ConfigUnit +from typing import Optional, Union +from abc import ABC +from enum import Enum + + +class TelegramUserListType(Enum): + USERS = 'users' + NOTIFY = 'notify_users' + + +class TelegramUserIdsConfig(ConfigUnit): + NAME = 'telegram_user_ids' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'roottype': 'dict', + 'type': 'integer' + } + + +_user_ids_config = TelegramUserIdsConfig() + + +def _user_id_mapper(user: Union[str, int]) -> int: + if isinstance(user, int): + return user + return _user_ids_config[user] + + +class TelegramChatsConfig(ConfigUnit): + NAME = 'telegram_chats' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'type': 'dict', + 'schema': { + 'id': {'type': 'string', 'required': True}, + 'token': {'type': 'string', 'required': True}, + } + } + + +class TelegramBotConfig(ConfigUnit, ABC): + @staticmethod + def schema() -> Optional[dict]: + return { + 'bot': { + 'type': 'dict', + 'schema': { + 'token': {'type': 'string', 'required': True}, + TelegramUserListType.USERS: {**TelegramBotConfig._userlist_schema(), 'required': True}, + TelegramUserListType.NOTIFY: TelegramBotConfig._userlist_schema(), + } + } + } + + @staticmethod + def _userlist_schema() -> dict: + return {'type': 'list', 'schema': {'type': ['string', 'int']}} + + @staticmethod + def custom_validator(data): + for ult in TelegramUserListType: + users = data['bot'][ult.value] + for user in users: + if isinstance(user, str): + if user not in _user_ids_config: + raise ValueError(f'user {user} not found in {TelegramUserIdsConfig.NAME}') + + def get_user_ids(self, + ult: TelegramUserListType = TelegramUserListType.USERS) -> list[int]: + return list(map(_user_id_mapper, self['bot'][ult.value])) \ No newline at end of file diff --git a/src/home/temphum/__init__.py b/src/home/temphum/__init__.py index 55a7e1f..46d14e6 100644 --- a/src/home/temphum/__init__.py +++ b/src/home/temphum/__init__.py @@ -1,18 +1 @@ -from .base import SensorType, TempHumSensor -from .si7021 import Si7021 -from .dht12 import DHT12 - -__all__ = [ - 'SensorType', - 'TempHumSensor', - 'create_sensor' -] - - -def create_sensor(type: SensorType, bus: int) -> TempHumSensor: - if type == SensorType.Si7021: - return Si7021(bus) - elif type == SensorType.DHT12: - return DHT12(bus) - else: - raise ValueError('unexpected sensor type') +from .base import SensorType, BaseSensor diff --git a/src/home/temphum/base.py b/src/home/temphum/base.py index e774433..602cab7 100644 --- a/src/home/temphum/base.py +++ b/src/home/temphum/base.py @@ -1,25 +1,19 @@ -import smbus - -from abc import abstractmethod, ABC +from abc import ABC from enum import Enum -class TempHumSensor: - @abstractmethod +class BaseSensor(ABC): + def __init__(self, bus: int): + super().__init__() + self.bus = smbus.SMBus(bus) + def humidity(self) -> float: pass - @abstractmethod def temperature(self) -> float: pass -class I2CTempHumSensor(TempHumSensor, ABC): - def __init__(self, bus: int): - super().__init__() - self.bus = smbus.SMBus(bus) - - class SensorType(Enum): Si7021 = 'si7021' - DHT12 = 'dht12' + DHT12 = 'dht12' \ No newline at end of file diff --git a/src/home/temphum/dht12.py b/src/home/temphum/dht12.py deleted file mode 100644 index d495766..0000000 --- a/src/home/temphum/dht12.py +++ /dev/null @@ -1,22 +0,0 @@ -from .base import I2CTempHumSensor - - -class DHT12(I2CTempHumSensor): - i2c_addr = 0x5C - - def _measure(self): - raw = self.bus.read_i2c_block_data(self.i2c_addr, 0, 5) - if (raw[0] + raw[1] + raw[2] + raw[3]) & 0xff != raw[4]: - raise ValueError("checksum error") - return raw - - def temperature(self) -> float: - raw = self._measure() - temp = raw[2] + (raw[3] & 0x7f) * 0.1 - if raw[3] & 0x80: - temp *= -1 - return temp - - def humidity(self) -> float: - raw = self._measure() - return raw[0] + raw[1] * 0.1 diff --git a/src/home/temphum/i2c.py b/src/home/temphum/i2c.py new file mode 100644 index 0000000..7d8e2e3 --- /dev/null +++ b/src/home/temphum/i2c.py @@ -0,0 +1,52 @@ +import abc +import smbus + +from .base import BaseSensor, SensorType + + +class I2CSensor(BaseSensor, abc.ABC): + def __init__(self, bus: int): + super().__init__() + self.bus = smbus.SMBus(bus) + + +class DHT12(I2CSensor): + i2c_addr = 0x5C + + def _measure(self): + raw = self.bus.read_i2c_block_data(self.i2c_addr, 0, 5) + if (raw[0] + raw[1] + raw[2] + raw[3]) & 0xff != raw[4]: + raise ValueError("checksum error") + return raw + + def temperature(self) -> float: + raw = self._measure() + temp = raw[2] + (raw[3] & 0x7f) * 0.1 + if raw[3] & 0x80: + temp *= -1 + return temp + + def humidity(self) -> float: + raw = self._measure() + return raw[0] + raw[1] * 0.1 + + +class Si7021(I2CSensor): + i2c_addr = 0x40 + + def temperature(self) -> float: + raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE3, 2) + return 175.72 * (raw[0] << 8 | raw[1]) / 65536.0 - 46.85 + + def humidity(self) -> float: + raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE5, 2) + return 125.0 * (raw[0] << 8 | raw[1]) / 65536.0 - 6.0 + + +def create_sensor(type: SensorType, bus: int) -> BaseSensor: + if type == SensorType.Si7021: + return Si7021(bus) + elif type == SensorType.DHT12: + return DHT12(bus) + else: + raise ValueError('unexpected sensor type') diff --git a/src/home/temphum/si7021.py b/src/home/temphum/si7021.py deleted file mode 100644 index 6289e15..0000000 --- a/src/home/temphum/si7021.py +++ /dev/null @@ -1,13 +0,0 @@ -from .base import I2CTempHumSensor - - -class Si7021(I2CTempHumSensor): - i2c_addr = 0x40 - - def temperature(self) -> float: - raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE3, 2) - return 175.72 * (raw[0] << 8 | raw[1]) / 65536.0 - 46.85 - - def humidity(self) -> float: - raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE5, 2) - return 125.0 * (raw[0] << 8 | raw[1]) / 65536.0 - 6.0 diff --git a/src/home/util.py b/src/home/util.py index 93a9d8f..35505bc 100644 --- a/src/home/util.py +++ b/src/home/util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import socket import time @@ -6,17 +8,57 @@ import traceback import logging import string import random +import re from enum import Enum from datetime import datetime from typing import Tuple, Optional, List from zlib import adler32 -Addr = Tuple[str, int] # network address type (host, port) - logger = logging.getLogger(__name__) +def validate_ipv4_or_hostname(address: str, raise_exception: bool = False) -> bool: + if re.match(r'^(\d{1,3}\.){3}\d{1,3}$', address): + parts = address.split('.') + if all(0 <= int(part) < 256 for part in parts): + return True + else: + if raise_exception: + raise ValueError(f"invalid IPv4 address: {address}") + return False + + if re.match(r'^[a-zA-Z0-9.-]+$', address): + return True + else: + if raise_exception: + raise ValueError(f"invalid hostname: {address}") + return False + + +class Addr: + host: str + port: int + + def __init__(self, host: str, port: int): + self.host = host + self.port = port + + @staticmethod + def fromstring(addr: str) -> Addr: + if addr.count(':') != 1: + raise ValueError('invalid host:port format') + + host, port = addr.split(':') + validate_ipv4_or_hostname(host, raise_exception=True) + + port = int(port) + if not 0 <= port <= 65535: + raise ValueError(f'invalid port {port}') + + return Addr(host, port) + + # https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks def chunks(lst, n): """Yield successive n-sized chunks from lst.""" @@ -45,21 +87,6 @@ def ipv4_valid(ip: str) -> bool: return False -def parse_addr(addr: str) -> Addr: - if addr.count(':') != 1: - raise ValueError('invalid host:port format') - - host, port = addr.split(':') - if not ipv4_valid(host): - raise ValueError('invalid ipv4 address') - - port = int(port) - if not 0 <= port <= 65535: - raise ValueError('invalid port') - - return host, port - - def strgen(n: int): return ''.join(random.choices(string.ascii_letters + string.digits, k=n)) @@ -193,4 +220,11 @@ def filesize_fmt(num, suffix="B") -> str: class HashableEnum(Enum): def hash(self) -> int: - return adler32(self.name.encode()) \ No newline at end of file + return adler32(self.name.encode()) + + +def next_tick_gen(freq): + t = time.time() + while True: + t += freq + yield max(t - time.time(), 0) \ No newline at end of file -- cgit v1.2.3 From 327a5298359027099631c3c9967b7585928cd367 Mon Sep 17 00:00:00 2001 From: Evgeny Zinoviev Date: Sat, 10 Jun 2023 21:54:56 +0300 Subject: port relay_mqtt_http_proxy to new config scheme; config: support addr types & normalization --- src/home/config/_configs.py | 8 +++--- src/home/config/config.py | 61 +++++++++++++++++++++++++++++---------------- src/home/inverter/config.py | 4 +-- src/home/mqtt/_config.py | 8 +++--- src/home/mqtt/_wrapper.py | 5 ++-- src/home/telegram/config.py | 12 ++++----- src/home/util.py | 33 ++++++++++++++++++------ 7 files changed, 84 insertions(+), 47 deletions(-) (limited to 'src/home') diff --git a/src/home/config/_configs.py b/src/home/config/_configs.py index 3a1aae5..1628cba 100644 --- a/src/home/config/_configs.py +++ b/src/home/config/_configs.py @@ -5,8 +5,8 @@ from typing import Optional class ServicesListConfig(ConfigUnit): NAME = 'services_list' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'type': 'list', 'empty': False, @@ -19,8 +19,8 @@ class ServicesListConfig(ConfigUnit): class LinuxBoardsConfig(ConfigUnit): NAME = 'linux_boards' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'type': 'dict', 'schema': { diff --git a/src/home/config/config.py b/src/home/config/config.py index aef9ee7..dc00d2e 100644 --- a/src/home/config/config.py +++ b/src/home/config/config.py @@ -1,10 +1,10 @@ import yaml import logging import os -import pprint +import cerberus +import cerberus.errors from abc import ABC -from cerberus import Validator, DocumentError from typing import Optional, Any, MutableMapping, Union from argparse import ArgumentParser from enum import Enum, auto @@ -12,11 +12,20 @@ from os.path import join, isdir, isfile from ..util import Addr +class MyValidator(cerberus.Validator): + def _normalize_coerce_addr(self, value): + return Addr.fromstring(value) + + +MyValidator.types_mapping['addr'] = cerberus.TypeDefinition('Addr', (Addr,), ()) + + CONFIG_DIRECTORIES = ( join(os.environ['HOME'], '.config', 'homekit'), '/etc/homekit' ) + class RootSchemaType(Enum): DEFAULT = auto() DICT = auto() @@ -95,10 +104,19 @@ class ConfigUnit(BaseConfigUnit): raise IOError(f'\'{name}.yaml\' not found') - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return None + @classmethod + def _addr_schema(cls, required=False, **kwargs): + return { + 'type': 'addr', + 'coerce': Addr.fromstring, + 'required': required, + **kwargs + } + def validate(self): schema = self.schema() if not schema: @@ -109,7 +127,7 @@ class ConfigUnit(BaseConfigUnit): schema['logging'] = { 'type': 'dict', 'schema': { - 'logging': {'type': 'bool'} + 'logging': {'type': 'boolean'} } } @@ -125,27 +143,27 @@ class ConfigUnit(BaseConfigUnit): except KeyError: pass + v = MyValidator() + if rst == RootSchemaType.DICT: - v = Validator({'document': { - 'type': 'dict', - 'keysrules': {'type': 'string'}, - 'valuesrules': schema - }}) - result = v.validate({'document': self._data}) + normalized = v.validated({'document': self._data}, + {'document': { + 'type': 'dict', + 'keysrules': {'type': 'string'}, + 'valuesrules': schema + }})['document'] elif rst == RootSchemaType.LIST: - v = Validator({'document': schema}) - result = v.validate({'document': self._data}) + v = MyValidator() + normalized = v.validated({'document': self._data}, {'document': schema})['document'] else: - v = Validator(schema) - result = v.validate(self._data) - # pprint.pprint(self._data) - if not result: - # pprint.pprint(v.errors) - raise DocumentError(f'{self.__class__.__name__}: failed to validate data:\n{pprint.pformat(v.errors)}') + normalized = v.validated(self._data, schema) + + self._data = normalized + try: self.custom_validator(self._data) except Exception as e: - raise DocumentError(f'{self.__class__.__name__}: {str(e)}') + raise cerberus.DocumentError(f'{self.__class__.__name__}: {str(e)}') @staticmethod def custom_validator(data): @@ -238,7 +256,7 @@ class Config: no_config=False): global app_config - if issubclass(name, AppConfigUnit) or name == AppConfigUnit: + if not isinstance(name, str) and not isinstance(name, bool) and issubclass(name, AppConfigUnit) or name == AppConfigUnit: self.app_name = name.NAME self.app_config = name() app_config = self.app_config @@ -278,6 +296,7 @@ class Config: if not no_config: self.app_config.load_from(path) + self.app_config.validate() setup_logging(self.app_config.logging_is_verbose(), self.app_config.logging_get_file(), diff --git a/src/home/inverter/config.py b/src/home/inverter/config.py index 62b8859..e284dfe 100644 --- a/src/home/inverter/config.py +++ b/src/home/inverter/config.py @@ -5,8 +5,8 @@ from typing import Optional class InverterdConfig(ConfigUnit): NAME = 'inverterd' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'remote_addr': {'type': 'string'}, 'local_addr': {'type': 'string'}, diff --git a/src/home/mqtt/_config.py b/src/home/mqtt/_config.py index f9047b4..9ba9443 100644 --- a/src/home/mqtt/_config.py +++ b/src/home/mqtt/_config.py @@ -9,8 +9,8 @@ MqttCreds = namedtuple('MqttCreds', 'username, password') class MqttConfig(ConfigUnit): NAME = 'mqtt' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: addr_schema = { 'type': 'dict', 'required': True, @@ -64,8 +64,8 @@ class MqttConfig(ConfigUnit): class MqttNodesConfig(ConfigUnit): NAME = 'mqtt_nodes' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'common': { 'type': 'dict', diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py index f858f88..3c2774c 100644 --- a/src/home/mqtt/_wrapper.py +++ b/src/home/mqtt/_wrapper.py @@ -2,7 +2,6 @@ import paho.mqtt.client as mqtt from ._mqtt import Mqtt from ._node import MqttNode -from ..config import config from ..util import strgen @@ -34,8 +33,10 @@ class MqttWrapper(Mqtt): def on_message(self, client: mqtt.Client, userdata, msg): try: topic = msg.topic + topic_node = topic[len(self._topic_prefix)+1:topic.find('/', len(self._topic_prefix)+1)] for node in self._nodes: - node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) + if node.id in ('+', topic_node): + node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) except Exception as e: self._logger.exception(str(e)) diff --git a/src/home/telegram/config.py b/src/home/telegram/config.py index 7a46087..4c7d74b 100644 --- a/src/home/telegram/config.py +++ b/src/home/telegram/config.py @@ -12,8 +12,8 @@ class TelegramUserListType(Enum): class TelegramUserIdsConfig(ConfigUnit): NAME = 'telegram_user_ids' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'roottype': 'dict', 'type': 'integer' @@ -32,8 +32,8 @@ def _user_id_mapper(user: Union[str, int]) -> int: class TelegramChatsConfig(ConfigUnit): NAME = 'telegram_chats' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'type': 'dict', 'schema': { @@ -44,8 +44,8 @@ class TelegramChatsConfig(ConfigUnit): class TelegramBotConfig(ConfigUnit, ABC): - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'bot': { 'type': 'dict', diff --git a/src/home/util.py b/src/home/util.py index 35505bc..1e12243 100644 --- a/src/home/util.py +++ b/src/home/util.py @@ -12,7 +12,7 @@ import re from enum import Enum from datetime import datetime -from typing import Tuple, Optional, List +from typing import Optional, List from zlib import adler32 logger = logging.getLogger(__name__) @@ -38,26 +38,43 @@ def validate_ipv4_or_hostname(address: str, raise_exception: bool = False) -> bo class Addr: host: str - port: int + port: Optional[int] - def __init__(self, host: str, port: int): + def __init__(self, host: str, port: Optional[int] = None): self.host = host self.port = port @staticmethod def fromstring(addr: str) -> Addr: - if addr.count(':') != 1: + colons = addr.count(':') + if colons != 1: raise ValueError('invalid host:port format') - host, port = addr.split(':') + if not colons: + host = addr + port= None + else: + host, port = addr.split(':') + validate_ipv4_or_hostname(host, raise_exception=True) - port = int(port) - if not 0 <= port <= 65535: - raise ValueError(f'invalid port {port}') + if port is not None: + port = int(port) + if not 0 <= port <= 65535: + raise ValueError(f'invalid port {port}') return Addr(host, port) + def __str__(self): + buf = self.host + if self.port is not None: + buf += ':'+str(self.port) + return buf + + def __iter__(self): + yield self.host + yield self.port + # https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks def chunks(lst, n): -- cgit v1.2.3 From 2631c58961c2f5ec90be560a8f5152fe27339a90 Mon Sep 17 00:00:00 2001 From: Evgeny Zinoviev Date: Sat, 10 Jun 2023 22:11:41 +0300 Subject: fix mqtt_node_util --- src/home/config/config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'src/home') diff --git a/src/home/config/config.py b/src/home/config/config.py index dc00d2e..7344386 100644 --- a/src/home/config/config.py +++ b/src/home/config/config.py @@ -256,7 +256,10 @@ class Config: no_config=False): global app_config - if not isinstance(name, str) and not isinstance(name, bool) and issubclass(name, AppConfigUnit) or name == AppConfigUnit: + if not no_config \ + and not isinstance(name, str) \ + and not isinstance(name, bool) \ + and issubclass(name, AppConfigUnit) or name == AppConfigUnit: self.app_name = name.NAME self.app_config = name() app_config = self.app_config -- cgit v1.2.3 From 3790c2205396cf860738f297e6ddc49cd2b2a03f Mon Sep 17 00:00:00 2001 From: Evgeny Zinoviev Date: Sat, 10 Jun 2023 22:29:24 +0300 Subject: new config: port openwrt_logger and webapiclient --- src/home/api/__init__.py | 12 ++++++++++-- src/home/api/__init__.pyi | 3 ++- src/home/api/config.py | 15 +++++++++++++++ src/home/api/web_api_client.py | 32 +++++++++++++++++--------------- src/home/database/_base.py | 9 +++++++++ src/home/database/simple_state.py | 14 ++++++++------ src/home/database/sqlite.py | 6 ++---- src/home/telegram/_botutil.py | 2 +- src/home/telegram/bot.py | 4 ++-- 9 files changed, 66 insertions(+), 31 deletions(-) create mode 100644 src/home/api/config.py create mode 100644 src/home/database/_base.py (limited to 'src/home') diff --git a/src/home/api/__init__.py b/src/home/api/__init__.py index 782a61e..d641f62 100644 --- a/src/home/api/__init__.py +++ b/src/home/api/__init__.py @@ -1,11 +1,19 @@ import importlib -__all__ = ['WebAPIClient', 'RequestParams'] +__all__ = [ + # web_api_client.py + 'WebApiClient', + 'RequestParams', + + # config.py + 'WebApiConfig' +] def __getattr__(name): if name in __all__: - module = importlib.import_module(f'.web_api_client', __name__) + file = 'config' if name == 'WebApiConfig' else 'web_api_client' + module = importlib.import_module(f'.{file}', __name__) return getattr(module, name) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/home/api/__init__.pyi b/src/home/api/__init__.pyi index 1b812d6..5b98161 100644 --- a/src/home/api/__init__.pyi +++ b/src/home/api/__init__.pyi @@ -1,4 +1,5 @@ from .web_api_client import ( RequestParams as RequestParams, - WebAPIClient as WebAPIClient + WebApiClient as WebApiClient ) +from .config import WebApiConfig as WebApiConfig diff --git a/src/home/api/config.py b/src/home/api/config.py new file mode 100644 index 0000000..00c1097 --- /dev/null +++ b/src/home/api/config.py @@ -0,0 +1,15 @@ +from ..config import ConfigUnit +from typing import Optional, Union + + +class WebApiConfig(ConfigUnit): + NAME = 'web_api' + + @classmethod + def schema(cls) -> Optional[dict]: + return { + 'listen_addr': cls._addr_schema(required=True), + 'host': cls._addr_schema(required=True), + 'token': dict(type='string', required=True), + 'recordings_dir': dict(type='string', required=True) + } \ No newline at end of file diff --git a/src/home/api/web_api_client.py b/src/home/api/web_api_client.py index 6677182..15c1915 100644 --- a/src/home/api/web_api_client.py +++ b/src/home/api/web_api_client.py @@ -9,13 +9,15 @@ from enum import Enum, auto from typing import Optional, Callable, Union, List, Tuple, Dict from requests.auth import HTTPBasicAuth +from .config import WebApiConfig from .errors import ApiResponseError from .types import * from ..config import config from ..util import stringify from ..media import RecordFile, MediaNodeClient -logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) +_config = WebApiConfig() RequestParams = namedtuple('RequestParams', 'params, files, method') @@ -26,7 +28,7 @@ class HTTPMethod(Enum): POST = auto() -class WebAPIClient: +class WebApiClient: token: str timeout: Union[float, Tuple[float, float]] basic_auth: Optional[HTTPBasicAuth] @@ -35,22 +37,22 @@ class WebAPIClient: async_success_handler: Optional[Callable] def __init__(self, timeout: Union[float, Tuple[float, float]] = 5): - self.token = config['api']['token'] + self.token = config['token'] self.timeout = timeout self.basic_auth = None self.do_async = False self.async_error_handler = None self.async_success_handler = None - if 'basic_auth' in config['api']: - ba = config['api']['basic_auth'] - col = ba.index(':') - - user = ba[:col] - pw = ba[col+1:] - - logger.debug(f'enabling basic auth: {user}:{pw}') - self.basic_auth = HTTPBasicAuth(user, pw) + # if 'basic_auth' in config['api']: + # ba = config['api']['basic_auth'] + # col = ba.index(':') + # + # user = ba[:col] + # pw = ba[col+1:] + # + # _logger.debug(f'enabling basic auth: {user}:{pw}') + # self.basic_auth = HTTPBasicAuth(user, pw) # api methods # ----------- @@ -152,7 +154,7 @@ class WebAPIClient: params: dict, method: HTTPMethod = HTTPMethod.GET, files: Optional[Dict[str, str]] = None) -> Optional[any]: - domain = config['api']['host'] + domain = config['host'] kwargs = {} if self.basic_auth is not None: @@ -196,7 +198,7 @@ class WebAPIClient: try: f.close() except Exception as exc: - logger.exception(exc) + _logger.exception(exc) pass def _make_request_in_thread(self, name, params, method, files): @@ -204,7 +206,7 @@ class WebAPIClient: result = self._make_request(name, params, method, files) self._report_async_success(result, name, RequestParams(params=params, method=method, files=files)) except Exception as e: - logger.exception(e) + _logger.exception(e) self._report_async_error(e, name, RequestParams(params=params, method=method, files=files)) def enable_async(self, diff --git a/src/home/database/_base.py b/src/home/database/_base.py new file mode 100644 index 0000000..c01e62b --- /dev/null +++ b/src/home/database/_base.py @@ -0,0 +1,9 @@ +import os + + +def get_data_root_directory(name: str) -> str: + return os.path.join( + os.environ['HOME'], + '.config', + 'homekit', + 'data') \ No newline at end of file diff --git a/src/home/database/simple_state.py b/src/home/database/simple_state.py index cada9c8..2b8ebe7 100644 --- a/src/home/database/simple_state.py +++ b/src/home/database/simple_state.py @@ -2,24 +2,26 @@ import os import json import atexit +from ._base import get_data_root_directory + class SimpleState: def __init__(self, - file: str, - default: dict = None, - **kwargs): + name: str, + default: dict = None): if default is None: default = {} elif type(default) is not dict: raise TypeError('default must be dictionary') - if not os.path.exists(file): + path = os.path.join(get_data_root_directory(), name) + if not os.path.exists(path): self._data = default else: - with open(file, 'r') as f: + with open(path, 'r') as f: self._data = json.loads(f.read()) - self._file = file + self._file = path atexit.register(self.__cleanup) def __cleanup(self): diff --git a/src/home/database/sqlite.py b/src/home/database/sqlite.py index 8c6145c..0af1f54 100644 --- a/src/home/database/sqlite.py +++ b/src/home/database/sqlite.py @@ -2,15 +2,13 @@ import sqlite3 import os.path import logging +from ._base import get_data_root_directory from ..config import config, is_development_mode def _get_database_path(name: str) -> str: return os.path.join( - os.environ['HOME'], - '.config', - 'homekit', - 'data', + get_data_root_directory(), f'{name}.db') diff --git a/src/home/telegram/_botutil.py b/src/home/telegram/_botutil.py index 6d1ee8f..b551a55 100644 --- a/src/home/telegram/_botutil.py +++ b/src/home/telegram/_botutil.py @@ -3,7 +3,7 @@ import traceback from html import escape from telegram import User -from home.api import WebAPIClient as APIClient +from home.api import WebApiClient as APIClient from home.api.types import BotType from home.api.errors import ApiResponseError diff --git a/src/home/telegram/bot.py b/src/home/telegram/bot.py index 7e22263..e6ebc6e 100644 --- a/src/home/telegram/bot.py +++ b/src/home/telegram/bot.py @@ -21,7 +21,7 @@ from telegram.ext.filters import BaseFilter from telegram.error import TimedOut from home.config import config -from home.api import WebAPIClient +from home.api import WebApiClient from home.api.types import BotType from ._botlang import lang, languages @@ -522,7 +522,7 @@ def _logging_callback_handler(update: Update, context: CallbackContext): def enable_logging(bot_type: BotType): - api = WebAPIClient(timeout=3) + api = WebApiClient(timeout=3) api.enable_async() global _reporting -- cgit v1.2.3 From f3b9d50496257d87757802dfb472b5ffae11962c Mon Sep 17 00:00:00 2001 From: Evgeny Zinoviev Date: Sat, 10 Jun 2023 22:44:31 +0300 Subject: new config: port openwrt_log_analyzer --- src/home/telegram/telegram.py | 28 +++++++++++++++------------- src/home/util.py | 8 ++++++++ 2 files changed, 23 insertions(+), 13 deletions(-) (limited to 'src/home') diff --git a/src/home/telegram/telegram.py b/src/home/telegram/telegram.py index 2f94f93..f42363e 100644 --- a/src/home/telegram/telegram.py +++ b/src/home/telegram/telegram.py @@ -2,25 +2,27 @@ import requests import logging from typing import Tuple -from ..config import config - +from .config import TelegramChatsConfig +_chats = TelegramChatsConfig() _logger = logging.getLogger(__name__) def send_message(text: str, - parse_mode: str = None, - disable_web_page_preview: bool = False): - data, token = _send_telegram_data(text, parse_mode, disable_web_page_preview) + chat: str, + parse_mode: str = 'HTML', + disable_web_page_preview: bool = False,): + data, token = _send_telegram_data(text, chat, parse_mode, disable_web_page_preview) req = requests.post('https://api.telegram.org/bot%s/sendMessage' % token, data=data) return req.json() -def send_photo(filename: str): +def send_photo(filename: str, chat: str): + chat_data = _chats[chat] data = { - 'chat_id': config['telegram']['chat_id'], + 'chat_id': chat_data['id'], } - token = config['telegram']['token'] + token = chat_data['token'] url = f'https://api.telegram.org/bot{token}/sendPhoto' with open(filename, "rb") as fd: @@ -29,19 +31,19 @@ def send_photo(filename: str): def _send_telegram_data(text: str, + chat: str, parse_mode: str = None, disable_web_page_preview: bool = False) -> Tuple[dict, str]: + chat_data = _chats[chat] data = { - 'chat_id': config['telegram']['chat_id'], + 'chat_id': chat_data['id'], 'text': text } if parse_mode is not None: data['parse_mode'] = parse_mode - elif 'parse_mode' in config['telegram']: - data['parse_mode'] = config['telegram']['parse_mode'] - if disable_web_page_preview or 'disable_web_page_preview' in config['telegram']: + if disable_web_page_preview: data['disable_web_page_preview'] = 1 - return data, config['telegram']['token'] + return data, chat_data['token'] diff --git a/src/home/util.py b/src/home/util.py index 1e12243..11e7116 100644 --- a/src/home/util.py +++ b/src/home/util.py @@ -36,6 +36,14 @@ def validate_ipv4_or_hostname(address: str, raise_exception: bool = False) -> bo return False +def validate_mac_address(mac_address: str) -> bool: + mac_pattern = r'^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$' + if re.match(mac_pattern, mac_address): + return True + else: + return False + + class Addr: host: str port: Optional[int] -- cgit v1.2.3 From b0bf43e6a272d42a55158e657bd937cb82fc3d8d Mon Sep 17 00:00:00 2001 From: Evgeny Zinoviev Date: Sat, 10 Jun 2023 23:02:34 +0300 Subject: move files, rename home package to homekit --- src/home/__init__.py | 0 src/home/api/__init__.py | 19 - src/home/api/__init__.pyi | 5 - src/home/api/config.py | 15 - src/home/api/errors/__init__.py | 1 - src/home/api/errors/api_response_error.py | 28 -- src/home/api/types/__init__.py | 6 - src/home/api/types/types.py | 33 -- src/home/api/web_api_client.py | 227 ----------- src/home/audio/__init__.py | 0 src/home/audio/amixer.py | 91 ----- src/home/camera/__init__.py | 1 - src/home/camera/esp32.py | 226 ----------- src/home/camera/types.py | 5 - src/home/camera/util.py | 107 ------ src/home/config/__init__.py | 13 - src/home/config/_configs.py | 55 --- src/home/config/config.py | 387 ------------------- src/home/database/__init__.py | 29 -- src/home/database/__init__.pyi | 11 - src/home/database/_base.py | 9 - src/home/database/bots.py | 106 ------ src/home/database/clickhouse.py | 39 -- src/home/database/inverter.py | 212 ----------- src/home/database/inverter_time_formats.py | 2 - src/home/database/mysql.py | 47 --- src/home/database/sensors.py | 69 ---- src/home/database/simple_state.py | 48 --- src/home/database/sqlite.py | 67 ---- src/home/http/__init__.py | 2 - src/home/http/http.py | 106 ------ src/home/inverter/__init__.py | 3 - src/home/inverter/config.py | 13 - src/home/inverter/emulator.py | 556 --------------------------- src/home/inverter/inverter_wrapper.py | 48 --- src/home/inverter/monitor.py | 499 ------------------------ src/home/inverter/types.py | 64 ---- src/home/inverter/util.py | 8 - src/home/media/__init__.py | 22 -- src/home/media/__init__.pyi | 27 -- src/home/media/node_client.py | 119 ------ src/home/media/node_server.py | 86 ----- src/home/media/record.py | 461 ----------------------- src/home/media/record_client.py | 166 -------- src/home/media/storage.py | 210 ----------- src/home/media/types.py | 13 - src/home/mqtt/__init__.py | 7 - src/home/mqtt/_config.py | 165 -------- src/home/mqtt/_module.py | 70 ---- src/home/mqtt/_mqtt.py | 86 ----- src/home/mqtt/_node.py | 92 ----- src/home/mqtt/_payload.py | 145 ------- src/home/mqtt/_util.py | 15 - src/home/mqtt/_wrapper.py | 60 --- src/home/mqtt/module/diagnostics.py | 64 ---- src/home/mqtt/module/inverter.py | 195 ---------- src/home/mqtt/module/ota.py | 77 ---- src/home/mqtt/module/relay.py | 92 ----- src/home/mqtt/module/temphum.py | 82 ---- src/home/pio/__init__.py | 1 - src/home/pio/exceptions.py | 2 - src/home/pio/products.py | 113 ------ src/home/relay/__init__.py | 16 - src/home/relay/__init__.pyi | 2 - src/home/relay/sunxi_h3_client.py | 39 -- src/home/relay/sunxi_h3_server.py | 82 ---- src/home/soundsensor/__init__.py | 22 -- src/home/soundsensor/__init__.pyi | 8 - src/home/soundsensor/node.py | 75 ---- src/home/soundsensor/server.py | 128 ------- src/home/soundsensor/server_client.py | 38 -- src/home/telegram/__init__.py | 1 - src/home/telegram/_botcontext.py | 86 ----- src/home/telegram/_botdb.py | 32 -- src/home/telegram/_botlang.py | 120 ------ src/home/telegram/_botutil.py | 47 --- src/home/telegram/aio.py | 18 - src/home/telegram/bot.py | 583 ----------------------------- src/home/telegram/config.py | 75 ---- src/home/telegram/telegram.py | 49 --- src/home/temphum/__init__.py | 1 - src/home/temphum/base.py | 19 - src/home/temphum/i2c.py | 52 --- src/home/util.py | 255 ------------- 84 files changed, 7275 deletions(-) delete mode 100644 src/home/__init__.py delete mode 100644 src/home/api/__init__.py delete mode 100644 src/home/api/__init__.pyi delete mode 100644 src/home/api/config.py delete mode 100644 src/home/api/errors/__init__.py delete mode 100644 src/home/api/errors/api_response_error.py delete mode 100644 src/home/api/types/__init__.py delete mode 100644 src/home/api/types/types.py delete mode 100644 src/home/api/web_api_client.py delete mode 100644 src/home/audio/__init__.py delete mode 100644 src/home/audio/amixer.py delete mode 100644 src/home/camera/__init__.py delete mode 100644 src/home/camera/esp32.py delete mode 100644 src/home/camera/types.py delete mode 100644 src/home/camera/util.py delete mode 100644 src/home/config/__init__.py delete mode 100644 src/home/config/_configs.py delete mode 100644 src/home/config/config.py delete mode 100644 src/home/database/__init__.py delete mode 100644 src/home/database/__init__.pyi delete mode 100644 src/home/database/_base.py delete mode 100644 src/home/database/bots.py delete mode 100644 src/home/database/clickhouse.py delete mode 100644 src/home/database/inverter.py delete mode 100644 src/home/database/inverter_time_formats.py delete mode 100644 src/home/database/mysql.py delete mode 100644 src/home/database/sensors.py delete mode 100644 src/home/database/simple_state.py delete mode 100644 src/home/database/sqlite.py delete mode 100644 src/home/http/__init__.py delete mode 100644 src/home/http/http.py delete mode 100644 src/home/inverter/__init__.py delete mode 100644 src/home/inverter/config.py delete mode 100644 src/home/inverter/emulator.py delete mode 100644 src/home/inverter/inverter_wrapper.py delete mode 100644 src/home/inverter/monitor.py delete mode 100644 src/home/inverter/types.py delete mode 100644 src/home/inverter/util.py delete mode 100644 src/home/media/__init__.py delete mode 100644 src/home/media/__init__.pyi delete mode 100644 src/home/media/node_client.py delete mode 100644 src/home/media/node_server.py delete mode 100644 src/home/media/record.py delete mode 100644 src/home/media/record_client.py delete mode 100644 src/home/media/storage.py delete mode 100644 src/home/media/types.py delete mode 100644 src/home/mqtt/__init__.py delete mode 100644 src/home/mqtt/_config.py delete mode 100644 src/home/mqtt/_module.py delete mode 100644 src/home/mqtt/_mqtt.py delete mode 100644 src/home/mqtt/_node.py delete mode 100644 src/home/mqtt/_payload.py delete mode 100644 src/home/mqtt/_util.py delete mode 100644 src/home/mqtt/_wrapper.py delete mode 100644 src/home/mqtt/module/diagnostics.py delete mode 100644 src/home/mqtt/module/inverter.py delete mode 100644 src/home/mqtt/module/ota.py delete mode 100644 src/home/mqtt/module/relay.py delete mode 100644 src/home/mqtt/module/temphum.py delete mode 100644 src/home/pio/__init__.py delete mode 100644 src/home/pio/exceptions.py delete mode 100644 src/home/pio/products.py delete mode 100644 src/home/relay/__init__.py delete mode 100644 src/home/relay/__init__.pyi delete mode 100644 src/home/relay/sunxi_h3_client.py delete mode 100644 src/home/relay/sunxi_h3_server.py delete mode 100644 src/home/soundsensor/__init__.py delete mode 100644 src/home/soundsensor/__init__.pyi delete mode 100644 src/home/soundsensor/node.py delete mode 100644 src/home/soundsensor/server.py delete mode 100644 src/home/soundsensor/server_client.py delete mode 100644 src/home/telegram/__init__.py delete mode 100644 src/home/telegram/_botcontext.py delete mode 100644 src/home/telegram/_botdb.py delete mode 100644 src/home/telegram/_botlang.py delete mode 100644 src/home/telegram/_botutil.py delete mode 100644 src/home/telegram/aio.py delete mode 100644 src/home/telegram/bot.py delete mode 100644 src/home/telegram/config.py delete mode 100644 src/home/telegram/telegram.py delete mode 100644 src/home/temphum/__init__.py delete mode 100644 src/home/temphum/base.py delete mode 100644 src/home/temphum/i2c.py delete mode 100644 src/home/util.py (limited to 'src/home') diff --git a/src/home/__init__.py b/src/home/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/home/api/__init__.py b/src/home/api/__init__.py deleted file mode 100644 index d641f62..0000000 --- a/src/home/api/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -import importlib - -__all__ = [ - # web_api_client.py - 'WebApiClient', - 'RequestParams', - - # config.py - 'WebApiConfig' -] - - -def __getattr__(name): - if name in __all__: - file = 'config' if name == 'WebApiConfig' else 'web_api_client' - module = importlib.import_module(f'.{file}', __name__) - return getattr(module, name) - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/home/api/__init__.pyi b/src/home/api/__init__.pyi deleted file mode 100644 index 5b98161..0000000 --- a/src/home/api/__init__.pyi +++ /dev/null @@ -1,5 +0,0 @@ -from .web_api_client import ( - RequestParams as RequestParams, - WebApiClient as WebApiClient -) -from .config import WebApiConfig as WebApiConfig diff --git a/src/home/api/config.py b/src/home/api/config.py deleted file mode 100644 index 00c1097..0000000 --- a/src/home/api/config.py +++ /dev/null @@ -1,15 +0,0 @@ -from ..config import ConfigUnit -from typing import Optional, Union - - -class WebApiConfig(ConfigUnit): - NAME = 'web_api' - - @classmethod - def schema(cls) -> Optional[dict]: - return { - 'listen_addr': cls._addr_schema(required=True), - 'host': cls._addr_schema(required=True), - 'token': dict(type='string', required=True), - 'recordings_dir': dict(type='string', required=True) - } \ No newline at end of file diff --git a/src/home/api/errors/__init__.py b/src/home/api/errors/__init__.py deleted file mode 100644 index efb06aa..0000000 --- a/src/home/api/errors/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .api_response_error import ApiResponseError diff --git a/src/home/api/errors/api_response_error.py b/src/home/api/errors/api_response_error.py deleted file mode 100644 index 85d788b..0000000 --- a/src/home/api/errors/api_response_error.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Optional, List - - -class ApiResponseError(Exception): - def __init__(self, - status_code: int, - error_type: str, - error_message: str, - error_stacktrace: Optional[List[str]] = None): - super().__init__() - self.status_code = status_code - self.error_message = error_message - self.error_type = error_type - self.error_stacktrace = error_stacktrace - - def __str__(self): - def st_formatter(line: str): - return f'Remote| {line}' - - s = f'{self.error_type}: {self.error_message} (HTTP {self.status_code})' - if self.error_stacktrace is not None: - st = [] - for st_line in self.error_stacktrace: - st.append('\n'.join(st_formatter(st_subline) for st_subline in st_line.split('\n'))) - s += '\nRemote stacktrace:\n' - s += '\n'.join(st) - - return s diff --git a/src/home/api/types/__init__.py b/src/home/api/types/__init__.py deleted file mode 100644 index 9f27ff6..0000000 --- a/src/home/api/types/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .types import ( - BotType, - TemperatureSensorDataType, - TemperatureSensorLocation, - SoundSensorLocation -) diff --git a/src/home/api/types/types.py b/src/home/api/types/types.py deleted file mode 100644 index 981e798..0000000 --- a/src/home/api/types/types.py +++ /dev/null @@ -1,33 +0,0 @@ -from enum import Enum, auto - - -class BotType(Enum): - INVERTER = auto() - PUMP = auto() - SENSORS = auto() - ADMIN = auto() - SOUND = auto() - POLARIS_KETTLE = auto() - PUMP_MQTT = auto() - RELAY_MQTT = auto() - - -class TemperatureSensorLocation(Enum): - BIG_HOUSE_1 = auto() - BIG_HOUSE_2 = auto() - BIG_HOUSE_ROOM = auto() - STREET = auto() - DIANA = auto() - SPB1 = auto() - - -class TemperatureSensorDataType(Enum): - TEMPERATURE = auto() - RELATIVE_HUMIDITY = auto() - - -class SoundSensorLocation(Enum): - DIANA = auto() - BIG_HOUSE = auto() - SPB1 = auto() - diff --git a/src/home/api/web_api_client.py b/src/home/api/web_api_client.py deleted file mode 100644 index 15c1915..0000000 --- a/src/home/api/web_api_client.py +++ /dev/null @@ -1,227 +0,0 @@ -import requests -import json -import threading -import logging - -from collections import namedtuple -from datetime import datetime -from enum import Enum, auto -from typing import Optional, Callable, Union, List, Tuple, Dict -from requests.auth import HTTPBasicAuth - -from .config import WebApiConfig -from .errors import ApiResponseError -from .types import * -from ..config import config -from ..util import stringify -from ..media import RecordFile, MediaNodeClient - -_logger = logging.getLogger(__name__) -_config = WebApiConfig() - - -RequestParams = namedtuple('RequestParams', 'params, files, method') - - -class HTTPMethod(Enum): - GET = auto() - POST = auto() - - -class WebApiClient: - token: str - timeout: Union[float, Tuple[float, float]] - basic_auth: Optional[HTTPBasicAuth] - do_async: bool - async_error_handler: Optional[Callable] - async_success_handler: Optional[Callable] - - def __init__(self, timeout: Union[float, Tuple[float, float]] = 5): - self.token = config['token'] - self.timeout = timeout - self.basic_auth = None - self.do_async = False - self.async_error_handler = None - self.async_success_handler = None - - # if 'basic_auth' in config['api']: - # ba = config['api']['basic_auth'] - # col = ba.index(':') - # - # user = ba[:col] - # pw = ba[col+1:] - # - # _logger.debug(f'enabling basic auth: {user}:{pw}') - # self.basic_auth = HTTPBasicAuth(user, pw) - - # api methods - # ----------- - - def log_bot_request(self, - bot: BotType, - user_id: int, - message: str): - return self._post('log/bot_request/', { - 'bot': bot.value, - 'user_id': str(user_id), - 'message': message - }) - - def log_openwrt(self, - lines: List[Tuple[int, str]], - access_point: int): - return self._post('log/openwrt/', { - 'logs': stringify(lines), - 'ap': access_point - }) - - def get_sensors_data(self, - sensor: TemperatureSensorLocation, - hours: int): - data = self._get('sensors/data/', { - 'sensor': sensor.value, - 'hours': hours - }) - return [(datetime.fromtimestamp(date), temp, hum) for date, temp, hum in data] - - def add_sound_sensor_hits(self, - hits: List[Tuple[str, int]]): - return self._post('sound_sensors/hits/', { - 'hits': stringify(hits) - }) - - def get_sound_sensor_hits(self, - location: SoundSensorLocation, - after: datetime) -> List[dict]: - return self._process_sound_sensor_hits_data(self._get('sound_sensors/hits/', { - 'after': int(after.timestamp()), - 'location': location.value - })) - - def get_last_sound_sensor_hits(self, location: SoundSensorLocation, last: int): - return self._process_sound_sensor_hits_data(self._get('sound_sensors/hits/', { - 'last': last, - 'location': location.value - })) - - def recordings_list(self, extended=False, as_objects=False) -> Union[List[str], List[dict], List[RecordFile]]: - files = self._get('recordings/list/', {'extended': int(extended)})['data'] - if as_objects: - return MediaNodeClient.record_list_from_serialized(files) - return files - - def inverter_get_consumed_energy(self, s_from: str, s_to: str): - return self._get('inverter/consumed_energy/', { - 'from': s_from, - 'to': s_to - }) - - def inverter_get_grid_consumed_energy(self, s_from: str, s_to: str): - return self._get('inverter/grid_consumed_energy/', { - 'from': s_from, - 'to': s_to - }) - - @staticmethod - def _process_sound_sensor_hits_data(data: List[dict]) -> List[dict]: - for item in data: - item['time'] = datetime.fromtimestamp(item['time']) - return data - - # internal methods - # ---------------- - - def _get(self, *args, **kwargs): - return self._call(method=HTTPMethod.GET, *args, **kwargs) - - def _post(self, *args, **kwargs): - return self._call(method=HTTPMethod.POST, *args, **kwargs) - - def _call(self, - name: str, - params: dict, - method: HTTPMethod, - files: Optional[Dict[str, str]] = None): - if not self.do_async: - return self._make_request(name, params, method, files) - else: - t = threading.Thread(target=self._make_request_in_thread, args=(name, params, method, files)) - t.start() - return None - - def _make_request(self, - name: str, - params: dict, - method: HTTPMethod = HTTPMethod.GET, - files: Optional[Dict[str, str]] = None) -> Optional[any]: - domain = config['host'] - kwargs = {} - - if self.basic_auth is not None: - kwargs['auth'] = self.basic_auth - - if method == HTTPMethod.GET: - if files: - raise RuntimeError('can\'t upload files using GET, please use me properly') - kwargs['params'] = params - f = requests.get - else: - kwargs['data'] = params - f = requests.post - - fd = {} - if files: - for fname, fpath in files.items(): - fd[fname] = open(fpath, 'rb') - kwargs['files'] = fd - - try: - r = f(f'https://{domain}/{name}', - headers={'X-Token': self.token}, - timeout=self.timeout, - **kwargs) - - if not r.headers['content-type'].startswith('application/json'): - raise ApiResponseError(r.status_code, 'TypeError', 'content-type is not application/json') - - data = json.loads(r.text) - if r.status_code != 200: - raise ApiResponseError(r.status_code, - data['error'], - data['message'], - data['stacktrace'] if 'stacktrace' in data['error'] else None) - - return data['response'] if 'response' in data else True - finally: - for fname, f in fd.items(): - # logger.debug(f'closing file {fname} (fd={f})') - try: - f.close() - except Exception as exc: - _logger.exception(exc) - pass - - def _make_request_in_thread(self, name, params, method, files): - try: - result = self._make_request(name, params, method, files) - self._report_async_success(result, name, RequestParams(params=params, method=method, files=files)) - except Exception as e: - _logger.exception(e) - self._report_async_error(e, name, RequestParams(params=params, method=method, files=files)) - - def enable_async(self, - success_handler: Optional[Callable] = None, - error_handler: Optional[Callable] = None): - self.do_async = True - if error_handler: - self.async_error_handler = error_handler - if success_handler: - self.async_success_handler = success_handler - - def _report_async_error(self, *args): - if self.async_error_handler: - self.async_error_handler(*args) - - def _report_async_success(self, *args): - if self.async_success_handler: - self.async_success_handler(*args) \ No newline at end of file diff --git a/src/home/audio/__init__.py b/src/home/audio/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/home/audio/amixer.py b/src/home/audio/amixer.py deleted file mode 100644 index 5133c97..0000000 --- a/src/home/audio/amixer.py +++ /dev/null @@ -1,91 +0,0 @@ -import subprocess - -from ..config import app_config as config -from threading import Lock -from typing import Union, List - - -_lock = Lock() -_default_step = 5 - - -def has_control(s: str) -> bool: - for control in config['amixer']['controls']: - if control['name'] == s: - return True - return False - - -def get_caps(s: str) -> List[str]: - for control in config['amixer']['controls']: - if control['name'] == s: - return control['caps'] - raise KeyError(f'control {s} not found') - - -def get_all() -> list: - controls = [] - for control in config['amixer']['controls']: - controls.append({ - 'name': control['name'], - 'info': get(control['name']), - 'caps': control['caps'] - }) - return controls - - -def get(control: str): - return call('get', control) - - -def mute(control): - return call('set', control, 'mute') - - -def unmute(control): - return call('set', control, 'unmute') - - -def cap(control): - return call('set', control, 'cap') - - -def nocap(control): - return call('set', control, 'nocap') - - -def _get_default_step() -> int: - if 'step' in config['amixer']: - return int(config['amixer']['step']) - - return _default_step - - -def incr(control, step=None): - if step is None: - step = _get_default_step() - return call('set', control, f'{step}%+') - - -def decr(control, step=None): - if step is None: - step = _get_default_step() - return call('set', control, f'{step}%-') - - -def call(*args, return_code=False) -> Union[int, str]: - with _lock: - result = subprocess.run([config['amixer']['bin'], *args], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - if return_code: - return result.returncode - - if result.returncode != 0: - raise AmixerError(result.stderr.decode().strip()) - - return result.stdout.decode().strip() - - -class AmixerError(OSError): - pass diff --git a/src/home/camera/__init__.py b/src/home/camera/__init__.py deleted file mode 100644 index 626930b..0000000 --- a/src/home/camera/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .types import CameraType \ No newline at end of file diff --git a/src/home/camera/esp32.py b/src/home/camera/esp32.py deleted file mode 100644 index fe6de0e..0000000 --- a/src/home/camera/esp32.py +++ /dev/null @@ -1,226 +0,0 @@ -import logging -import requests -import json -import asyncio -import aioshutil - -from io import BytesIO -from functools import partial -from typing import Union, Optional -from enum import Enum -from ..api.errors import ApiResponseError -from ..util import Addr - - -class FrameSize(Enum): - UXGA_1600x1200 = 13 - SXGA_1280x1024 = 12 - HD_1280x720 = 11 - XGA_1024x768 = 10 - SVGA_800x600 = 9 - VGA_640x480 = 8 - HVGA_480x320 = 7 - CIF_400x296 = 6 - QVGA_320x240 = 5 - N_240x240 = 4 - HQVGA_240x176 = 3 - QCIF_176x144 = 2 - QQVGA_160x120 = 1 - N_96x96 = 0 - - -class WBMode(Enum): - AUTO = 0 - SUNNY = 1 - CLOUDY = 2 - OFFICE = 3 - HOME = 4 - - -def _assert_bounds(n: int, min: int, max: int): - if not min <= n <= max: - raise ValueError(f'value must be between {min} and {max}') - - -class WebClient: - def __init__(self, - addr: Addr): - self.endpoint = f'http://{addr[0]}:{addr[1]}' - self.logger = logging.getLogger(self.__class__.__name__) - self.delay = 0 - self.isfirstrequest = True - - async def syncsettings(self, settings) -> bool: - status = await self.getstatus() - self.logger.debug(f'syncsettings: status={status}') - - changed_anything = False - - for name, value in settings.items(): - server_name = name - if name == 'aec_dsp': - server_name = 'aec2' - - if server_name not in status: - # legacy compatibility - if server_name != 'vflip': - self.logger.warning(f'syncsettings: field `{server_name}` not found in camera status') - continue - - try: - # server returns 0 or 1 for bool values - if type(value) is bool: - value = int(value) - - if status[server_name] == value: - continue - except KeyError as exc: - if name != 'vflip': - self.logger.error(exc) - - try: - # fix for cases like when field is called raw_gma, but method is setrawgma() - name = name.replace('_', '') - - func = getattr(self, f'set{name}') - self.logger.debug(f'syncsettings: calling set{name}({value})') - - await func(value) - - changed_anything = True - except AttributeError as exc: - self.logger.exception(exc) - self.logger.error(f'syncsettings: method set{name}() not found') - - return changed_anything - - def setdelay(self, delay: int): - self.delay = delay - - async def capture(self, output: Optional[str] = None) -> Union[BytesIO, bool]: - kw = {} - if output: - kw['save_to'] = output - else: - kw['as_bytes'] = True - return await self._call('capture', **kw) - - async def getstatus(self): - return json.loads(await self._call('status')) - - async def setflash(self, enable: bool): - await self._control('flash', int(enable)) - - async def setframesize(self, fs: Union[int, FrameSize]): - if type(fs) is int: - fs = FrameSize(fs) - await self._control('framesize', fs.value) - - async def sethmirror(self, enable: bool): - await self._control('hmirror', int(enable)) - - async def setvflip(self, enable: bool): - await self._control('vflip', int(enable)) - - async def setawb(self, enable: bool): - await self._control('awb', int(enable)) - - async def setawbgain(self, enable: bool): - await self._control('awb_gain', int(enable)) - - async def setwbmode(self, mode: WBMode): - await self._control('wb_mode', mode.value) - - async def setaecsensor(self, enable: bool): - await self._control('aec', int(enable)) - - async def setaecdsp(self, enable: bool): - await self._control('aec2', int(enable)) - - async def setagc(self, enable: bool): - await self._control('agc', int(enable)) - - async def setagcgain(self, gain: int): - _assert_bounds(gain, 1, 31) - await self._control('agc_gain', gain) - - async def setgainceiling(self, gainceiling: int): - _assert_bounds(gainceiling, 2, 128) - await self._control('gainceiling', gainceiling) - - async def setbpc(self, enable: bool): - await self._control('bpc', int(enable)) - - async def setwpc(self, enable: bool): - await self._control('wpc', int(enable)) - - async def setrawgma(self, enable: bool): - await self._control('raw_gma', int(enable)) - - async def setlenscorrection(self, enable: bool): - await self._control('lenc', int(enable)) - - async def setdcw(self, enable: bool): - await self._control('dcw', int(enable)) - - async def setcolorbar(self, enable: bool): - await self._control('colorbar', int(enable)) - - async def setquality(self, q: int): - _assert_bounds(q, 4, 63) - await self._control('quality', q) - - async def setbrightness(self, brightness: int): - _assert_bounds(brightness, -2, -2) - await self._control('brightness', brightness) - - async def setcontrast(self, contrast: int): - _assert_bounds(contrast, -2, 2) - await self._control('contrast', contrast) - - async def setsaturation(self, saturation: int): - _assert_bounds(saturation, -2, 2) - await self._control('saturation', saturation) - - async def _control(self, var: str, value: Union[int, str]): - return await self._call('control', params={'var': var, 'val': value}) - - async def _call(self, - method: str, - params: Optional[dict] = None, - save_to: Optional[str] = None, - as_bytes=False) -> Union[str, bool, BytesIO]: - loop = asyncio.get_event_loop() - - if not self.isfirstrequest and self.delay > 0: - sleeptime = self.delay / 1000 - self.logger.debug(f'sleeping for {sleeptime}') - - await asyncio.sleep(sleeptime) - - self.isfirstrequest = False - - url = f'{self.endpoint}/{method}' - self.logger.debug(f'calling {url}, params: {params}') - - kwargs = {} - if params: - kwargs['params'] = params - if save_to: - kwargs['stream'] = True - - r = await loop.run_in_executor(None, - partial(requests.get, url, **kwargs)) - if r.status_code != 200: - raise ApiResponseError(status_code=r.status_code) - - if as_bytes: - return BytesIO(r.content) - - if save_to: - r.raise_for_status() - with open(save_to, 'wb') as f: - await aioshutil.copyfileobj(r.raw, f) - return True - - return r.text diff --git a/src/home/camera/types.py b/src/home/camera/types.py deleted file mode 100644 index de59022..0000000 --- a/src/home/camera/types.py +++ /dev/null @@ -1,5 +0,0 @@ -from enum import Enum - - -class CameraType(Enum): - ESP32 = 'esp32' diff --git a/src/home/camera/util.py b/src/home/camera/util.py deleted file mode 100644 index 97f35aa..0000000 --- a/src/home/camera/util.py +++ /dev/null @@ -1,107 +0,0 @@ -import asyncio -import os.path -import logging -import psutil - -from typing import List, Tuple -from ..util import chunks -from ..config import config - -_logger = logging.getLogger(__name__) -_temporary_fixing = '.temporary_fixing.mp4' - - -def _get_ffmpeg_path() -> str: - return 'ffmpeg' if 'ffmpeg' not in config else config['ffmpeg']['path'] - - -def time2seconds(time: str) -> int: - time, frac = time.split('.') - frac = int(frac) - - h, m, s = [int(i) for i in time.split(':')] - - return round(s + m*60 + h*3600 + frac/1000) - - -async def ffmpeg_recreate(filename: str): - filedir = os.path.dirname(filename) - tempname = os.path.join(filedir, _temporary_fixing) - mtime = os.path.getmtime(filename) - - args = [_get_ffmpeg_path(), '-nostats', '-loglevel', 'error', '-i', filename, '-c', 'copy', '-y', tempname] - proc = await asyncio.create_subprocess_exec(*args, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE) - stdout, stderr = await proc.communicate() - if proc.returncode != 0: - _logger.error(f'fix_timestamps({filename}): ffmpeg returned {proc.returncode}, stderr: {stderr.decode().strip()}') - - if os.path.isfile(tempname): - os.unlink(filename) - os.rename(tempname, filename) - os.utime(filename, (mtime, mtime)) - _logger.info(f'fix_timestamps({filename}): OK') - else: - _logger.error(f'fix_timestamps({filename}): temp file \'{tempname}\' does not exists, fix failed') - - -async def ffmpeg_cut(input: str, - output: str, - start_pos: int, - duration: int): - args = [_get_ffmpeg_path(), '-nostats', '-loglevel', 'error', '-i', input, - '-ss', str(start_pos), '-t', str(duration), - '-c', 'copy', '-y', output] - proc = await asyncio.create_subprocess_exec(*args, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE) - stdout, stderr = await proc.communicate() - if proc.returncode != 0: - _logger.error(f'ffmpeg_cut({input}, start_pos={start_pos}, duration={duration}): ffmpeg returned {proc.returncode}, stderr: {stderr.decode().strip()}') - else: - _logger.info(f'ffmpeg_cut({input}): OK') - - -def dvr_scan_timecodes(timecodes: str) -> List[Tuple[int, int]]: - tc_backup = timecodes - - timecodes = timecodes.split(',') - if len(timecodes) % 2 != 0: - raise DVRScanInvalidTimecodes(f'invalid number of timecodes. input: {tc_backup}') - - timecodes = list(map(time2seconds, timecodes)) - timecodes = list(chunks(timecodes, 2)) - - # sort out invalid fragments (dvr-scan returns them sometimes, idk why...) - timecodes = list(filter(lambda f: f[0] < f[1], timecodes)) - if not timecodes: - raise DVRScanInvalidTimecodes(f'no valid timecodes. input: {tc_backup}') - - # https://stackoverflow.com/a/43600953 - timecodes.sort(key=lambda interval: interval[0]) - merged = [timecodes[0]] - for current in timecodes: - previous = merged[-1] - if current[0] <= previous[1]: - previous[1] = max(previous[1], current[1]) - else: - merged.append(current) - - return merged - - -class DVRScanInvalidTimecodes(Exception): - pass - - -def has_handle(fpath): - for proc in psutil.process_iter(): - try: - for item in proc.open_files(): - if fpath == item.path: - return True - except Exception: - pass - - return False \ No newline at end of file diff --git a/src/home/config/__init__.py b/src/home/config/__init__.py deleted file mode 100644 index 2fa5214..0000000 --- a/src/home/config/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .config import ( - Config, - ConfigUnit, - AppConfigUnit, - Translation, - config, - is_development_mode, - setup_logging -) -from ._configs import ( - LinuxBoardsConfig, - ServicesListConfig -) \ No newline at end of file diff --git a/src/home/config/_configs.py b/src/home/config/_configs.py deleted file mode 100644 index 1628cba..0000000 --- a/src/home/config/_configs.py +++ /dev/null @@ -1,55 +0,0 @@ -from .config import ConfigUnit -from typing import Optional - - -class ServicesListConfig(ConfigUnit): - NAME = 'services_list' - - @classmethod - def schema(cls) -> Optional[dict]: - return { - 'type': 'list', - 'empty': False, - 'schema': { - 'type': 'string' - } - } - - -class LinuxBoardsConfig(ConfigUnit): - NAME = 'linux_boards' - - @classmethod - def schema(cls) -> Optional[dict]: - return { - 'type': 'dict', - 'schema': { - 'mdns': {'type': 'string', 'required': True}, - 'board': {'type': 'string', 'required': True}, - 'network': { - 'type': 'list', - 'required': True, - 'empty': False, - 'allowed': ['wifi', 'ethernet'] - }, - 'ram': {'type': 'integer', 'required': True}, - 'online': {'type': 'boolean', 'required': True}, - - # optional - 'services': { - 'type': 'list', - 'empty': False, - 'allowed': ServicesListConfig().get() - }, - 'ext_hdd': { - 'type': 'list', - 'schema': { - 'type': 'dict', - 'schema': { - 'mountpoint': {'type': 'string', 'required': True}, - 'size': {'type': 'integer', 'required': True} - } - }, - }, - } - } diff --git a/src/home/config/config.py b/src/home/config/config.py deleted file mode 100644 index 7344386..0000000 --- a/src/home/config/config.py +++ /dev/null @@ -1,387 +0,0 @@ -import yaml -import logging -import os -import cerberus -import cerberus.errors - -from abc import ABC -from typing import Optional, Any, MutableMapping, Union -from argparse import ArgumentParser -from enum import Enum, auto -from os.path import join, isdir, isfile -from ..util import Addr - - -class MyValidator(cerberus.Validator): - def _normalize_coerce_addr(self, value): - return Addr.fromstring(value) - - -MyValidator.types_mapping['addr'] = cerberus.TypeDefinition('Addr', (Addr,), ()) - - -CONFIG_DIRECTORIES = ( - join(os.environ['HOME'], '.config', 'homekit'), - '/etc/homekit' -) - - -class RootSchemaType(Enum): - DEFAULT = auto() - DICT = auto() - LIST = auto() - - -class BaseConfigUnit(ABC): - _data: MutableMapping[str, Any] - _logger: logging.Logger - - def __init__(self): - self._data = {} - self._logger = logging.getLogger(self.__class__.__name__) - - def __getitem__(self, key): - return self._data[key] - - def __setitem__(self, key, value): - raise NotImplementedError('overwriting config values is prohibited') - - def __contains__(self, key): - return key in self._data - - def load_from(self, path: str): - with open(path, 'r') as fd: - self._data = yaml.safe_load(fd) - - def get(self, - key: Optional[str] = None, - default=None): - if key is None: - return self._data - - cur = self._data - pts = key.split('.') - for i in range(len(pts)): - k = pts[i] - if i < len(pts)-1: - if k not in cur: - raise KeyError(f'key {k} not found') - else: - return cur[k] if k in cur else default - cur = self._data[k] - - raise KeyError(f'option {key} not found') - - -class ConfigUnit(BaseConfigUnit): - NAME = 'dumb' - - def __init__(self, name=None, load=True): - super().__init__() - - self._data = {} - self._logger = logging.getLogger(self.__class__.__name__) - - if self.NAME != 'dumb' and load: - self.load_from(self.get_config_path()) - self.validate() - - elif name is not None: - self.NAME = name - - @classmethod - def get_config_path(cls, name=None) -> str: - if name is None: - name = cls.NAME - if name is None: - raise ValueError('get_config_path: name is none') - - for dirname in CONFIG_DIRECTORIES: - if isdir(dirname): - filename = join(dirname, f'{name}.yaml') - if isfile(filename): - return filename - - raise IOError(f'\'{name}.yaml\' not found') - - @classmethod - def schema(cls) -> Optional[dict]: - return None - - @classmethod - def _addr_schema(cls, required=False, **kwargs): - return { - 'type': 'addr', - 'coerce': Addr.fromstring, - 'required': required, - **kwargs - } - - def validate(self): - schema = self.schema() - if not schema: - self._logger.warning('validate: no schema') - return - - if isinstance(self, AppConfigUnit): - schema['logging'] = { - 'type': 'dict', - 'schema': { - 'logging': {'type': 'boolean'} - } - } - - rst = RootSchemaType.DEFAULT - try: - if schema['type'] == 'dict': - rst = RootSchemaType.DICT - elif schema['type'] == 'list': - rst = RootSchemaType.LIST - elif schema['roottype'] == 'dict': - del schema['roottype'] - rst = RootSchemaType.DICT - except KeyError: - pass - - v = MyValidator() - - if rst == RootSchemaType.DICT: - normalized = v.validated({'document': self._data}, - {'document': { - 'type': 'dict', - 'keysrules': {'type': 'string'}, - 'valuesrules': schema - }})['document'] - elif rst == RootSchemaType.LIST: - v = MyValidator() - normalized = v.validated({'document': self._data}, {'document': schema})['document'] - else: - normalized = v.validated(self._data, schema) - - self._data = normalized - - try: - self.custom_validator(self._data) - except Exception as e: - raise cerberus.DocumentError(f'{self.__class__.__name__}: {str(e)}') - - @staticmethod - def custom_validator(data): - pass - - def get_addr(self, key: str): - return Addr.fromstring(self.get(key)) - - -class AppConfigUnit(ConfigUnit): - _logging_verbose: bool - _logging_fmt: Optional[str] - _logging_file: Optional[str] - - def __init__(self, *args, **kwargs): - super().__init__(load=False, *args, **kwargs) - self._logging_verbose = False - self._logging_fmt = None - self._logging_file = None - - def logging_set_fmt(self, fmt: str) -> None: - self._logging_fmt = fmt - - def logging_get_fmt(self) -> Optional[str]: - try: - return self['logging']['default_fmt'] - except KeyError: - return self._logging_fmt - - def logging_set_file(self, file: str) -> None: - self._logging_file = file - - def logging_get_file(self) -> Optional[str]: - try: - return self['logging']['file'] - except KeyError: - return self._logging_file - - def logging_set_verbose(self): - self._logging_verbose = True - - def logging_is_verbose(self) -> bool: - try: - return bool(self['logging']['verbose']) - except KeyError: - return self._logging_verbose - - -class TranslationUnit(BaseConfigUnit): - pass - - -class Translation: - LANGUAGES = ('en', 'ru') - _langs: dict[str, TranslationUnit] - - def __init__(self, name: str): - super().__init__() - self._langs = {} - for lang in self.LANGUAGES: - for dirname in CONFIG_DIRECTORIES: - if isdir(dirname): - filename = join(dirname, f'i18n-{lang}', f'{name}.yaml') - if lang in self._langs: - raise RuntimeError(f'{name}: translation unit for lang \'{lang}\' already loaded') - self._langs[lang] = TranslationUnit() - self._langs[lang].load_from(filename) - diff = set() - for data in self._langs.values(): - diff ^= data.get().keys() - if len(diff) > 0: - raise RuntimeError(f'{name}: translation units have difference in keys: ' + ', '.join(diff)) - - def get(self, lang: str) -> TranslationUnit: - return self._langs[lang] - - -class Config: - app_name: Optional[str] - app_config: AppConfigUnit - - def __init__(self): - self.app_name = None - self.app_config = AppConfigUnit() - - def load_app(self, - name: Optional[Union[str, AppConfigUnit, bool]] = None, - use_cli=True, - parser: ArgumentParser = None, - no_config=False): - global app_config - - if not no_config \ - and not isinstance(name, str) \ - and not isinstance(name, bool) \ - and issubclass(name, AppConfigUnit) or name == AppConfigUnit: - self.app_name = name.NAME - self.app_config = name() - app_config = self.app_config - else: - self.app_name = name if isinstance(name, str) else None - - if self.app_name is None and not use_cli: - raise RuntimeError('either config name must be none or use_cli must be True') - - no_config = name is False or no_config - path = None - - if use_cli: - if parser is None: - parser = ArgumentParser() - if not no_config: - parser.add_argument('-c', '--config', type=str, required=name is None, - help='Path to the config in TOML or YAML format') - parser.add_argument('-V', '--verbose', action='store_true') - parser.add_argument('--log-file', type=str) - parser.add_argument('--log-default-fmt', action='store_true') - args = parser.parse_args() - - if not no_config and args.config: - path = args.config - - if args.verbose: - self.app_config.logging_set_verbose() - if args.log_file: - self.app_config.logging_set_file(args.log_file) - if args.log_default_fmt: - self.app_config.logging_set_fmt(args.log_default_fmt) - - if not isinstance(name, ConfigUnit): - if not no_config and path is None: - path = ConfigUnit.get_config_path(name=self.app_name) - - if not no_config: - self.app_config.load_from(path) - self.app_config.validate() - - setup_logging(self.app_config.logging_is_verbose(), - self.app_config.logging_get_file(), - self.app_config.logging_get_fmt()) - - if use_cli: - return args - - -config = Config() - - -def is_development_mode() -> bool: - if 'HK_MODE' in os.environ and os.environ['HK_MODE'] == 'dev': - return True - - return ('logging' in config.app_config) and ('verbose' in config.app_config['logging']) and (config.app_config['logging']['verbose'] is True) - - -def setup_logging(verbose=False, log_file=None, default_fmt=None): - logging_level = logging.INFO - if is_development_mode() or verbose: - logging_level = logging.DEBUG - _add_logging_level('TRACE', logging.DEBUG-5) - - log_config = {'level': logging_level} - if not default_fmt: - log_config['format'] = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - - if log_file is not None: - log_config['filename'] = log_file - log_config['encoding'] = 'utf-8' - - logging.basicConfig(**log_config) - - -# https://stackoverflow.com/questions/2183233/how-to-add-a-custom-loglevel-to-pythons-logging-facility/35804945#35804945 -def _add_logging_level(levelName, levelNum, methodName=None): - """ - Comprehensively adds a new logging level to the `logging` module and the - currently configured logging class. - - `levelName` becomes an attribute of the `logging` module with the value - `levelNum`. `methodName` becomes a convenience method for both `logging` - itself and the class returned by `logging.getLoggerClass()` (usually just - `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is - used. - - To avoid accidental clobberings of existing attributes, this method will - raise an `AttributeError` if the level name is already an attribute of the - `logging` module or if the method name is already present - - Example - ------- - >>> addLoggingLevel('TRACE', logging.DEBUG - 5) - >>> logging.getLogger(__name__).setLevel("TRACE") - >>> logging.getLogger(__name__).trace('that worked') - >>> logging.trace('so did this') - >>> logging.TRACE - 5 - - """ - if not methodName: - methodName = levelName.lower() - - if hasattr(logging, levelName): - raise AttributeError('{} already defined in logging module'.format(levelName)) - if hasattr(logging, methodName): - raise AttributeError('{} already defined in logging module'.format(methodName)) - if hasattr(logging.getLoggerClass(), methodName): - raise AttributeError('{} already defined in logger class'.format(methodName)) - - # This method was inspired by the answers to Stack Overflow post - # http://stackoverflow.com/q/2183233/2988730, especially - # http://stackoverflow.com/a/13638084/2988730 - def logForLevel(self, message, *args, **kwargs): - if self.isEnabledFor(levelNum): - self._log(levelNum, message, args, **kwargs) - def logToRoot(message, *args, **kwargs): - logging.log(levelNum, message, *args, **kwargs) - - logging.addLevelName(levelNum, levelName) - setattr(logging, levelName, levelNum) - setattr(logging.getLoggerClass(), methodName, logForLevel) - setattr(logging, methodName, logToRoot) \ No newline at end of file diff --git a/src/home/database/__init__.py b/src/home/database/__init__.py deleted file mode 100644 index b50cbce..0000000 --- a/src/home/database/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -import importlib - -__all__ = [ - 'get_mysql', - 'mysql_now', - 'get_clickhouse', - 'SimpleState', - - 'SensorsDatabase', - 'InverterDatabase', - 'BotsDatabase' -] - - -def __getattr__(name: str): - if name in __all__: - if name.endswith('Database'): - file = name[:-8].lower() - elif 'mysql' in name: - file = 'mysql' - elif 'clickhouse' in name: - file = 'clickhouse' - else: - file = 'simple_state' - - module = importlib.import_module(f'.{file}', __name__) - return getattr(module, name) - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/home/database/__init__.pyi b/src/home/database/__init__.pyi deleted file mode 100644 index 31aae5d..0000000 --- a/src/home/database/__init__.pyi +++ /dev/null @@ -1,11 +0,0 @@ -from .mysql import ( - get_mysql as get_mysql, - mysql_now as mysql_now -) -from .clickhouse import get_clickhouse as get_clickhouse - -from simple_state import SimpleState as SimpleState - -from .sensors import SensorsDatabase as SensorsDatabase -from .inverter import InverterDatabase as InverterDatabase -from .bots import BotsDatabase as BotsDatabase diff --git a/src/home/database/_base.py b/src/home/database/_base.py deleted file mode 100644 index c01e62b..0000000 --- a/src/home/database/_base.py +++ /dev/null @@ -1,9 +0,0 @@ -import os - - -def get_data_root_directory(name: str) -> str: - return os.path.join( - os.environ['HOME'], - '.config', - 'homekit', - 'data') \ No newline at end of file diff --git a/src/home/database/bots.py b/src/home/database/bots.py deleted file mode 100644 index cde48b9..0000000 --- a/src/home/database/bots.py +++ /dev/null @@ -1,106 +0,0 @@ -import pytz - -from .mysql import mysql_now, MySQLDatabase, datetime_fmt -from ..api.types import ( - BotType, - SoundSensorLocation -) -from typing import Optional, List, Tuple -from datetime import datetime -from html import escape - - -class OpenwrtLogRecord: - id: int - log_time: datetime - received_time: datetime - text: str - - def __init__(self, id, text, log_time, received_time): - self.id = id - self.text = text - self.log_time = log_time - self.received_time = received_time - - def __repr__(self): - return f"{self.log_time.strftime('%H:%M:%S')} {escape(self.text)}" - - -class BotsDatabase(MySQLDatabase): - def add_request(self, - bot: BotType, - user_id: int, - message: str): - with self.cursor() as cursor: - cursor.execute("INSERT INTO requests_log (user_id, message, bot, time) VALUES (%s, %s, %s, %s)", - (user_id, message, bot.name.lower(), mysql_now())) - self.commit() - - def add_openwrt_logs(self, - lines: List[Tuple[datetime, str]], - access_point: int): - now = datetime.now() - with self.cursor() as cursor: - for line in lines: - time, text = line - cursor.execute("INSERT INTO openwrt (log_time, received_time, text, ap) VALUES (%s, %s, %s, %s)", - (time.strftime(datetime_fmt), now.strftime(datetime_fmt), text, access_point)) - self.commit() - - def add_sound_hits(self, - hits: List[Tuple[SoundSensorLocation, int]], - time: datetime): - with self.cursor() as cursor: - for loc, count in hits: - cursor.execute("INSERT INTO sound_hits (location, `time`, hits) VALUES (%s, %s, %s)", - (loc.name.lower(), time.strftime(datetime_fmt), count)) - self.commit() - - def get_sound_hits(self, - location: SoundSensorLocation, - after: Optional[datetime] = None, - last: Optional[int] = None) -> List[dict]: - with self.cursor(dictionary=True) as cursor: - sql = "SELECT `time`, hits FROM sound_hits WHERE location=%s" - args = [location.name.lower()] - - if after: - sql += ' AND `time` >= %s ORDER BY time DESC' - args.append(after) - elif last: - sql += ' ORDER BY time DESC LIMIT 0, %s' - args.append(last) - else: - raise ValueError('no `after`, no `last`, what do you expect?') - - cursor.execute(sql, tuple(args)) - data = [] - for row in cursor.fetchall(): - data.append({ - 'time': row['time'], - 'hits': row['hits'] - }) - return data - - def get_openwrt_logs(self, - filter_text: str, - min_id: int, - access_point: int, - limit: int = None) -> List[OpenwrtLogRecord]: - tz = pytz.timezone('Europe/Moscow') - with self.cursor(dictionary=True) as cursor: - sql = "SELECT * FROM openwrt WHERE ap=%s AND text LIKE %s AND id > %s" - if limit is not None: - sql += f" LIMIT {limit}" - - cursor.execute(sql, (access_point, f'%{filter_text}%', min_id)) - data = [] - for row in cursor.fetchall(): - data.append(OpenwrtLogRecord( - id=int(row['id']), - text=row['text'], - log_time=row['log_time'].astimezone(tz), - received_time=row['received_time'].astimezone(tz) - )) - - return data diff --git a/src/home/database/clickhouse.py b/src/home/database/clickhouse.py deleted file mode 100644 index d0ec283..0000000 --- a/src/home/database/clickhouse.py +++ /dev/null @@ -1,39 +0,0 @@ -import logging - -from zoneinfo import ZoneInfo -from datetime import datetime -from clickhouse_driver import Client as ClickhouseClient -from ..config import is_development_mode - -_links = {} - - -def get_clickhouse(db: str) -> ClickhouseClient: - if db not in _links: - _links[db] = ClickhouseClient.from_url(f'clickhouse://localhost/{db}') - - return _links[db] - - -class ClickhouseDatabase: - def __init__(self, db: str): - self.db = get_clickhouse(db) - - self.server_timezone = self.db.execute('SELECT timezone()')[0][0] - self.logger = logging.getLogger(self.__class__.__name__) - - def query(self, *args, **kwargs): - settings = {'use_client_time_zone': True} - kwargs['settings'] = settings - - if 'no_tz_fix' not in kwargs and len(args) > 1 and isinstance(args[1], dict): - for k, v in args[1].items(): - if isinstance(v, datetime): - args[1][k] = v.astimezone(tz=ZoneInfo(self.server_timezone)) - - result = self.db.execute(*args, **kwargs) - - if is_development_mode(): - self.logger.debug(args[0] if len(args) == 1 else args[0] % args[1]) - - return result diff --git a/src/home/database/inverter.py b/src/home/database/inverter.py deleted file mode 100644 index fc3f74f..0000000 --- a/src/home/database/inverter.py +++ /dev/null @@ -1,212 +0,0 @@ -from time import time -from datetime import datetime, timedelta -from typing import Optional -from collections import namedtuple - -from .clickhouse import ClickhouseDatabase - - -IntervalList = list[list[Optional[datetime]]] - - -class InverterDatabase(ClickhouseDatabase): - def __init__(self): - super().__init__('solarmon') - - def add_generation(self, home_id: int, client_time: int, watts: int) -> None: - self.db.execute( - 'INSERT INTO generation (ClientTime, ReceivedTime, HomeID, Watts) VALUES', - [[client_time, round(time()), home_id, watts]] - ) - - def add_status(self, home_id: int, - client_time: int, - grid_voltage: int, - grid_freq: int, - ac_output_voltage: int, - ac_output_freq: int, - ac_output_apparent_power: int, - ac_output_active_power: int, - output_load_percent: int, - battery_voltage: int, - battery_voltage_scc: int, - battery_voltage_scc2: int, - battery_discharge_current: int, - battery_charge_current: int, - battery_capacity: int, - inverter_heat_sink_temp: int, - mppt1_charger_temp: int, - mppt2_charger_temp: int, - pv1_input_power: int, - pv2_input_power: int, - pv1_input_voltage: int, - pv2_input_voltage: int, - mppt1_charger_status: int, - mppt2_charger_status: int, - battery_power_direction: int, - dc_ac_power_direction: int, - line_power_direction: int, - load_connected: int) -> None: - self.db.execute("""INSERT INTO status ( - ClientTime, - ReceivedTime, - HomeID, - GridVoltage, - GridFrequency, - ACOutputVoltage, - ACOutputFrequency, - ACOutputApparentPower, - ACOutputActivePower, - OutputLoadPercent, - BatteryVoltage, - BatteryVoltageSCC, - BatteryVoltageSCC2, - BatteryDischargingCurrent, - BatteryChargingCurrent, - BatteryCapacity, - HeatSinkTemp, - MPPT1ChargerTemp, - MPPT2ChargerTemp, - PV1InputPower, - PV2InputPower, - PV1InputVoltage, - PV2InputVoltage, - MPPT1ChargerStatus, - MPPT2ChargerStatus, - BatteryPowerDirection, - DCACPowerDirection, - LinePowerDirection, - LoadConnected) VALUES""", [[ - client_time, - round(time()), - home_id, - grid_voltage, - grid_freq, - ac_output_voltage, - ac_output_freq, - ac_output_apparent_power, - ac_output_active_power, - output_load_percent, - battery_voltage, - battery_voltage_scc, - battery_voltage_scc2, - battery_discharge_current, - battery_charge_current, - battery_capacity, - inverter_heat_sink_temp, - mppt1_charger_temp, - mppt2_charger_temp, - pv1_input_power, - pv2_input_power, - pv1_input_voltage, - pv2_input_voltage, - mppt1_charger_status, - mppt2_charger_status, - battery_power_direction, - dc_ac_power_direction, - line_power_direction, - load_connected - ]]) - - def get_consumed_energy(self, dt_from: datetime, dt_to: datetime) -> float: - rows = self.query('SELECT ClientTime, ACOutputActivePower FROM status' - ' WHERE ClientTime >= %(from)s AND ClientTime <= %(to)s' - ' ORDER BY ClientTime', {'from': dt_from, 'to': dt_to}) - prev_time = None - prev_wh = 0 - - ws = 0 # watt-seconds - for t, wh in rows: - if prev_time is not None: - n = (t - prev_time).total_seconds() - ws += prev_wh * n - - prev_time = t - prev_wh = wh - - return ws / 3600 # convert to watt-hours - - def get_intervals_by_condition(self, - dt_from: datetime, - dt_to: datetime, - cond_start: str, - cond_end: str) -> IntervalList: - rows = None - ranges = [[None, None]] - - while rows is None or len(rows) > 0: - if ranges[len(ranges)-1][0] is None: - condition = cond_start - range_idx = 0 - else: - condition = cond_end - range_idx = 1 - - rows = self.query('SELECT ClientTime FROM status ' - f'WHERE ClientTime > %(from)s AND ClientTime <= %(to)s AND {condition}' - ' ORDER BY ClientTime LIMIT 1', - {'from': dt_from, 'to': dt_to}) - if not rows: - break - - row = rows[0] - - ranges[len(ranges) - 1][range_idx] = row[0] - if range_idx == 1: - ranges.append([None, None]) - - dt_from = row[0] - - if ranges[len(ranges)-1][0] is None: - ranges.pop() - elif ranges[len(ranges)-1][1] is None: - ranges[len(ranges)-1][1] = dt_to - timedelta(seconds=1) - - return ranges - - def get_grid_connected_intervals(self, dt_from: datetime, dt_to: datetime) -> IntervalList: - return self.get_intervals_by_condition(dt_from, dt_to, 'GridFrequency > 0', 'GridFrequency = 0') - - def get_grid_used_intervals(self, dt_from: datetime, dt_to: datetime) -> IntervalList: - return self.get_intervals_by_condition(dt_from, - dt_to, - "LinePowerDirection = 'Input'", - "LinePowerDirection != 'Input'") - - def get_grid_consumed_energy(self, dt_from: datetime, dt_to: datetime) -> float: - PrevData = namedtuple('PrevData', 'time, pd, bat_chg, bat_dis, wh') - - ws = 0 # watt-seconds - amps = 0 # amper-seconds - - intervals = self.get_grid_used_intervals(dt_from, dt_to) - for dt_start, dt_end in intervals: - fields = ', '.join([ - 'ClientTime', - 'DCACPowerDirection', - 'BatteryChargingCurrent', - 'BatteryDischargingCurrent', - 'ACOutputActivePower' - ]) - rows = self.query(f'SELECT {fields} FROM status' - ' WHERE ClientTime >= %(from)s AND ClientTime < %(to)s ORDER BY ClientTime', - {'from': dt_start, 'to': dt_end}) - - prev = PrevData(time=None, pd=None, bat_chg=None, bat_dis=None, wh=None) - for ct, pd, bat_chg, bat_dis, wh in rows: - if prev.time is not None: - n = (ct-prev.time).total_seconds() - ws += prev.wh * n - - if pd == 'DC/AC': - amps -= prev.bat_dis * n - elif pd == 'AC/DC': - amps += prev.bat_chg * n - - prev = PrevData(time=ct, pd=pd, bat_chg=bat_chg, bat_dis=bat_dis, wh=wh) - - amps /= 3600 - wh = ws / 3600 - wh += amps*48 - - return wh diff --git a/src/home/database/inverter_time_formats.py b/src/home/database/inverter_time_formats.py deleted file mode 100644 index 7c37d30..0000000 --- a/src/home/database/inverter_time_formats.py +++ /dev/null @@ -1,2 +0,0 @@ -FormatTime = '%Y-%m-%d %H:%M:%S' -FormatDate = '%Y-%m-%d' diff --git a/src/home/database/mysql.py b/src/home/database/mysql.py deleted file mode 100644 index fe97cd4..0000000 --- a/src/home/database/mysql.py +++ /dev/null @@ -1,47 +0,0 @@ -import time -import logging - -from mysql.connector import connect, MySQLConnection, Error -from typing import Optional -from ..config import config - -link: Optional[MySQLConnection] = None -logger = logging.getLogger(__name__) - -datetime_fmt = '%Y-%m-%d %H:%M:%S' - - -def get_mysql() -> MySQLConnection: - global link - - if link is not None: - return link - - link = connect( - host=config['mysql']['host'], - user=config['mysql']['user'], - password=config['mysql']['password'], - database=config['mysql']['database'], - ) - link.time_zone = '+01:00' - return link - - -def mysql_now() -> str: - return time.strftime('%Y-%m-%d %H:%M:%S') - - -class MySQLDatabase: - def __init__(self): - self.db = get_mysql() - - def cursor(self, **kwargs): - try: - self.db.ping(reconnect=True, attempts=2) - except Error as e: - logger.exception(e) - self.db = get_mysql() - return self.db.cursor(**kwargs) - - def commit(self): - self.db.commit() diff --git a/src/home/database/sensors.py b/src/home/database/sensors.py deleted file mode 100644 index 8155108..0000000 --- a/src/home/database/sensors.py +++ /dev/null @@ -1,69 +0,0 @@ -from time import time -from datetime import datetime -from typing import Tuple, List -from .clickhouse import ClickhouseDatabase -from ..api.types import TemperatureSensorLocation - - -def get_temperature_table(sensor: TemperatureSensorLocation) -> str: - if sensor == TemperatureSensorLocation.DIANA: - return 'temp_diana' - - elif sensor == TemperatureSensorLocation.STREET: - return 'temp_street' - - elif sensor == TemperatureSensorLocation.BIG_HOUSE_1: - return 'temp' - - elif sensor == TemperatureSensorLocation.BIG_HOUSE_2: - return 'temp_roof' - - elif sensor == TemperatureSensorLocation.BIG_HOUSE_ROOM: - return 'temp_room' - - elif sensor == TemperatureSensorLocation.SPB1: - return 'temp_spb1' - - -class SensorsDatabase(ClickhouseDatabase): - def __init__(self): - super().__init__('home') - - def add_temperature(self, - home_id: int, - client_time: int, - sensor: TemperatureSensorLocation, - temp: int, - rh: int): - table = get_temperature_table(sensor) - sql = """INSERT INTO """ + table + """ ( - ClientTime, - ReceivedTime, - HomeID, - Temperature, - RelativeHumidity - ) VALUES""" - self.db.execute(sql, [[ - client_time, - int(time()), - home_id, - temp, - rh - ]]) - - def get_temperature_recordings(self, - sensor: TemperatureSensorLocation, - time_range: Tuple[datetime, datetime], - home_id=1) -> List[tuple]: - table = get_temperature_table(sensor) - sql = f"""SELECT ClientTime, Temperature, RelativeHumidity - FROM {table} - WHERE ClientTime >= %(from)s AND ClientTime <= %(to)s - ORDER BY ClientTime""" - dt_from, dt_to = time_range - - data = self.query(sql, { - 'from': dt_from, - 'to': dt_to - }) - return [(date, temp/100, humidity/100) for date, temp, humidity in data] diff --git a/src/home/database/simple_state.py b/src/home/database/simple_state.py deleted file mode 100644 index 2b8ebe7..0000000 --- a/src/home/database/simple_state.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -import json -import atexit - -from ._base import get_data_root_directory - - -class SimpleState: - def __init__(self, - name: str, - default: dict = None): - if default is None: - default = {} - elif type(default) is not dict: - raise TypeError('default must be dictionary') - - path = os.path.join(get_data_root_directory(), name) - if not os.path.exists(path): - self._data = default - else: - with open(path, 'r') as f: - self._data = json.loads(f.read()) - - self._file = path - atexit.register(self.__cleanup) - - def __cleanup(self): - if hasattr(self, '_file'): - with open(self._file, 'w') as f: - f.write(json.dumps(self._data)) - atexit.unregister(self.__cleanup) - - def __del__(self): - if 'open' in __builtins__: - self.__cleanup() - - def __getitem__(self, key): - return self._data[key] - - def __setitem__(self, key, value): - self._data[key] = value - - def __contains__(self, key): - return key in self._data - - def __delitem__(self, key): - if key in self._data: - del self._data[key] diff --git a/src/home/database/sqlite.py b/src/home/database/sqlite.py deleted file mode 100644 index 0af1f54..0000000 --- a/src/home/database/sqlite.py +++ /dev/null @@ -1,67 +0,0 @@ -import sqlite3 -import os.path -import logging - -from ._base import get_data_root_directory -from ..config import config, is_development_mode - - -def _get_database_path(name: str) -> str: - return os.path.join( - get_data_root_directory(), - f'{name}.db') - - -class SQLiteBase: - SCHEMA = 1 - - def __init__(self, name=None, check_same_thread=False): - if name is None: - name = config.app_config['database_name'] - database_path = _get_database_path(name) - if not os.path.exists(os.path.dirname(database_path)): - os.makedirs(os.path.dirname(database_path)) - - self.logger = logging.getLogger(self.__class__.__name__) - self.sqlite = sqlite3.connect(database_path, check_same_thread=check_same_thread) - - if is_development_mode(): - self.sql_logger = logging.getLogger(self.__class__.__name__) - self.sql_logger.setLevel('TRACE') - self.sqlite.set_trace_callback(self.sql_logger.trace) - - sqlite_version = self._get_sqlite_version() - self.logger.debug(f'SQLite version: {sqlite_version}') - - schema_version = self.schema_get_version() - self.logger.debug(f'Schema version: {schema_version}') - - self.schema_init(schema_version) - self.schema_set_version(self.SCHEMA) - - def __del__(self): - if self.sqlite: - self.sqlite.commit() - self.sqlite.close() - - def _get_sqlite_version(self) -> str: - cursor = self.sqlite.cursor() - cursor.execute("SELECT sqlite_version()") - return cursor.fetchone()[0] - - def schema_get_version(self) -> int: - cursor = self.sqlite.execute('PRAGMA user_version') - return int(cursor.fetchone()[0]) - - def schema_set_version(self, v) -> None: - self.sqlite.execute('PRAGMA user_version={:d}'.format(v)) - self.logger.info(f'Schema set to {v}') - - def cursor(self) -> sqlite3.Cursor: - return self.sqlite.cursor() - - def commit(self) -> None: - return self.sqlite.commit() - - def schema_init(self, version: int) -> None: - raise ValueError(f'{self.__class__.__name__}: must override schema_init') diff --git a/src/home/http/__init__.py b/src/home/http/__init__.py deleted file mode 100644 index 6030e95..0000000 --- a/src/home/http/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .http import serve, ok, routes, HTTPServer -from aiohttp.web import FileResponse, StreamResponse, Request, Response diff --git a/src/home/http/http.py b/src/home/http/http.py deleted file mode 100644 index 3e70751..0000000 --- a/src/home/http/http.py +++ /dev/null @@ -1,106 +0,0 @@ -import logging -import asyncio - -from aiohttp import web -from aiohttp.web import Response -from aiohttp.web_exceptions import HTTPNotFound - -from ..util import stringify, format_tb, Addr - - -_logger = logging.getLogger(__name__) - - -@web.middleware -async def errors_handler_middleware(request, handler): - try: - response = await handler(request) - return response - - except HTTPNotFound: - return web.json_response({'error': 'not found'}, status=404) - - except Exception as exc: - _logger.exception(exc) - data = { - 'error': exc.__class__.__name__, - 'message': exc.message if hasattr(exc, 'message') else str(exc) - } - tb = format_tb(exc) - if tb: - data['stacktrace'] = tb - - return web.json_response(data, status=500) - - -def serve(addr: Addr, route_table: web.RouteTableDef, handle_signals: bool = True): - app = web.Application() - app.add_routes(route_table) - app.middlewares.append(errors_handler_middleware) - - host, port = addr - - web.run_app(app, - host=host, - port=port, - handle_signals=handle_signals) - - -def routes() -> web.RouteTableDef: - return web.RouteTableDef() - - -def ok(data=None): - if data is None: - data = 1 - response = {'response': data} - return web.json_response(response, dumps=stringify) - - -class HTTPServer: - def __init__(self, addr: Addr, handle_errors=True): - self.addr = addr - self.app = web.Application() - self.logger = logging.getLogger(self.__class__.__name__) - - if handle_errors: - self.app.middlewares.append(errors_handler_middleware) - - def _add_route(self, - method: str, - path: str, - handler: callable): - self.app.router.add_routes([getattr(web, method)(path, handler)]) - - def get(self, path, handler): - self._add_route('get', path, handler) - - def post(self, path, handler): - self._add_route('post', path, handler) - - def put(self, path, handler): - self._add_route('put', path, handler) - - def delete(self, path, handler): - self._add_route('delete', path, handler) - - def run(self, event_loop=None, handle_signals=True): - if not event_loop: - event_loop = asyncio.get_event_loop() - - runner = web.AppRunner(self.app, handle_signals=handle_signals) - event_loop.run_until_complete(runner.setup()) - - host, port = self.addr - site = web.TCPSite(runner, host=host, port=port) - event_loop.run_until_complete(site.start()) - - self.logger.info(f'Server started at http://{host}:{port}') - - event_loop.run_forever() - - def ok(self, data=None): - return ok(data) - - def plain(self, text: str): - return Response(text=text, content_type='text/plain') diff --git a/src/home/inverter/__init__.py b/src/home/inverter/__init__.py deleted file mode 100644 index 8831ef3..0000000 --- a/src/home/inverter/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .monitor import InverterMonitor -from .inverter_wrapper import wrapper_instance -from .util import beautify_table diff --git a/src/home/inverter/config.py b/src/home/inverter/config.py deleted file mode 100644 index e284dfe..0000000 --- a/src/home/inverter/config.py +++ /dev/null @@ -1,13 +0,0 @@ -from ..config import ConfigUnit -from typing import Optional - - -class InverterdConfig(ConfigUnit): - NAME = 'inverterd' - - @classmethod - def schema(cls) -> Optional[dict]: - return { - 'remote_addr': {'type': 'string'}, - 'local_addr': {'type': 'string'}, - } \ No newline at end of file diff --git a/src/home/inverter/emulator.py b/src/home/inverter/emulator.py deleted file mode 100644 index e86b8bb..0000000 --- a/src/home/inverter/emulator.py +++ /dev/null @@ -1,556 +0,0 @@ -import asyncio -import logging - -from inverterd import Format - -from typing import Union -from enum import Enum -from ..util import Addr, stringify - - -class InverterEnum(Enum): - def as_text(self) -> str: - raise RuntimeError('abstract method') - - -class BatteryType(InverterEnum): - AGM = 0 - Flooded = 1 - User = 2 - - def as_text(self) -> str: - return ('AGM', 'Flooded', 'User')[self.value] - - -class InputVoltageRange(InverterEnum): - Appliance = 0 - USP = 1 - - def as_text(self) -> str: - return ('Appliance', 'USP')[self.value] - - -class OutputSourcePriority(InverterEnum): - SolarUtilityBattery = 0 - SolarBatteryUtility = 1 - - def as_text(self) -> str: - return ('Solar-Utility-Battery', 'Solar-Battery-Utility')[self.value] - - -class ChargeSourcePriority(InverterEnum): - SolarFirst = 0 - SolarAndUtility = 1 - SolarOnly = 2 - - def as_text(self) -> str: - return ('Solar-First', 'Solar-and-Utility', 'Solar-only')[self.value] - - -class MachineType(InverterEnum): - OffGridTie = 0 - GridTie = 1 - - def as_text(self) -> str: - return ('Off-Grid-Tie', 'Grid-Tie')[self.value] - - -class Topology(InverterEnum): - TransformerLess = 0 - Transformer = 1 - - def as_text(self) -> str: - return ('Transformer-less', 'Transformer')[self.value] - - -class OutputMode(InverterEnum): - SingleOutput = 0 - ParallelOutput = 1 - Phase_1_of_3 = 2 - Phase_2_of_3 = 3 - Phase_3_of_3 = 4 - - def as_text(self) -> str: - return ( - 'Single output', - 'Parallel output', - 'Phase 1 of 3-phase output', - 'Phase 2 of 3-phase output', - 'Phase 3 of 3-phase' - )[self.value] - - -class SolarPowerPriority(InverterEnum): - BatteryLoadUtility = 0 - LoadBatteryUtility = 1 - - def as_text(self) -> str: - return ('Battery-Load-Utility', 'Load-Battery-Utility')[self.value] - - -class MPPTChargerStatus(InverterEnum): - Abnormal = 0 - NotCharging = 1 - Charging = 2 - - def as_text(self) -> str: - return ('Abnormal', 'Not charging', 'Charging')[self.value] - - -class BatteryPowerDirection(InverterEnum): - DoNothing = 0 - Charge = 1 - Discharge = 2 - - def as_text(self) -> str: - return ('Do nothing', 'Charge', 'Discharge')[self.value] - - -class DC_AC_PowerDirection(InverterEnum): - DoNothing = 0 - AC_DC = 1 - DC_AC = 2 - - def as_text(self) -> str: - return ('Do nothing', 'AC/DC', 'DC/AC')[self.value] - - -class LinePowerDirection(InverterEnum): - DoNothing = 0 - Input = 1 - Output = 2 - - def as_text(self) -> str: - return ('Do nothing', 'Input', 'Output')[self.value] - - -class WorkingMode(InverterEnum): - PowerOnMode = 0 - StandbyMode = 1 - BypassMode = 2 - BatteryMode = 3 - FaultMode = 4 - HybridMode = 5 - - def as_text(self) -> str: - return ( - 'Power on mode', - 'Standby mode', - 'Bypass mode', - 'Battery mode', - 'Fault mode', - 'Hybrid mode' - )[self.value] - - -class ParallelConnectionStatus(InverterEnum): - NotExistent = 0 - Existent = 1 - - def as_text(self) -> str: - return ('Non-existent', 'Existent')[self.value] - - -class LoadConnectionStatus(InverterEnum): - Disconnected = 0 - Connected = 1 - - def as_text(self) -> str: - return ('Disconnected', 'Connected')[self.value] - - -class ConfigurationStatus(InverterEnum): - Default = 0 - Changed = 1 - - def as_text(self) -> str: - return ('Default', 'Changed')[self.value] - - -_g_human_readable = {"grid_voltage": "Grid voltage", - "grid_freq": "Grid frequency", - "ac_output_voltage": "AC output voltage", - "ac_output_freq": "AC output frequency", - "ac_output_apparent_power": "AC output apparent power", - "ac_output_active_power": "AC output active power", - "output_load_percent": "Output load percent", - "battery_voltage": "Battery voltage", - "battery_voltage_scc": "Battery voltage from SCC", - "battery_voltage_scc2": "Battery voltage from SCC2", - "battery_discharge_current": "Battery discharge current", - "battery_charge_current": "Battery charge current", - "battery_capacity": "Battery capacity", - "inverter_heat_sink_temp": "Inverter heat sink temperature", - "mppt1_charger_temp": "MPPT1 charger temperature", - "mppt2_charger_temp": "MPPT2 charger temperature", - "pv1_input_power": "PV1 input power", - "pv2_input_power": "PV2 input power", - "pv1_input_voltage": "PV1 input voltage", - "pv2_input_voltage": "PV2 input voltage", - "configuration_status": "Configuration state", - "mppt1_charger_status": "MPPT1 charger status", - "mppt2_charger_status": "MPPT2 charger status", - "load_connected": "Load connection", - "battery_power_direction": "Battery power direction", - "dc_ac_power_direction": "DC/AC power direction", - "line_power_direction": "Line power direction", - "local_parallel_id": "Local parallel ID", - "ac_input_rating_voltage": "AC input rating voltage", - "ac_input_rating_current": "AC input rating current", - "ac_output_rating_voltage": "AC output rating voltage", - "ac_output_rating_freq": "AC output rating frequency", - "ac_output_rating_current": "AC output rating current", - "ac_output_rating_apparent_power": "AC output rating apparent power", - "ac_output_rating_active_power": "AC output rating active power", - "battery_rating_voltage": "Battery rating voltage", - "battery_recharge_voltage": "Battery re-charge voltage", - "battery_redischarge_voltage": "Battery re-discharge voltage", - "battery_under_voltage": "Battery under voltage", - "battery_bulk_voltage": "Battery bulk voltage", - "battery_float_voltage": "Battery float voltage", - "battery_type": "Battery type", - "max_charge_current": "Max charge current", - "max_ac_charge_current": "Max AC charge current", - "input_voltage_range": "Input voltage range", - "output_source_priority": "Output source priority", - "charge_source_priority": "Charge source priority", - "parallel_max_num": "Parallel max num", - "machine_type": "Machine type", - "topology": "Topology", - "output_mode": "Output mode", - "solar_power_priority": "Solar power priority", - "mppt": "MPPT string", - "fault_code": "Fault code", - "line_fail": "Line fail", - "output_circuit_short": "Output circuit short", - "inverter_over_temperature": "Inverter over temperature", - "fan_lock": "Fan lock", - "battery_voltage_high": "Battery voltage high", - "battery_low": "Battery low", - "battery_under": "Battery under", - "over_load": "Over load", - "eeprom_fail": "EEPROM fail", - "power_limit": "Power limit", - "pv1_voltage_high": "PV1 voltage high", - "pv2_voltage_high": "PV2 voltage high", - "mppt1_overload_warning": "MPPT1 overload warning", - "mppt2_overload_warning": "MPPT2 overload warning", - "battery_too_low_to_charge_for_scc1": "Battery too low to charge for SCC1", - "battery_too_low_to_charge_for_scc2": "Battery too low to charge for SCC2", - "buzzer": "Buzzer", - "overload_bypass": "Overload bypass function", - "escape_to_default_screen_after_1min_timeout": "Escape to default screen after 1min timeout", - "overload_restart": "Overload restart", - "over_temp_restart": "Over temperature restart", - "backlight_on": "Backlight on", - "alarm_on_on_primary_source_interrupt": "Alarm on on primary source interrupt", - "fault_code_record": "Fault code record", - "wh": "Wh"} - - -class InverterEmulator: - def __init__(self, addr: Addr, wait=True): - self.status = {"grid_voltage": {"unit": "V", "value": 236.3}, - "grid_freq": {"unit": "Hz", "value": 50.0}, - "ac_output_voltage": {"unit": "V", "value": 229.9}, - "ac_output_freq": {"unit": "Hz", "value": 50.0}, - "ac_output_apparent_power": {"unit": "VA", "value": 207}, - "ac_output_active_power": {"unit": "Wh", "value": 146}, - "output_load_percent": {"unit": "%", "value": 4}, - "battery_voltage": {"unit": "V", "value": 49.1}, - "battery_voltage_scc": {"unit": "V", "value": 0.0}, - "battery_voltage_scc2": {"unit": "V", "value": 0.0}, - "battery_discharge_current": {"unit": "A", "value": 3}, - "battery_charge_current": {"unit": "A", "value": 0}, - "battery_capacity": {"unit": "%", "value": 69}, - "inverter_heat_sink_temp": {"unit": "°C", "value": 17}, - "mppt1_charger_temp": {"unit": "°C", "value": 0}, - "mppt2_charger_temp": {"unit": "°C", "value": 0}, - "pv1_input_power": {"unit": "Wh", "value": 0}, - "pv2_input_power": {"unit": "Wh", "value": 0}, - "pv1_input_voltage": {"unit": "V", "value": 0.0}, - "pv2_input_voltage": {"unit": "V", "value": 0.0}, - "configuration_status": ConfigurationStatus.Default, - "mppt1_charger_status": MPPTChargerStatus.Abnormal, - "mppt2_charger_status": MPPTChargerStatus.Abnormal, - "load_connected": LoadConnectionStatus.Connected, - "battery_power_direction": BatteryPowerDirection.Discharge, - "dc_ac_power_direction": DC_AC_PowerDirection.DC_AC, - "line_power_direction": LinePowerDirection.DoNothing, - "local_parallel_id": 0} - - self.rated = {"ac_input_rating_voltage": {"unit": "V", "value": 230.0}, - "ac_input_rating_current": {"unit": "A", "value": 21.7}, - "ac_output_rating_voltage": {"unit": "V", "value": 230.0}, - "ac_output_rating_freq": {"unit": "Hz", "value": 50.0}, - "ac_output_rating_current": {"unit": "A", "value": 21.7}, - "ac_output_rating_apparent_power": {"unit": "VA", "value": 5000}, - "ac_output_rating_active_power": {"unit": "Wh", "value": 5000}, - "battery_rating_voltage": {"unit": "V", "value": 48.0}, - "battery_recharge_voltage": {"unit": "V", "value": 48.0}, - "battery_redischarge_voltage": {"unit": "V", "value": 55.0}, - "battery_under_voltage": {"unit": "V", "value": 42.0}, - "battery_bulk_voltage": {"unit": "V", "value": 57.6}, - "battery_float_voltage": {"unit": "V", "value": 54.0}, - "battery_type": BatteryType.User, - "max_charge_current": {"unit": "A", "value": 60}, - "max_ac_charge_current": {"unit": "A", "value": 30}, - "input_voltage_range": InputVoltageRange.Appliance, - "output_source_priority": OutputSourcePriority.SolarBatteryUtility, - "charge_source_priority": ChargeSourcePriority.SolarAndUtility, - "parallel_max_num": 6, - "machine_type": MachineType.OffGridTie, - "topology": Topology.TransformerLess, - "output_mode": OutputMode.SingleOutput, - "solar_power_priority": SolarPowerPriority.LoadBatteryUtility, - "mppt": "2"} - - self.errors = {"fault_code": 0, - "line_fail": False, - "output_circuit_short": False, - "inverter_over_temperature": False, - "fan_lock": False, - "battery_voltage_high": False, - "battery_low": False, - "battery_under": False, - "over_load": False, - "eeprom_fail": False, - "power_limit": False, - "pv1_voltage_high": False, - "pv2_voltage_high": False, - "mppt1_overload_warning": False, - "mppt2_overload_warning": False, - "battery_too_low_to_charge_for_scc1": False, - "battery_too_low_to_charge_for_scc2": False} - - self.flags = {"buzzer": False, - "overload_bypass": True, - "escape_to_default_screen_after_1min_timeout": False, - "overload_restart": True, - "over_temp_restart": True, - "backlight_on": False, - "alarm_on_on_primary_source_interrupt": True, - "fault_code_record": False} - - self.day_generated = 1000 - - self.logger = logging.getLogger(self.__class__.__name__) - - host, port = addr - asyncio.run(self.run_server(host, port, wait)) - # self.max_ac_charge_current = 30 - # self.max_charge_current = 60 - # self.charge_thresholds = [48, 54] - - async def run_server(self, host, port, wait: bool): - server = await asyncio.start_server(self.client_handler, host, port) - async with server: - self.logger.info(f'listening on {host}:{port}') - if wait: - await server.serve_forever() - else: - asyncio.ensure_future(server.serve_forever()) - - async def client_handler(self, reader, writer): - client_fmt = Format.JSON - - def w(s: str): - writer.write(s.encode('utf-8')) - - def return_error(message=None): - w('err\r\n') - if message: - if client_fmt in (Format.JSON, Format.SIMPLE_JSON): - w(stringify({ - 'result': 'error', - 'message': message - })) - elif client_fmt in (Format.TABLE, Format.SIMPLE_TABLE): - w(f'error: {message}') - w('\r\n') - w('\r\n') - - def return_ok(data=None): - w('ok\r\n') - if client_fmt in (Format.JSON, Format.SIMPLE_JSON): - jdata = { - 'result': 'ok' - } - if data: - jdata['data'] = data - w(stringify(jdata)) - w('\r\n') - elif data: - w(data) - w('\r\n') - w('\r\n') - - request = None - while request != 'quit': - try: - request = await reader.read(255) - if request == b'\x04': - break - request = request.decode('utf-8').strip() - except Exception: - break - - if request.startswith('format '): - requested_format = request[7:] - try: - client_fmt = Format(requested_format) - except ValueError: - return_error('invalid format') - - return_ok() - - elif request.startswith('exec '): - buf = request[5:].split(' ') - command = buf[0] - args = buf[1:] - - try: - return_ok(self.process_command(client_fmt, command, *args)) - except ValueError as e: - return_error(str(e)) - - else: - return_error(f'invalid token: {request}') - - try: - await writer.drain() - except ConnectionResetError as e: - # self.logger.exception(e) - pass - - writer.close() - - def process_command(self, fmt: Format, c: str, *args) -> Union[dict, str, list[int], None]: - ac_charge_currents = [2, 10, 20, 30, 40, 50, 60] - - if c == 'get-status': - return self.format_dict(self.status, fmt) - - elif c == 'get-rated': - return self.format_dict(self.rated, fmt) - - elif c == 'get-errors': - return self.format_dict(self.errors, fmt) - - elif c == 'get-flags': - return self.format_dict(self.flags, fmt) - - elif c == 'get-day-generated': - return self.format_dict({'wh': 1000}, fmt) - - elif c == 'get-allowed-ac-charge-currents': - return self.format_list(ac_charge_currents, fmt) - - elif c == 'set-max-ac-charge-current': - if int(args[0]) != 0: - raise ValueError(f'invalid machine id: {args[0]}') - amps = int(args[1]) - if amps not in ac_charge_currents: - raise ValueError(f'invalid value: {amps}') - self.rated['max_ac_charge_current']['value'] = amps - - elif c == 'set-charge-thresholds': - self.rated['battery_recharge_voltage']['value'] = float(args[0]) - self.rated['battery_redischarge_voltage']['value'] = float(args[1]) - - elif c == 'set-output-source-priority': - self.rated['output_source_priority'] = OutputSourcePriority.SolarBatteryUtility if args[0] == 'SBU' else OutputSourcePriority.SolarUtilityBattery - - elif c == 'set-battery-cutoff-voltage': - self.rated['battery_under_voltage']['value'] = float(args[0]) - - elif c == 'set-flag': - flag = args[0] - val = bool(int(args[1])) - - if flag == 'BUZZ': - k = 'buzzer' - elif flag == 'OLBP': - k = 'overload_bypass' - elif flag == 'LCDE': - k = 'escape_to_default_screen_after_1min_timeout' - elif flag == 'OLRS': - k = 'overload_restart' - elif flag == 'OTRS': - k = 'over_temp_restart' - elif flag == 'BLON': - k = 'backlight_on' - elif flag == 'ALRM': - k = 'alarm_on_on_primary_source_interrupt' - elif flag == 'FTCR': - k = 'fault_code_record' - else: - raise ValueError('invalid flag') - - self.flags[k] = val - - else: - raise ValueError(f'{c}: unsupported command') - - @staticmethod - def format_list(values: list, fmt: Format) -> Union[str, list]: - if fmt in (Format.JSON, Format.SIMPLE_JSON): - return values - return '\n'.join(map(lambda v: str(v), values)) - - @staticmethod - def format_dict(data: dict, fmt: Format) -> Union[str, dict]: - new_data = {} - for k, v in data.items(): - new_val = None - if fmt in (Format.JSON, Format.TABLE, Format.SIMPLE_TABLE): - if isinstance(v, dict): - new_val = v - elif isinstance(v, InverterEnum): - new_val = v.as_text() - else: - new_val = v - elif fmt == Format.SIMPLE_JSON: - if isinstance(v, dict): - new_val = v['value'] - elif isinstance(v, InverterEnum): - new_val = v.value - else: - new_val = str(v) - new_data[k] = new_val - - if fmt in (Format.JSON, Format.SIMPLE_JSON): - return new_data - - lines = [] - - if fmt == Format.SIMPLE_TABLE: - for k, v in new_data.items(): - buf = k - if isinstance(v, dict): - buf += ' ' + str(v['value']) + ' ' + v['unit'] - elif isinstance(v, InverterEnum): - buf += ' ' + v.as_text() - else: - buf += ' ' + str(v) - lines.append(buf) - - elif fmt == Format.TABLE: - max_k_len = 0 - for k in new_data.keys(): - if len(_g_human_readable[k]) > max_k_len: - max_k_len = len(_g_human_readable[k]) - for k, v in new_data.items(): - buf = _g_human_readable[k] + ':' - buf += ' ' * (max_k_len - len(_g_human_readable[k]) + 1) - if isinstance(v, dict): - buf += str(v['value']) + ' ' + v['unit'] - elif isinstance(v, InverterEnum): - buf += v.as_text() - elif isinstance(v, bool): - buf += str(int(v)) - else: - buf += str(v) - lines.append(buf) - - return '\n'.join(lines) diff --git a/src/home/inverter/inverter_wrapper.py b/src/home/inverter/inverter_wrapper.py deleted file mode 100644 index df2c2fc..0000000 --- a/src/home/inverter/inverter_wrapper.py +++ /dev/null @@ -1,48 +0,0 @@ -import json - -from threading import Lock -from inverterd import ( - Format, - Client as InverterClient, - InverterError -) - -_lock = Lock() - - -class InverterClientWrapper: - def __init__(self): - self._inverter = None - self._host = None - self._port = None - - def init(self, host: str, port: int): - self._host = host - self._port = port - self.create() - - def create(self): - self._inverter = InverterClient(host=self._host, port=self._port) - self._inverter.connect() - - def exec(self, command: str, arguments: tuple = (), format=Format.JSON): - with _lock: - try: - self._inverter.format(format) - response = self._inverter.exec(command, arguments) - if format == Format.JSON: - response = json.loads(response) - return response - except InverterError as e: - raise e - except Exception as e: - # silently try to reconnect - try: - self.create() - except Exception: - pass - raise e - - -wrapper_instance = InverterClientWrapper() - diff --git a/src/home/inverter/monitor.py b/src/home/inverter/monitor.py deleted file mode 100644 index 86f75ac..0000000 --- a/src/home/inverter/monitor.py +++ /dev/null @@ -1,499 +0,0 @@ -import logging -import time - -from .types import * -from threading import Thread -from typing import Callable, Optional -from .inverter_wrapper import wrapper_instance as inverter -from inverterd import InverterError -from ..util import Stopwatch, StopwatchError -from ..config import config - -logger = logging.getLogger(__name__) - - -def _pd_from_string(pd: str) -> BatteryPowerDirection: - if pd == 'Discharge': - return BatteryPowerDirection.DISCHARGING - elif pd == 'Charge': - return BatteryPowerDirection.CHARGING - elif pd == 'Do nothing': - return BatteryPowerDirection.DO_NOTHING - else: - raise ValueError(f'invalid power direction: {pd}') - - -class MonitorConfig: - def __getattr__(self, item): - return config['monitor'][item] - - -cfg = MonitorConfig() - - -""" -TODO: -- поддержать возможность ручного (через бота) переключения тока заряда вверх и вниз -- поддержать возможность бесшовного перезапуска бота, когда монитор понимает, что зарядка уже идет, и он - не запускает программу с начала, а продолжает с уже существующей позиции. Уведомления при этом можно не - присылать совсем, либо прислать какое-то одно приложение, в духе "программа была перезапущена" -""" - - -class InverterMonitor(Thread): - charging_event_handler: Optional[Callable] - battery_event_handler: Optional[Callable] - util_event_handler: Optional[Callable] - error_handler: Optional[Callable] - osp_change_cb: Optional[Callable] - osp: Optional[OutputSourcePriority] - - def __init__(self): - super().__init__() - self.setName('InverterMonitor') - - self.interrupted = False - self.min_allowed_current = 0 - self.ac_mode = None - self.osp = None - - # Event handlers for the bot. - self.charging_event_handler = None - self.battery_event_handler = None - self.util_event_handler = None - self.error_handler = None - self.osp_change_cb = None - - # Currents list, defined in the bot config. - self.currents = cfg.gen_currents - self.currents.sort() - - # We start charging at lowest possible current, then increase it once per minute (or so) to the maximum level. - # This is done so that the load on the generator increases smoothly, not abruptly. Generator will thank us. - self.current_change_direction = CurrentChangeDirection.UP - self.next_current_enter_time = 0 - self.active_current_idx = -1 - - self.battery_state = BatteryState.NORMAL - self.charging_state = ChargingState.NOT_CHARGING - - # 'Mostly-charged' means that we've already lowered the charging current to the level - # at which batteries are charging pretty slow. So instead of burning gasoline and shaking the air, - # we can just turn the generator off at this point. - self.mostly_charged = False - - # The stopwatch is used to measure how long does the battery voltage exceeds the float voltage level. - # We don't want to damage our batteries, right? - self.floating_stopwatch = Stopwatch() - - # State variables for utilities charging program - self.util_ac_present = None - self.util_pd = None - self.util_solar = None - - @property - def active_current(self) -> Optional[int]: - try: - if self.active_current_idx < 0: - return None - return self.currents[self.active_current_idx] - except IndexError: - return None - - def run(self): - # Check allowed currents and validate the config. - allowed_currents = list(inverter.exec('get-allowed-ac-charge-currents')['data']) - allowed_currents.sort() - - for a in self.currents: - if a not in allowed_currents: - raise ValueError(f'invalid value {a} in gen_currents list') - - self.min_allowed_current = min(allowed_currents) - - # Reading rated configuration - rated = inverter.exec('get-rated')['data'] - self.osp = OutputSourcePriority.from_text(rated['output_source_priority']) - - # Read data and run implemented programs every 2 seconds. - while not self.interrupted: - try: - response = inverter.exec('get-status') - if response['result'] != 'ok': - logger.error('get-status failed:', response) - else: - gs = response['data'] - - ac = gs['grid_voltage']['value'] > 0 or gs['grid_freq']['value'] > 0 - solar = gs['pv1_input_voltage']['value'] > 0 or gs['pv2_input_voltage']['value'] > 0 - solar_input = gs['pv1_input_power']['value'] - v = float(gs['battery_voltage']['value']) - load_watts = int(gs['ac_output_active_power']['value']) - pd = _pd_from_string(gs['battery_power_direction']) - - logger.debug(f'got status: ac={ac}, solar={solar}, v={v}, pd={pd}') - - if self.ac_mode == ACMode.GENERATOR: - self.gen_charging_program(ac, solar, v, pd) - - elif self.ac_mode == ACMode.UTILITIES: - self.utilities_monitoring_program(ac, solar, v, load_watts, solar_input, pd) - - if not ac or pd != BatteryPowerDirection.CHARGING: - # if AC is disconnected or not charging, run the low voltage checking program - self.low_voltage_program(v, load_watts) - - elif self.battery_state != BatteryState.NORMAL: - # AC is connected and the battery is charging, assume battery level is normal - self.battery_state = BatteryState.NORMAL - - except InverterError as e: - logger.exception(e) - - time.sleep(2) - - def utilities_monitoring_program(self, - ac: bool, # whether AC is connected - solar: bool, # whether MPPT is active - v: float, # battery voltage - load_watts: int, # load, wh - solar_input: int, # input from solar panels, wh - pd: BatteryPowerDirection # current power direction - ): - pd_event_send = False - if self.util_solar is None or solar != self.util_solar: - self.util_solar = solar - if solar and self.util_ac_present and self.util_pd == BatteryPowerDirection.CHARGING: - self.charging_event_handler(ChargingEvent.UTIL_CHARGING_STOPPED_SOLAR) - pd_event_send = True - - if solar: - if v <= 48 and self.osp == OutputSourcePriority.SolarBatteryUtility: - self.osp_change_cb(OutputSourcePriority.SolarUtilityBattery, solar_input=solar_input, v=v) - self.osp = OutputSourcePriority.SolarUtilityBattery - - if self.osp == OutputSourcePriority.SolarUtilityBattery and solar_input >= 900: - self.osp_change_cb(OutputSourcePriority.SolarBatteryUtility, solar_input=solar_input, v=v) - self.osp = OutputSourcePriority.SolarBatteryUtility - - if self.util_ac_present is None or ac != self.util_ac_present: - self.util_event_handler(ACPresentEvent.CONNECTED if ac else ACPresentEvent.DISCONNECTED) - self.util_ac_present = ac - - if self.util_pd is None or self.util_pd != pd: - self.util_pd = pd - if not pd_event_send and not solar: - if pd == BatteryPowerDirection.CHARGING: - self.charging_event_handler(ChargingEvent.UTIL_CHARGING_STARTED) - - elif pd == BatteryPowerDirection.DISCHARGING: - self.charging_event_handler(ChargingEvent.UTIL_CHARGING_STOPPED) - - def gen_charging_program(self, - ac: bool, # whether AC is connected - solar: bool, # whether MPPT is active - v: float, # current battery voltage - pd: BatteryPowerDirection # current power direction - ): - if self.charging_state == ChargingState.NOT_CHARGING: - if ac and solar: - # Not charging because MPPT is active (solar line is connected). - # Notify users about it and change the current state. - self.charging_state = ChargingState.AC_BUT_SOLAR - self.charging_event_handler(ChargingEvent.AC_CHARGING_UNAVAILABLE_BECAUSE_SOLAR) - logger.info('entering AC_BUT_SOLAR state') - elif ac: - # Not charging, but AC is connected and ready to use. - # Start the charging program. - self.gen_start(pd) - - elif self.charging_state == ChargingState.AC_BUT_SOLAR: - if not ac: - # AC charger has been disconnected. Since the state is AC_BUT_SOLAR, - # charging probably never even started. Stop the charging program. - self.gen_stop(ChargingState.NOT_CHARGING) - elif not solar: - # MPPT has been disconnected, and, since AC is still connected, we can - # try to start the charging program. - self.gen_start(pd) - - elif self.charging_state in (ChargingState.AC_OK, ChargingState.AC_WAITING): - if not ac: - # Charging was in progress, but AC has been suddenly disconnected. - # Sad, but what can we do? Stop the charging program and return. - self.gen_stop(ChargingState.NOT_CHARGING) - return - - if solar: - # Charging was in progress, but MPPT has been detected. Inverter doesn't charge - # batteries from AC when MPPT is active, so we have to pause our program. - self.charging_state = ChargingState.AC_BUT_SOLAR - self.charging_event_handler(ChargingEvent.AC_CHARGING_UNAVAILABLE_BECAUSE_SOLAR) - try: - self.floating_stopwatch.pause() - except StopwatchError: - msg = 'gen_charging_program: floating_stopwatch.pause() failed at (1)' - logger.warning(msg) - # self.error_handler(msg) - logger.info('solar power connected during charging, entering AC_BUT_SOLAR state') - return - - # No surprises at this point, just check the values and make decisions based on them. - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # We've reached the 'mostly-charged' point, the voltage level is not float, - # but inverter decided to stop charging (or somebody used a kettle, lol). - # Anyway, assume that charging is complete, stop the program, notify users and return. - if self.mostly_charged and v > (cfg.gen_floating_v - 1) and pd != BatteryPowerDirection.CHARGING: - self.gen_stop(ChargingState.AC_DONE) - return - - # Monitor inverter power direction and notify users when it changes. - state = ChargingState.AC_OK if pd == BatteryPowerDirection.CHARGING else ChargingState.AC_WAITING - if state != self.charging_state: - self.charging_state = state - - evt = ChargingEvent.AC_CHARGING_STARTED if state == ChargingState.AC_OK else ChargingEvent.AC_NOT_CHARGING - self.charging_event_handler(evt) - - if self.floating_stopwatch.get_elapsed_time() >= cfg.gen_floating_time_max: - # We've been at a bulk voltage level too long, so we have to stop charging. - # Set the minimum current possible. - - if self.current_change_direction == CurrentChangeDirection.UP: - # This shouldn't happen, obviously an error. - msg = 'gen_charging_program:' - msg += ' been at bulk voltage level too long, but current change direction is still \'up\'!' - msg += ' This is obviously an error, please fix it' - logger.warning(msg) - self.error_handler(msg) - - self.gen_next_current(current=self.min_allowed_current) - - elif self.active_current is not None: - # If voltage is greater than float voltage, keep the stopwatch ticking - if v > cfg.gen_floating_v and self.floating_stopwatch.is_paused(): - try: - self.floating_stopwatch.go() - except StopwatchError: - msg = 'gen_charging_program: floating_stopwatch.go() failed at (2)' - logger.warning(msg) - self.error_handler(msg) - # Otherwise, pause it - elif v <= cfg.gen_floating_v and not self.floating_stopwatch.is_paused(): - try: - self.floating_stopwatch.pause() - except StopwatchError: - msg = 'gen_charging_program: floating_stopwatch.pause() failed at (3)' - logger.warning(msg) - self.error_handler(msg) - - # Charging current monitoring - if self.current_change_direction == CurrentChangeDirection.UP: - # Generator is warming up in this code path - - if self.next_current_enter_time != 0 and pd != BatteryPowerDirection.CHARGING: - # Generator was warming up and charging, but stopped (pd has changed). - # Resetting to the minimum possible current - logger.info(f'gen_charging_program (warming path): was charging but power direction suddeny changed. resetting to minimum current') - self.next_current_enter_time = 0 - self.gen_next_current(current=self.min_allowed_current) - - elif self.next_current_enter_time == 0 and pd == BatteryPowerDirection.CHARGING: - self.next_current_enter_time = time.time() + cfg.gen_raise_intervals[self.active_current_idx] - logger.info(f'gen_charging_program (warming path): set next_current_enter_time to {self.next_current_enter_time}') - - elif self.next_current_enter_time != 0 and time.time() >= self.next_current_enter_time: - logger.info('gen_charging_program (warming path): hit next_current_enter_time, calling gen_next_current()') - self.gen_next_current() - else: - # Gradually lower the current level, based on how close - # battery voltage has come to the bulk level. - if self.active_current >= 30: - upper_bound = cfg.gen_cur30_v_limit - elif self.active_current == 20: - upper_bound = cfg.gen_cur20_v_limit - else: - upper_bound = cfg.gen_cur10_v_limit - - # Voltage is high enough already and it's close to bulk level; we hit the upper bound, - # so let's lower the current - if v >= upper_bound: - self.gen_next_current() - - elif self.charging_state == ChargingState.AC_DONE: - # We've already finished charging, but AC was connected. Not that it's disconnected, - # set the appropriate state and notify users. - if not ac: - self.gen_stop(ChargingState.NOT_CHARGING) - - def gen_start(self, pd: BatteryPowerDirection): - if pd == BatteryPowerDirection.CHARGING: - self.charging_state = ChargingState.AC_OK - self.charging_event_handler(ChargingEvent.AC_CHARGING_STARTED) - logger.info('AC line connected and charging, entering AC_OK state') - - # Continue the stopwatch, if needed - try: - self.floating_stopwatch.go() - except StopwatchError: - msg = 'floating_stopwatch.go() failed at ac_charging_start(), AC_OK path' - logger.warning(msg) - self.error_handler(msg) - else: - self.charging_state = ChargingState.AC_WAITING - self.charging_event_handler(ChargingEvent.AC_NOT_CHARGING) - logger.info('AC line connected but not charging yet, entering AC_WAITING state') - - # Pause the stopwatch, if needed - try: - if not self.floating_stopwatch.is_paused(): - self.floating_stopwatch.pause() - except StopwatchError: - msg = 'floating_stopwatch.pause() failed at ac_charging_start(), AC_WAITING path' - logger.warning(msg) - self.error_handler(msg) - - # idx == -1 means haven't started our program yet. - if self.active_current_idx == -1: - self.gen_next_current() - # self.set_hw_charging_current(self.min_allowed_current) - - def gen_stop(self, reason: ChargingState): - self.charging_state = reason - - if reason == ChargingState.AC_DONE: - event = ChargingEvent.AC_CHARGING_FINISHED - elif reason == ChargingState.NOT_CHARGING: - event = ChargingEvent.AC_DISCONNECTED - else: - raise ValueError(f'ac_charging_stop: unexpected reason {reason}') - - logger.info(f'charging is finished, entering {reason} state') - self.charging_event_handler(event) - - self.next_current_enter_time = 0 - self.mostly_charged = False - self.active_current_idx = -1 - self.floating_stopwatch.reset() - self.current_change_direction = CurrentChangeDirection.UP - - self.set_hw_charging_current(self.min_allowed_current) - - def gen_next_current(self, current=None): - if current is None: - try: - current = self._next_current() - logger.debug(f'gen_next_current: ready to change charging current to {current} A') - except IndexError: - logger.debug('gen_next_current: was going to change charging current, but no currents left; finishing charging program') - self.gen_stop(ChargingState.AC_DONE) - return - - else: - try: - idx = self.currents.index(current) - except ValueError: - msg = f'gen_next_current: got current={current} but it\'s not in the currents list' - logger.error(msg) - self.error_handler(msg) - return - self.active_current_idx = idx - - if self.current_change_direction == CurrentChangeDirection.DOWN: - if current == self.currents[0]: - self.mostly_charged = True - self.gen_stop(ChargingState.AC_DONE) - - elif current == self.currents[1] and not self.mostly_charged: - self.mostly_charged = True - self.charging_event_handler(ChargingEvent.AC_MOSTLY_CHARGED) - - self.set_hw_charging_current(current) - - def set_hw_charging_current(self, current: int): - try: - response = inverter.exec('set-max-ac-charge-current', (0, current)) - if response['result'] != 'ok': - logger.error(f'failed to change AC charging current to {current} A') - raise InverterError('set-max-ac-charge-current: inverterd reported error') - else: - self.charging_event_handler(ChargingEvent.AC_CURRENT_CHANGED, current=current) - logger.info(f'changed AC charging current to {current} A') - except InverterError as e: - self.error_handler(f'failed to set charging current to {current} A (caught InverterError)') - logger.exception(e) - - def _next_current(self): - if self.current_change_direction == CurrentChangeDirection.UP: - self.active_current_idx += 1 - if self.active_current_idx == len(self.currents)-1: - logger.info('_next_current: charging current power direction to DOWN') - self.current_change_direction = CurrentChangeDirection.DOWN - self.next_current_enter_time = 0 - else: - if self.active_current_idx == 0: - raise IndexError('can\'t go lower') - self.active_current_idx -= 1 - - logger.info(f'_next_current: active_current_idx set to {self.active_current_idx}, returning current of {self.currents[self.active_current_idx]} A') - return self.currents[self.active_current_idx] - - def low_voltage_program(self, v: float, load_watts: int): - crit_level = cfg.vcrit - low_level = cfg.vlow - - if v <= crit_level: - state = BatteryState.CRITICAL - elif v <= low_level: - state = BatteryState.LOW - else: - state = BatteryState.NORMAL - - if state != self.battery_state: - self.battery_state = state - self.battery_event_handler(state, v, load_watts) - - def set_charging_event_handler(self, handler: Callable): - self.charging_event_handler = handler - - def set_battery_event_handler(self, handler: Callable): - self.battery_event_handler = handler - - def set_util_event_handler(self, handler: Callable): - self.util_event_handler = handler - - def set_error_handler(self, handler: Callable): - self.error_handler = handler - - def set_osp_need_change_callback(self, cb: Callable): - self.osp_change_cb = cb - - def set_ac_mode(self, mode: ACMode): - self.ac_mode = mode - - def notify_osp(self, osp: OutputSourcePriority): - self.osp = osp - - def stop(self): - self.interrupted = True - - def dump_status(self) -> dict: - return { - 'interrupted': self.interrupted, - 'currents': self.currents, - 'active_current': self.active_current, - 'current_change_direction': self.current_change_direction.name, - 'battery_state': self.battery_state.name, - 'charging_state': self.charging_state.name, - 'mostly_charged': self.mostly_charged, - 'floating_stopwatch_paused': self.floating_stopwatch.is_paused(), - 'floating_stopwatch_elapsed': self.floating_stopwatch.get_elapsed_time(), - 'time_now': time.time(), - 'next_current_enter_time': self.next_current_enter_time, - 'ac_mode': self.ac_mode, - 'osp': self.osp, - 'util_ac_present': self.util_ac_present, - 'util_pd': self.util_pd.name, - 'util_solar': self.util_solar - } diff --git a/src/home/inverter/types.py b/src/home/inverter/types.py deleted file mode 100644 index 57021f1..0000000 --- a/src/home/inverter/types.py +++ /dev/null @@ -1,64 +0,0 @@ -from enum import Enum, auto - - -class BatteryPowerDirection(Enum): - DISCHARGING = auto() - CHARGING = auto() - DO_NOTHING = auto() - - -class ChargingEvent(Enum): - AC_CHARGING_UNAVAILABLE_BECAUSE_SOLAR = auto() - AC_NOT_CHARGING = auto() - AC_CHARGING_STARTED = auto() - AC_DISCONNECTED = auto() - AC_CURRENT_CHANGED = auto() - AC_MOSTLY_CHARGED = auto() - AC_CHARGING_FINISHED = auto() - - UTIL_CHARGING_STARTED = auto() - UTIL_CHARGING_STOPPED = auto() - UTIL_CHARGING_STOPPED_SOLAR = auto() - - -class ACPresentEvent(Enum): - CONNECTED = auto() - DISCONNECTED = auto() - - -class ChargingState(Enum): - NOT_CHARGING = auto() - AC_BUT_SOLAR = auto() - AC_WAITING = auto() - AC_OK = auto() - AC_DONE = auto() - - -class CurrentChangeDirection(Enum): - UP = auto() - DOWN = auto() - - -class BatteryState(Enum): - NORMAL = auto() - LOW = auto() - CRITICAL = auto() - - -class ACMode(Enum): - GENERATOR = 'generator' - UTILITIES = 'utilities' - - -class OutputSourcePriority(Enum): - SolarUtilityBattery = 'SUB' - SolarBatteryUtility = 'SBU' - - @classmethod - def from_text(cls, s: str): - if s == 'Solar-Battery-Utility': - return cls.SolarBatteryUtility - elif s == 'Solar-Utility-Battery': - return cls.SolarUtilityBattery - else: - raise ValueError(f'unknown value: {s}') \ No newline at end of file diff --git a/src/home/inverter/util.py b/src/home/inverter/util.py deleted file mode 100644 index a577e6a..0000000 --- a/src/home/inverter/util.py +++ /dev/null @@ -1,8 +0,0 @@ -import re - - -def beautify_table(s): - lines = s.split('\n') - lines = list(map(lambda line: re.sub(r'\s+', ' ', line), lines)) - lines = list(map(lambda line: re.sub(r'(.*?): (.*)', r'\1: \2', line), lines)) - return '\n'.join(lines) diff --git a/src/home/media/__init__.py b/src/home/media/__init__.py deleted file mode 100644 index 6923105..0000000 --- a/src/home/media/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -import importlib -import itertools - -__map__ = { - 'types': ['MediaNodeType'], - 'record_client': ['SoundRecordClient', 'CameraRecordClient', 'RecordClient'], - 'node_server': ['MediaNodeServer'], - 'node_client': ['SoundNodeClient', 'CameraNodeClient', 'MediaNodeClient'], - 'storage': ['SoundRecordStorage', 'ESP32CameraRecordStorage', 'SoundRecordFile', 'CameraRecordFile', 'RecordFile'], - 'record': ['SoundRecorder', 'CameraRecorder'] -} - -__all__ = list(itertools.chain(*__map__.values())) - - -def __getattr__(name): - if name in __all__: - for file, names in __map__.items(): - if name in names: - module = importlib.import_module(f'.{file}', __name__) - return getattr(module, name) - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/home/media/__init__.pyi b/src/home/media/__init__.pyi deleted file mode 100644 index 77c2176..0000000 --- a/src/home/media/__init__.pyi +++ /dev/null @@ -1,27 +0,0 @@ -from .types import ( - MediaNodeType as MediaNodeType -) -from .record_client import ( - SoundRecordClient as SoundRecordClient, - CameraRecordClient as CameraRecordClient, - RecordClient as RecordClient -) -from .node_server import ( - MediaNodeServer as MediaNodeServer -) -from .node_client import ( - SoundNodeClient as SoundNodeClient, - CameraNodeClient as CameraNodeClient, - MediaNodeClient as MediaNodeClient -) -from .storage import ( - SoundRecordStorage as SoundRecordStorage, - ESP32CameraRecordStorage as ESP32CameraRecordStorage, - SoundRecordFile as SoundRecordFile, - CameraRecordFile as CameraRecordFile, - RecordFile as RecordFile -) -from .record import ( - SoundRecorder as SoundRecorder, - CameraRecorder as CameraRecorder -) \ No newline at end of file diff --git a/src/home/media/node_client.py b/src/home/media/node_client.py deleted file mode 100644 index eb39898..0000000 --- a/src/home/media/node_client.py +++ /dev/null @@ -1,119 +0,0 @@ -import requests -import shutil -import logging - -from typing import Optional, Union, List -from .storage import RecordFile -from ..util import Addr -from ..api.errors import ApiResponseError - - -class MediaNodeClient: - def __init__(self, addr: Addr): - self.endpoint = f'http://{addr[0]}:{addr[1]}' - self.logger = logging.getLogger(self.__class__.__name__) - - def record(self, duration: int): - return self._call('record/', params={"duration": duration}) - - def record_info(self, record_id: int): - return self._call(f'record/info/{record_id}/') - - def record_forget(self, record_id: int): - return self._call(f'record/forget/{record_id}/') - - def record_download(self, record_id: int, output: str): - return self._call(f'record/download/{record_id}/', save_to=output) - - def storage_list(self, extended=False, as_objects=False) -> Union[List[str], List[dict], List[RecordFile]]: - r = self._call('storage/list/', params={'extended': int(extended)}) - files = r['files'] - if as_objects: - return self.record_list_from_serialized(files) - return files - - @staticmethod - def record_list_from_serialized(files: Union[List[str], List[dict]]): - new_files = [] - for f in files: - kwargs = {'remote': True} - if isinstance(f, dict): - name = f['filename'] - kwargs['remote_filesize'] = f['filesize'] - else: - name = f - item = RecordFile.create(name, **kwargs) - new_files.append(item) - return new_files - - def storage_delete(self, file_id: str): - return self._call('storage/delete/', params={'file_id': file_id}) - - def storage_download(self, file_id: str, output: str): - return self._call('storage/download/', params={'file_id': file_id}, save_to=output) - - def _call(self, - method: str, - params: dict = None, - save_to: Optional[str] = None): - kwargs = {} - if isinstance(params, dict): - kwargs['params'] = params - if save_to: - kwargs['stream'] = True - - url = f'{self.endpoint}/{method}' - self.logger.debug(f'calling {url}, kwargs: {kwargs}') - - r = requests.get(url, **kwargs) - if r.status_code != 200: - response = r.json() - raise ApiResponseError(status_code=r.status_code, - error_type=response['error'], - error_message=response['message'] or None, - error_stacktrace=response['stacktrace'] if 'stacktrace' in response else None) - - if save_to: - r.raise_for_status() - with open(save_to, 'wb') as f: - shutil.copyfileobj(r.raw, f) - return True - - return r.json()['response'] - - -class SoundNodeClient(MediaNodeClient): - def amixer_get_all(self): - return self._call('amixer/get-all/') - - def amixer_get(self, control: str): - return self._call(f'amixer/get/{control}/') - - def amixer_incr(self, control: str, step: Optional[int] = None): - params = {'step': step} if step is not None else None - return self._call(f'amixer/incr/{control}/', params=params) - - def amixer_decr(self, control: str, step: Optional[int] = None): - params = {'step': step} if step is not None else None - return self._call(f'amixer/decr/{control}/', params=params) - - def amixer_mute(self, control: str): - return self._call(f'amixer/mute/{control}/') - - def amixer_unmute(self, control: str): - return self._call(f'amixer/unmute/{control}/') - - def amixer_cap(self, control: str): - return self._call(f'amixer/cap/{control}/') - - def amixer_nocap(self, control: str): - return self._call(f'amixer/nocap/{control}/') - - -class CameraNodeClient(MediaNodeClient): - def capture(self, - save_to: str, - with_flash: bool = False): - return self._call('capture/', - {'with_flash': int(with_flash)}, - save_to=save_to) diff --git a/src/home/media/node_server.py b/src/home/media/node_server.py deleted file mode 100644 index 5d0803c..0000000 --- a/src/home/media/node_server.py +++ /dev/null @@ -1,86 +0,0 @@ -from .. import http -from .record import Recorder -from .types import RecordStatus -from .storage import RecordStorage - - -class MediaNodeServer(http.HTTPServer): - recorder: Recorder - storage: RecordStorage - - def __init__(self, - recorder: Recorder, - storage: RecordStorage, - *args, **kwargs): - super().__init__(*args, **kwargs) - - self.recorder = recorder - self.storage = storage - - self.get('/record/', self.do_record) - self.get('/record/info/{id}/', self.record_info) - self.get('/record/forget/{id}/', self.record_forget) - self.get('/record/download/{id}/', self.record_download) - - self.get('/storage/list/', self.storage_list) - self.get('/storage/delete/', self.storage_delete) - self.get('/storage/download/', self.storage_download) - - async def do_record(self, request: http.Request): - duration = int(request.query['duration']) - max = Recorder.get_max_record_time()*15 - if not 0 < duration <= max: - raise ValueError(f'invalid duration: max duration is {max}') - - record_id = self.recorder.record(duration) - return http.ok({'id': record_id}) - - async def record_info(self, request: http.Request): - record_id = int(request.match_info['id']) - info = self.recorder.get_info(record_id) - return http.ok(info.as_dict()) - - async def record_forget(self, request: http.Request): - record_id = int(request.match_info['id']) - - info = self.recorder.get_info(record_id) - assert info.status in (RecordStatus.FINISHED, RecordStatus.ERROR), f"can't forget: record status is {info.status}" - - self.recorder.forget(record_id) - return http.ok() - - async def record_download(self, request: http.Request): - record_id = int(request.match_info['id']) - - info = self.recorder.get_info(record_id) - assert info.status == RecordStatus.FINISHED, f"record status is {info.status}" - - return http.FileResponse(info.file.path) - - async def storage_list(self, request: http.Request): - extended = 'extended' in request.query and int(request.query['extended']) == 1 - - files = self.storage.getfiles(as_objects=extended) - if extended: - files = list(map(lambda file: file.__dict__(), files)) - - return http.ok({ - 'files': files - }) - - async def storage_delete(self, request: http.Request): - file_id = request.query['file_id'] - file = self.storage.find(file_id) - if not file: - raise ValueError(f'file {file} not found') - - self.storage.delete(file) - return http.ok() - - async def storage_download(self, request): - file_id = request.query['file_id'] - file = self.storage.find(file_id) - if not file: - raise ValueError(f'file {file} not found') - - return http.FileResponse(file.path) diff --git a/src/home/media/record.py b/src/home/media/record.py deleted file mode 100644 index cd7447a..0000000 --- a/src/home/media/record.py +++ /dev/null @@ -1,461 +0,0 @@ -import os -import threading -import logging -import time -import subprocess -import signal - -from typing import Optional, List, Dict -from ..util import find_child_processes, Addr -from ..config import config -from .storage import RecordFile, RecordStorage -from .types import RecordStatus -from ..camera.types import CameraType - - -_history_item_timeout = 7200 -_history_cleanup_freq = 3600 - - -class RecordHistoryItem: - id: int - request_time: float - start_time: float - stop_time: float - relations: List[int] - status: RecordStatus - error: Optional[Exception] - file: Optional[RecordFile] - creation_time: float - - def __init__(self, id): - self.id = id - self.request_time = 0 - self.start_time = 0 - self.stop_time = 0 - self.relations = [] - self.status = RecordStatus.WAITING - self.file = None - self.error = None - self.creation_time = time.time() - - def add_relation(self, related_id: int): - self.relations.append(related_id) - - def mark_started(self, start_time: float): - self.start_time = start_time - self.status = RecordStatus.RECORDING - - def mark_finished(self, end_time: float, file: RecordFile): - self.stop_time = end_time - self.file = file - self.status = RecordStatus.FINISHED - - def mark_failed(self, error: Exception): - self.status = RecordStatus.ERROR - self.error = error - - def as_dict(self) -> dict: - data = { - 'id': self.id, - 'request_time': self.request_time, - 'status': self.status.value, - 'relations': self.relations, - 'start_time': self.start_time, - 'stop_time': self.stop_time, - } - if self.error: - data['error'] = str(self.error) - if self.file: - data['file'] = self.file.__dict__() - return data - - -class RecordingNotFoundError(Exception): - pass - - -class RecordHistory: - history: Dict[int, RecordHistoryItem] - - def __init__(self): - self.history = {} - self.logger = logging.getLogger(self.__class__.__name__) - - def add(self, record_id: int): - self.logger.debug(f'add: record_id={record_id}') - - r = RecordHistoryItem(record_id) - r.request_time = time.time() - - self.history[record_id] = r - - def delete(self, record_id: int): - self.logger.debug(f'delete: record_id={record_id}') - del self.history[record_id] - - def cleanup(self): - del_ids = [] - for rid, item in self.history.items(): - if item.creation_time < time.time()-_history_item_timeout: - del_ids.append(rid) - for rid in del_ids: - self.delete(rid) - - def __getitem__(self, key): - if key not in self.history: - raise RecordingNotFoundError() - - return self.history[key] - - def __setitem__(self, key, value): - raise NotImplementedError('setting history item this way is prohibited') - - def __contains__(self, key): - return key in self.history - - -class Recording: - RECORDER_PROGRAM = None - - start_time: float - stop_time: float - duration: int - record_id: int - recorder_program_pid: Optional[int] - process: Optional[subprocess.Popen] - - g_record_id = 1 - - def __init__(self): - if self.RECORDER_PROGRAM is None: - raise RuntimeError('this is abstract class') - - self.start_time = 0 - self.stop_time = 0 - self.duration = 0 - self.process = None - self.recorder_program_pid = None - self.record_id = Recording.next_id() - self.logger = logging.getLogger(self.__class__.__name__) - - def is_started(self) -> bool: - return self.start_time > 0 and self.stop_time > 0 - - def is_waiting(self): - return self.duration > 0 - - def ask_for(self, duration) -> int: - overtime = 0 - orig_duration = duration - - if self.is_started(): - already_passed = time.time() - self.start_time - max_duration = Recorder.get_max_record_time() - already_passed - self.logger.debug(f'ask_for({orig_duration}): recording is in progress, already passed {already_passed}s, max_duration set to {max_duration}') - else: - max_duration = Recorder.get_max_record_time() - - if duration > max_duration: - overtime = duration - max_duration - duration = max_duration - - self.logger.debug(f'ask_for({orig_duration}): requested duration ({orig_duration}) is greater than max ({max_duration}), overtime is {overtime}') - - self.duration += duration - if self.is_started(): - til_end = self.stop_time - time.time() - if til_end < 0: - til_end = 0 - - _prev_stop_time = self.stop_time - _to_add = duration - til_end - if _to_add < 0: - _to_add = 0 - - self.stop_time += _to_add - self.logger.debug(f'ask_for({orig_duration}): adding {_to_add} to stop_time (before: {_prev_stop_time}, after: {self.stop_time})') - - return overtime - - def start(self, output: str): - assert self.start_time == 0 and self.stop_time == 0, "already started?!" - assert self.process is None, "self.process is not None, what the hell?" - - cur = time.time() - self.start_time = cur - self.stop_time = cur + self.duration - - cmd = self.get_command(output) - self.logger.debug(f'start: running `{cmd}`') - self.process = subprocess.Popen(cmd, shell=True, stdin=None, stdout=None, stderr=None, close_fds=True) - - sh_pid = self.process.pid - self.logger.debug(f'start: started, pid of shell is {sh_pid}') - - pid = self.find_recorder_program_pid(sh_pid) - if pid is not None: - self.recorder_program_pid = pid - self.logger.debug(f'start: pid of {self.RECORDER_PROGRAM} is {pid}') - - def get_command(self, output: str) -> str: - pass - - def stop(self): - if self.process: - if self.recorder_program_pid is None: - self.recorder_program_pid = self.find_recorder_program_pid(self.process.pid) - - if self.recorder_program_pid is not None: - os.kill(self.recorder_program_pid, signal.SIGINT) - timeout = config['node']['process_wait_timeout'] - - self.logger.debug(f'stop: sent SIGINT to {self.recorder_program_pid}. now waiting up to {timeout} seconds...') - try: - self.process.wait(timeout=timeout) - except subprocess.TimeoutExpired: - self.logger.warning(f'stop: wait({timeout}): timeout expired, killing it') - try: - os.kill(self.recorder_program_pid, signal.SIGKILL) - self.process.terminate() - except Exception as exc: - self.logger.exception(exc) - else: - self.logger.warning(f'stop: pid of {self.RECORDER_PROGRAM} is unknown, calling terminate()') - self.process.terminate() - - rc = self.process.returncode - self.logger.debug(f'stop: rc={rc}') - - self.process = None - self.recorder_program_pid = 0 - - self.duration = 0 - self.start_time = 0 - self.stop_time = 0 - - def find_recorder_program_pid(self, sh_pid: int): - try: - children = find_child_processes(sh_pid) - except OSError as exc: - self.logger.warning(f'failed to find child process of {sh_pid}: ' + str(exc)) - return None - - for child in children: - if self.RECORDER_PROGRAM in child.cmd: - return child.pid - - return None - - @staticmethod - def next_id() -> int: - cur_id = Recording.g_record_id - Recording.g_record_id += 1 - return cur_id - - def increment_id(self): - self.record_id = Recording.next_id() - - -class Recorder: - TEMP_NAME = None - - interrupted: bool - lock: threading.Lock - history_lock: threading.Lock - recording: Optional[Recording] - overtime: int - history: RecordHistory - next_history_cleanup_time: float - storage: RecordStorage - - def __init__(self, - storage: RecordStorage, - recording: Recording): - if self.TEMP_NAME is None: - raise RuntimeError('this is abstract class') - - self.storage = storage - self.recording = recording - self.interrupted = False - self.lock = threading.Lock() - self.history_lock = threading.Lock() - self.overtime = 0 - self.history = RecordHistory() - self.next_history_cleanup_time = 0 - self.logger = logging.getLogger(self.__class__.__name__) - - def start_thread(self): - t = threading.Thread(target=self.loop) - t.daemon = True - t.start() - - def loop(self) -> None: - tempname = os.path.join(self.storage.root, self.TEMP_NAME) - - while not self.interrupted: - cur = time.time() - stopped = False - cur_record_id = None - - if self.next_history_cleanup_time == 0: - self.next_history_cleanup_time = time.time() + _history_cleanup_freq - elif self.next_history_cleanup_time <= time.time(): - self.logger.debug('loop: calling history.cleanup()') - try: - self.history.cleanup() - except Exception as e: - self.logger.error('loop: error while history.cleanup(): ' + str(e)) - self.next_history_cleanup_time = time.time() + _history_cleanup_freq - - with self.lock: - cur_record_id = self.recording.record_id - # self.logger.debug(f'cur_record_id={cur_record_id}') - - if not self.recording.is_started(): - if self.recording.is_waiting(): - try: - if os.path.exists(tempname): - self.logger.warning(f'loop: going to start new recording, but {tempname} still exists, unlinking..') - try: - os.unlink(tempname) - except OSError as e: - self.logger.exception(e) - self.recording.start(tempname) - with self.history_lock: - self.history[cur_record_id].mark_started(self.recording.start_time) - except Exception as exc: - self.logger.exception(exc) - - # there should not be any errors, but still.. - try: - self.recording.stop() - except Exception as exc: - self.logger.exception(exc) - - with self.history_lock: - self.history[cur_record_id].mark_failed(exc) - - self.logger.debug(f'loop: start exc path: calling increment_id()') - self.recording.increment_id() - else: - if cur >= self.recording.stop_time: - try: - start_time = self.recording.start_time - stop_time = self.recording.stop_time - self.recording.stop() - - saved_name = self.storage.save(tempname, - record_id=cur_record_id, - start_time=int(start_time), - stop_time=int(stop_time)) - - with self.history_lock: - self.history[cur_record_id].mark_finished(stop_time, saved_name) - except Exception as exc: - self.logger.exception(exc) - with self.history_lock: - self.history[cur_record_id].mark_failed(exc) - finally: - self.logger.debug(f'loop: stop exc final path: calling increment_id()') - self.recording.increment_id() - - stopped = True - - if stopped and self.overtime > 0: - self.logger.info(f'recording {cur_record_id} is stopped, but we\'ve got overtime ({self.overtime})') - _overtime = self.overtime - self.overtime = 0 - - related_id = self.record(_overtime) - self.logger.info(f'enqueued another record with id {related_id}') - - if cur_record_id is not None: - with self.history_lock: - self.history[cur_record_id].add_relation(related_id) - - time.sleep(0.2) - - def record(self, duration: int) -> int: - self.logger.debug(f'record: duration={duration}') - with self.lock: - overtime = self.recording.ask_for(duration) - self.logger.debug(f'overtime={overtime}') - - if overtime > self.overtime: - self.overtime = overtime - - if not self.recording.is_started(): - with self.history_lock: - self.history.add(self.recording.record_id) - - return self.recording.record_id - - def stop(self): - self.interrupted = True - - def get_info(self, record_id: int) -> RecordHistoryItem: - with self.history_lock: - return self.history[record_id] - - def forget(self, record_id: int): - with self.history_lock: - self.logger.info(f'forget: removing record {record_id} from history') - self.history.delete(record_id) - - @staticmethod - def get_max_record_time() -> int: - return config['node']['record_max_time'] - - -class SoundRecorder(Recorder): - TEMP_NAME = 'temp.mp3' - - def __init__(self, *args, **kwargs): - super().__init__(recording=SoundRecording(), - *args, **kwargs) - - -class CameraRecorder(Recorder): - TEMP_NAME = 'temp.mp4' - - def __init__(self, - camera_type: CameraType, - *args, **kwargs): - if camera_type == CameraType.ESP32: - recording = ESP32CameraRecording(stream_addr=kwargs['stream_addr']) - del kwargs['stream_addr'] - else: - raise RuntimeError(f'unsupported camera type {camera_type}') - - super().__init__(recording=recording, - *args, **kwargs) - - -class SoundRecording(Recording): - RECORDER_PROGRAM = 'arecord' - - def get_command(self, output: str) -> str: - arecord = config['arecord']['bin'] - lame = config['lame']['bin'] - b = config['lame']['bitrate'] - - return f'{arecord} -f S16 -r 44100 -t raw 2>/dev/null | {lame} -r -s 44.1 -b {b} -m m - {output} >/dev/null 2>/dev/null' - - -class ESP32CameraRecording(Recording): - RECORDER_PROGRAM = 'esp32_capture.py' - - stream_addr: Addr - - def __init__(self, stream_addr: Addr): - super().__init__() - self.stream_addr = stream_addr - - def get_command(self, output: str) -> str: - bin = config['esp32_capture']['bin'] - return f'{bin} --addr {self.stream_addr[0]}:{self.stream_addr[1]} --output-directory {output} >/dev/null 2>/dev/null' - - def start(self, output: str): - output = os.path.dirname(output) - return super().start(output) \ No newline at end of file diff --git a/src/home/media/record_client.py b/src/home/media/record_client.py deleted file mode 100644 index 322495c..0000000 --- a/src/home/media/record_client.py +++ /dev/null @@ -1,166 +0,0 @@ -import time -import logging -import threading -import os.path - -from tempfile import gettempdir -from .record import RecordStatus -from .node_client import SoundNodeClient, MediaNodeClient, CameraNodeClient -from ..util import Addr -from typing import Optional, Callable, Dict - - -class RecordClient: - DOWNLOAD_EXTENSION = None - - interrupted: bool - logger: logging.Logger - clients: Dict[str, MediaNodeClient] - awaiting: Dict[str, Dict[int, Optional[dict]]] - error_handler: Optional[Callable] - finished_handler: Optional[Callable] - download_on_finish: bool - - def __init__(self, - nodes: Dict[str, Addr], - error_handler: Optional[Callable] = None, - finished_handler: Optional[Callable] = None, - download_on_finish=False): - if self.DOWNLOAD_EXTENSION is None: - raise RuntimeError('this is abstract class') - - self.interrupted = False - self.logger = logging.getLogger(self.__class__.__name__) - self.clients = {} - self.awaiting = {} - - self.download_on_finish = download_on_finish - self.error_handler = error_handler - self.finished_handler = finished_handler - - self.awaiting_lock = threading.Lock() - - self.make_clients(nodes) - - try: - t = threading.Thread(target=self.loop) - t.daemon = True - t.start() - except (KeyboardInterrupt, SystemExit) as exc: - self.stop() - self.logger.exception(exc) - - def make_clients(self, nodes: Dict[str, Addr]): - pass - - def stop(self): - self.interrupted = True - - def loop(self): - while not self.interrupted: - for node in self.awaiting.keys(): - with self.awaiting_lock: - record_ids = list(self.awaiting[node].keys()) - if not record_ids: - continue - - self.logger.debug(f'loop: node `{node}` awaiting list: {record_ids}') - - cl = self.getclient(node) - del_ids = [] - for rid in record_ids: - info = cl.record_info(rid) - - if info['relations']: - for relid in info['relations']: - self.wait_for_record(node, relid, self.awaiting[node][rid], is_relative=True) - - status = RecordStatus(info['status']) - if status in (RecordStatus.FINISHED, RecordStatus.ERROR): - if status == RecordStatus.FINISHED: - if self.download_on_finish: - local_fn = self.download(node, rid, info['file']['fileid']) - else: - local_fn = None - self._report_finished(info, local_fn, self.awaiting[node][rid]) - else: - self._report_error(info, self.awaiting[node][rid]) - del_ids.append(rid) - self.logger.debug(f'record {rid}: status {status}') - - if del_ids: - self.logger.debug(f'deleting {del_ids} from {node}\'s awaiting list') - with self.awaiting_lock: - for del_id in del_ids: - del self.awaiting[node][del_id] - - time.sleep(5) - - self.logger.info('loop ended') - - def getclient(self, node: str): - return self.clients[node] - - def record(self, - node: str, - duration: int, - userdata: Optional[dict] = None) -> int: - self.logger.debug(f'record: node={node}, duration={duration}, userdata={userdata}') - - cl = self.getclient(node) - record_id = cl.record(duration)['id'] - self.logger.debug(f'record: request sent, record_id={record_id}') - - self.wait_for_record(node, record_id, userdata) - return record_id - - def wait_for_record(self, - node: str, - record_id: int, - userdata: Optional[dict] = None, - is_relative=False): - with self.awaiting_lock: - if record_id not in self.awaiting[node]: - msg = f'wait_for_record: adding {record_id} to {node}' - if is_relative: - msg += ' (by relation)' - self.logger.debug(msg) - - self.awaiting[node][record_id] = userdata - - def download(self, node: str, record_id: int, fileid: str): - dst = os.path.join(gettempdir(), f'{node}_{fileid}.{self.DOWNLOAD_EXTENSION}') - cl = self.getclient(node) - cl.record_download(record_id, dst) - return dst - - def forget(self, node: str, rid: int): - self.getclient(node).record_forget(rid) - - def _report_finished(self, *args): - if self.finished_handler: - self.finished_handler(*args) - - def _report_error(self, *args): - if self.error_handler: - self.error_handler(*args) - - -class SoundRecordClient(RecordClient): - DOWNLOAD_EXTENSION = 'mp3' - # clients: Dict[str, SoundNodeClient] - - def make_clients(self, nodes: Dict[str, Addr]): - for node, addr in nodes.items(): - self.clients[node] = SoundNodeClient(addr) - self.awaiting[node] = {} - - -class CameraRecordClient(RecordClient): - DOWNLOAD_EXTENSION = 'mp4' - # clients: Dict[str, CameraNodeClient] - - def make_clients(self, nodes: Dict[str, Addr]): - for node, addr in nodes.items(): - self.clients[node] = CameraNodeClient(addr) - self.awaiting[node] = {} \ No newline at end of file diff --git a/src/home/media/storage.py b/src/home/media/storage.py deleted file mode 100644 index dd74ff8..0000000 --- a/src/home/media/storage.py +++ /dev/null @@ -1,210 +0,0 @@ -import os -import re -import shutil -import logging - -from typing import Optional, Union, List -from datetime import datetime -from ..util import strgen - -logger = logging.getLogger(__name__) - - -# record file -# ----------- - -class RecordFile: - EXTENSION = None - - start_time: Optional[datetime] - stop_time: Optional[datetime] - record_id: Optional[int] - name: str - file_id: Optional[str] - remote: bool - remote_filesize: int - storage_root: str - - human_date_dmt = '%d.%m.%y' - human_time_fmt = '%H:%M:%S' - - @staticmethod - def create(filename: str, *args, **kwargs): - if filename.endswith(f'.{SoundRecordFile.EXTENSION}'): - return SoundRecordFile(filename, *args, **kwargs) - elif filename.endswith(f'.{CameraRecordFile.EXTENSION}'): - return CameraRecordFile(filename, *args, **kwargs) - else: - raise RuntimeError(f'unsupported file extension: {filename}') - - def __init__(self, filename: str, remote=False, remote_filesize=None, storage_root='/'): - if self.EXTENSION is None: - raise RuntimeError('this is abstract class') - - self.name = filename - self.storage_root = storage_root - - self.remote = remote - self.remote_filesize = remote_filesize - - m = re.match(r'^(\d{6}-\d{6})_(\d{6}-\d{6})_id(\d+)(_\w+)?\.'+self.EXTENSION+'$', filename) - if m: - self.start_time = datetime.strptime(m.group(1), RecordStorage.time_fmt) - self.stop_time = datetime.strptime(m.group(2), RecordStorage.time_fmt) - self.record_id = int(m.group(3)) - self.file_id = (m.group(1) + '_' + m.group(2)).replace('-', '_') - else: - logger.warning(f'unexpected filename: {filename}') - self.start_time = None - self.stop_time = None - self.record_id = None - self.file_id = None - - @property - def path(self): - if self.remote: - return RuntimeError('remote recording, can\'t get real path') - - return os.path.realpath(os.path.join( - self.storage_root, self.name - )) - - @property - def start_humantime(self) -> str: - if self.start_time is None: - return '?' - fmt = f'{RecordFile.human_date_dmt} {RecordFile.human_time_fmt}' - return self.start_time.strftime(fmt) - - @property - def stop_humantime(self) -> str: - if self.stop_time is None: - return '?' - fmt = RecordFile.human_time_fmt - if self.start_time.date() != self.stop_time.date(): - fmt = f'{RecordFile.human_date_dmt} {fmt}' - return self.stop_time.strftime(fmt) - - @property - def start_unixtime(self) -> int: - if self.start_time is None: - return 0 - return int(self.start_time.timestamp()) - - @property - def stop_unixtime(self) -> int: - if self.stop_time is None: - return 0 - return int(self.stop_time.timestamp()) - - @property - def filesize(self): - if self.remote: - if self.remote_filesize is None: - raise RuntimeError('file is remote and remote_filesize is not set') - return self.remote_filesize - return os.path.getsize(self.path) - - def __dict__(self) -> dict: - return { - 'start_unixtime': self.start_unixtime, - 'stop_unixtime': self.stop_unixtime, - 'filename': self.name, - 'filesize': self.filesize, - 'fileid': self.file_id, - 'record_id': self.record_id or 0, - } - - -class PseudoRecordFile(RecordFile): - EXTENSION = 'null' - - def __init__(self): - super().__init__('pseudo.null') - - @property - def filesize(self): - return 0 - - -class SoundRecordFile(RecordFile): - EXTENSION = 'mp3' - - -class CameraRecordFile(RecordFile): - EXTENSION = 'mp4' - - -# record storage -# -------------- - -class RecordStorage: - EXTENSION = None - - time_fmt = '%d%m%y-%H%M%S' - - def __init__(self, root: str): - if self.EXTENSION is None: - raise RuntimeError('this is abstract class') - - self.root = root - - def getfiles(self, as_objects=False) -> Union[List[str], List[RecordFile]]: - files = [] - for name in os.listdir(self.root): - path = os.path.join(self.root, name) - if os.path.isfile(path) and name.endswith(f'.{self.EXTENSION}'): - files.append(name if not as_objects else RecordFile.create(name, storage_root=self.root)) - return files - - def find(self, file_id: str) -> Optional[RecordFile]: - for name in os.listdir(self.root): - if os.path.isfile(os.path.join(self.root, name)) and name.endswith(f'.{self.EXTENSION}'): - item = RecordFile.create(name, storage_root=self.root) - if item.file_id == file_id: - return item - return None - - def purge(self): - files = self.getfiles() - if files: - logger = logging.getLogger(self.__name__) - for f in files: - try: - path = os.path.join(self.root, f) - logger.debug(f'purge: deleting {path}') - os.unlink(path) - except OSError as exc: - logger.exception(exc) - - def delete(self, file: RecordFile): - os.unlink(file.path) - - def save(self, - fn: str, - record_id: int, - start_time: int, - stop_time: int) -> RecordFile: - - start_time_s = datetime.fromtimestamp(start_time).strftime(self.time_fmt) - stop_time_s = datetime.fromtimestamp(stop_time).strftime(self.time_fmt) - - dst_fn = f'{start_time_s}_{stop_time_s}_id{record_id}' - if os.path.exists(os.path.join(self.root, dst_fn)): - dst_fn += strgen(4) - dst_fn += f'.{self.EXTENSION}' - dst_path = os.path.join(self.root, dst_fn) - - shutil.move(fn, dst_path) - return RecordFile.create(dst_fn, storage_root=self.root) - - -class SoundRecordStorage(RecordStorage): - EXTENSION = 'mp3' - - -class ESP32CameraRecordStorage(RecordStorage): - EXTENSION = 'jpg' # not used anyway - - def save(self, *args, **kwargs): - return PseudoRecordFile() \ No newline at end of file diff --git a/src/home/media/types.py b/src/home/media/types.py deleted file mode 100644 index acbc291..0000000 --- a/src/home/media/types.py +++ /dev/null @@ -1,13 +0,0 @@ -from enum import Enum, auto - - -class MediaNodeType(Enum): - SOUND = auto() - CAMERA = auto() - - -class RecordStatus(Enum): - WAITING = auto() - RECORDING = auto() - FINISHED = auto() - ERROR = auto() diff --git a/src/home/mqtt/__init__.py b/src/home/mqtt/__init__.py deleted file mode 100644 index 707d59c..0000000 --- a/src/home/mqtt/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from ._mqtt import Mqtt -from ._node import MqttNode -from ._module import MqttModule -from ._wrapper import MqttWrapper -from ._config import MqttConfig, MqttCreds, MqttNodesConfig -from ._payload import MqttPayload, MqttPayloadCustomField -from ._util import get_modules as get_mqtt_modules \ No newline at end of file diff --git a/src/home/mqtt/_config.py b/src/home/mqtt/_config.py deleted file mode 100644 index 9ba9443..0000000 --- a/src/home/mqtt/_config.py +++ /dev/null @@ -1,165 +0,0 @@ -from ..config import ConfigUnit -from typing import Optional, Union -from ..util import Addr -from collections import namedtuple - -MqttCreds = namedtuple('MqttCreds', 'username, password') - - -class MqttConfig(ConfigUnit): - NAME = 'mqtt' - - @classmethod - def schema(cls) -> Optional[dict]: - addr_schema = { - 'type': 'dict', - 'required': True, - 'schema': { - 'host': {'type': 'string', 'required': True}, - 'port': {'type': 'integer', 'required': True} - } - } - - schema = {} - for key in ('local', 'remote'): - schema[f'{key}_addr'] = addr_schema - - schema['creds'] = { - 'type': 'dict', - 'required': True, - 'keysrules': {'type': 'string'}, - 'valuesrules': { - 'type': 'dict', - 'schema': { - 'username': {'type': 'string', 'required': True}, - 'password': {'type': 'string', 'required': True}, - } - } - } - - for key in ('client', 'server'): - schema[f'default_{key}_creds'] = {'type': 'string', 'required': True} - - return schema - - def remote_addr(self) -> Addr: - return Addr(host=self['remote_addr']['host'], - port=self['remote_addr']['port']) - - def local_addr(self) -> Addr: - return Addr(host=self['local_addr']['host'], - port=self['local_addr']['port']) - - def creds_by_name(self, name: str) -> MqttCreds: - return MqttCreds(username=self['creds'][name]['username'], - password=self['creds'][name]['password']) - - def creds(self) -> MqttCreds: - return self.creds_by_name(self['default_client_creds']) - - def server_creds(self) -> MqttCreds: - return self.creds_by_name(self['default_server_creds']) - - -class MqttNodesConfig(ConfigUnit): - NAME = 'mqtt_nodes' - - @classmethod - def schema(cls) -> Optional[dict]: - return { - 'common': { - 'type': 'dict', - 'schema': { - 'temphum': { - 'type': 'dict', - 'schema': { - 'interval': {'type': 'integer'} - } - }, - 'password': {'type': 'string'} - } - }, - 'nodes': { - 'type': 'dict', - 'required': True, - 'keysrules': {'type': 'string'}, - 'valuesrules': { - 'type': 'dict', - 'schema': { - 'type': {'type': 'string', 'required': True, 'allowed': ['esp8266', 'linux', 'none'],}, - 'board': {'type': 'string', 'allowed': ['nodemcu', 'd1_mini_lite', 'esp12e']}, - 'temphum': { - 'type': 'dict', - 'schema': { - 'module': {'type': 'string', 'required': True, 'allowed': ['si7021', 'dht12']}, - 'interval': {'type': 'integer'}, - 'i2c_bus': {'type': 'integer'}, - 'tcpserver': { - 'type': 'dict', - 'schema': { - 'port': {'type': 'integer', 'required': True} - } - } - } - }, - 'relay': { - 'type': 'dict', - 'schema': { - 'device_type': {'type': 'string', 'allowed': ['lamp', 'pump', 'solenoid'], 'required': True}, - 'legacy_topics': {'type': 'boolean'} - } - }, - 'password': {'type': 'string'} - } - } - } - } - - @staticmethod - def custom_validator(data): - for name, node in data['nodes'].items(): - if 'temphum' in node: - if node['type'] == 'linux': - if 'i2c_bus' not in node['temphum']: - raise KeyError(f'nodes.{name}.temphum: i2c_bus is missing but required for type=linux') - if node['type'] in ('esp8266',) and 'board' not in node: - raise KeyError(f'nodes.{name}: board is missing but required for type={node["type"]}') - - def get_node(self, name: str) -> dict: - node = self['nodes'][name] - if node['type'] == 'none': - return node - - try: - if 'password' not in node: - node['password'] = self['common']['password'] - except KeyError: - pass - - try: - if 'temphum' in node: - for ckey, cval in self['common']['temphum'].items(): - if ckey not in node['temphum']: - node['temphum'][ckey] = cval - except KeyError: - pass - - return node - - def get_nodes(self, - filters: Optional[Union[list[str], tuple[str]]] = None, - only_names=False) -> Union[dict, list[str]]: - if filters: - for f in filters: - if f not in ('temphum', 'relay'): - raise ValueError(f'{self.__class__.__name__}::get_node(): invalid filter {f}') - reslist = [] - resdict = {} - for name in self['nodes'].keys(): - node = self.get_node(name) - if (not filters) or ('temphum' in filters and 'temphum' in node) or ('relay' in filters and 'relay' in node): - if only_names: - reslist.append(name) - else: - resdict[name] = node - return reslist if only_names else resdict diff --git a/src/home/mqtt/_module.py b/src/home/mqtt/_module.py deleted file mode 100644 index 80f27bb..0000000 --- a/src/home/mqtt/_module.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -import abc -import logging -import threading - -from time import sleep -from ..util import next_tick_gen - -from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: - from ._node import MqttNode - from ._payload import MqttPayload - - -class MqttModule(abc.ABC): - _tick_interval: int - _initialized: bool - _connected: bool - _ticker: Optional[threading.Thread] - _mqtt_node_ref: Optional[MqttNode] - - def __init__(self, tick_interval=0): - self._tick_interval = tick_interval - self._initialized = False - self._ticker = None - self._logger = logging.getLogger(self.__class__.__name__) - self._connected = False - self._mqtt_node_ref = None - - def on_connect(self, mqtt: MqttNode): - self._connected = True - self._mqtt_node_ref = mqtt - if self._tick_interval: - self._start_ticker() - - def on_disconnect(self, mqtt: MqttNode): - self._connected = False - self._mqtt_node_ref = None - - def is_initialized(self): - return self._initialized - - def set_initialized(self): - self._initialized = True - - def unset_initialized(self): - self._initialized = False - - def tick(self): - pass - - def _tick(self): - g = next_tick_gen(self._tick_interval) - while self._connected: - sleep(next(g)) - if not self._connected: - break - self.tick() - - def _start_ticker(self): - if not self._ticker or not self._ticker.is_alive(): - name_part = f'{self._mqtt_node_ref.id}/' if self._mqtt_node_ref else '' - self._ticker = None - self._ticker = threading.Thread(target=self._tick, - name=f'mqtt:{self.__class__.__name__}/{name_part}ticker') - self._ticker.start() - - def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: - pass diff --git a/src/home/mqtt/_mqtt.py b/src/home/mqtt/_mqtt.py deleted file mode 100644 index 746ae2e..0000000 --- a/src/home/mqtt/_mqtt.py +++ /dev/null @@ -1,86 +0,0 @@ -import os.path -import paho.mqtt.client as mqtt -import ssl -import logging - -from ._config import MqttCreds, MqttConfig -from typing import Optional - - -class Mqtt: - _connected: bool - _is_server: bool - _mqtt_config: MqttConfig - - def __init__(self, - clean_session=True, - client_id='', - creds: Optional[MqttCreds] = None, - is_server=False): - if not client_id: - raise ValueError('client_id must not be empty') - - self._client = mqtt.Client(client_id=client_id, - protocol=mqtt.MQTTv311, - clean_session=clean_session) - self._client.on_connect = self.on_connect - self._client.on_disconnect = self.on_disconnect - self._client.on_message = self.on_message - self._client.on_log = self.on_log - self._client.on_publish = self.on_publish - self._loop_started = False - self._connected = False - self._is_server = is_server - self._mqtt_config = MqttConfig() - self._logger = logging.getLogger(self.__class__.__name__) - - if not creds: - creds = self._mqtt_config.creds() if not is_server else self._mqtt_config.server_creds() - - self._client.username_pw_set(creds.username, creds.password) - - def _configure_tls(self): - ca_certs = os.path.realpath(os.path.join( - os.path.dirname(os.path.realpath(__file__)), - '..', - '..', - '..', - 'assets', - 'mqtt_ca.crt' - )) - self._client.tls_set(ca_certs=ca_certs, - cert_reqs=ssl.CERT_REQUIRED, - tls_version=ssl.PROTOCOL_TLSv1_2) - - def connect_and_loop(self, loop_forever=True): - self._configure_tls() - addr = self._mqtt_config.local_addr() if self._is_server else self._mqtt_config.remote_addr() - self._client.connect(addr.host, addr.port, 60) - if loop_forever: - self._client.loop_forever() - else: - self._client.loop_start() - self._loop_started = True - - def disconnect(self): - self._client.disconnect() - self._client.loop_write() - self._client.loop_stop() - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - self._logger.info("Connected with result code " + str(rc)) - self._connected = True - - def on_disconnect(self, client: mqtt.Client, userdata, rc): - self._logger.info("Disconnected with result code " + str(rc)) - self._connected = False - - def on_log(self, client: mqtt.Client, userdata, level, buf): - level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO - self._logger.log(level, f'MQTT: {buf}') - - def on_message(self, client: mqtt.Client, userdata, msg): - self._logger.debug(msg.topic + ": " + str(msg.payload)) - - def on_publish(self, client: mqtt.Client, userdata, mid): - self._logger.debug(f'publish done, mid={mid}') diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py deleted file mode 100644 index 4e259a4..0000000 --- a/src/home/mqtt/_node.py +++ /dev/null @@ -1,92 +0,0 @@ -import logging -import importlib - -from typing import List, TYPE_CHECKING, Optional -from ._payload import MqttPayload -from ._module import MqttModule -if TYPE_CHECKING: - from ._wrapper import MqttWrapper -else: - MqttWrapper = None - - -class MqttNode: - _modules: List[MqttModule] - _module_subscriptions: dict[str, MqttModule] - _node_id: str - _node_secret: str - _payload_callbacks: list[callable] - _wrapper: Optional[MqttWrapper] - - def __init__(self, - node_id: str, - node_secret: Optional[str] = None): - self._modules = [] - self._module_subscriptions = {} - self._node_id = node_id - self._node_secret = node_secret - self._payload_callbacks = [] - self._logger = logging.getLogger(self.__class__.__name__) - self._wrapper = None - - def on_connect(self, wrapper: MqttWrapper): - self._wrapper = wrapper - for module in self._modules: - if not module.is_initialized(): - module.on_connect(self) - module.set_initialized() - - def on_disconnect(self): - self._wrapper = None - for module in self._modules: - module.unset_initialized() - - def on_message(self, topic, payload): - if topic in self._module_subscriptions: - payload = self._module_subscriptions[topic].handle_payload(self, topic, payload) - if isinstance(payload, MqttPayload): - for f in self._payload_callbacks: - f(self, payload) - - def load_module(self, module_name: str, *args, **kwargs) -> MqttModule: - module = importlib.import_module(f'..module.{module_name}', __name__) - if not hasattr(module, 'MODULE_NAME'): - raise RuntimeError(f'MODULE_NAME not found in module {module}') - cl = getattr(module, getattr(module, 'MODULE_NAME')) - instance = cl(*args, **kwargs) - self.add_module(instance) - return instance - - def add_module(self, module: MqttModule): - self._modules.append(module) - if self._wrapper and self._wrapper._connected: - module.on_connect(self) - module.set_initialized() - - def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1): - if not self._wrapper or not self._wrapper._connected: - raise RuntimeError('not connected') - - self._module_subscriptions[topic] = module - self._wrapper.subscribe(self.id, topic, qos) - - def publish(self, - topic: str, - payload: bytes, - qos: int = 1): - self._wrapper.publish(self.id, topic, payload, qos) - - def add_payload_callback(self, callback: callable): - self._payload_callbacks.append(callback) - - @property - def id(self) -> str: - return self._node_id - - @property - def secret(self) -> str: - return self._node_secret - - @secret.setter - def secret(self, secret: str) -> None: - self._node_secret = secret diff --git a/src/home/mqtt/_payload.py b/src/home/mqtt/_payload.py deleted file mode 100644 index 58eeae3..0000000 --- a/src/home/mqtt/_payload.py +++ /dev/null @@ -1,145 +0,0 @@ -import struct -import abc -import re - -from typing import Optional, Tuple - - -def pldstr(self) -> str: - attrs = [] - for field in self.__class__.__annotations__: - if hasattr(self, field): - attr = getattr(self, field) - attrs.append(f'{field}={attr}') - if attrs: - attrs_s = ' ' - attrs_s += ', '.join(attrs) - else: - attrs_s = '' - return f'<%s{attrs_s}>' % (self.__class__.__name__,) - - -class MqttPayload(abc.ABC): - FORMAT = '' - PACKER = {} - UNPACKER = {} - - def __init__(self, **kwargs): - for field in self.__class__.__annotations__: - setattr(self, field, kwargs[field]) - - def pack(self): - args = [] - bf_number = -1 - bf_arg = 0 - bf_progress = 0 - - for field, field_type in self.__class__.__annotations__.items(): - bfp = _bit_field_params(field_type) - if bfp: - n, s, b = bfp - if n != bf_number: - if bf_number != -1: - args.append(bf_arg) - bf_number = n - bf_progress = 0 - bf_arg = 0 - bf_arg |= (getattr(self, field) & (2 ** b - 1)) << bf_progress - bf_progress += b - - else: - if bf_number != -1: - args.append(bf_arg) - bf_number = -1 - bf_progress = 0 - bf_arg = 0 - - args.append(self._pack_field(field)) - - if bf_number != -1: - args.append(bf_arg) - - return struct.pack(self.FORMAT, *args) - - @classmethod - def unpack(cls, buf: bytes): - data = struct.unpack(cls.FORMAT, buf) - kwargs = {} - i = 0 - bf_number = -1 - bf_progress = 0 - - for field, field_type in cls.__annotations__.items(): - bfp = _bit_field_params(field_type) - if bfp: - n, s, b = bfp - if n != bf_number: - bf_number = n - bf_progress = 0 - kwargs[field] = (data[i] >> bf_progress) & (2 ** b - 1) - bf_progress += b - continue # don't increment i - - if bf_number != -1: - bf_number = -1 - i += 1 - - if issubclass(field_type, MqttPayloadCustomField): - kwargs[field] = field_type.unpack(data[i]) - else: - kwargs[field] = cls._unpack_field(field, data[i]) - i += 1 - - return cls(**kwargs) - - def _pack_field(self, name): - val = getattr(self, name) - if self.PACKER and name in self.PACKER: - return self.PACKER[name](val) - else: - return val - - @classmethod - def _unpack_field(cls, name, val): - if isinstance(val, MqttPayloadCustomField): - return - if cls.UNPACKER and name in cls.UNPACKER: - return cls.UNPACKER[name](val) - else: - return val - - def __str__(self): - return pldstr(self) - - -class MqttPayloadCustomField(abc.ABC): - def __init__(self, **kwargs): - for field in self.__class__.__annotations__: - setattr(self, field, kwargs[field]) - - @abc.abstractmethod - def __index__(self): - pass - - @classmethod - @abc.abstractmethod - def unpack(cls, *args, **kwargs): - pass - - def __str__(self): - return pldstr(self) - - -def bit_field(seq_no: int, total_bits: int, bits: int): - return type(f'MQTTPayloadBitField_{seq_no}_{total_bits}_{bits}', (object,), { - 'seq_no': seq_no, - 'total_bits': total_bits, - 'bits': bits - }) - - -def _bit_field_params(cl) -> Optional[Tuple[int, ...]]: - match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__) - if match is not None: - return tuple([int(match.group(i)) for i in range(1, 4)]) - return None \ No newline at end of file diff --git a/src/home/mqtt/_util.py b/src/home/mqtt/_util.py deleted file mode 100644 index 390d463..0000000 --- a/src/home/mqtt/_util.py +++ /dev/null @@ -1,15 +0,0 @@ -import os -import re - -from typing import List - - -def get_modules() -> List[str]: - modules = [] - modules_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module') - for name in os.listdir(modules_dir): - if os.path.isdir(os.path.join(modules_dir, name)): - continue - name = re.sub(r'\.py$', '', name) - modules.append(name) - return modules diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py deleted file mode 100644 index 3c2774c..0000000 --- a/src/home/mqtt/_wrapper.py +++ /dev/null @@ -1,60 +0,0 @@ -import paho.mqtt.client as mqtt - -from ._mqtt import Mqtt -from ._node import MqttNode -from ..util import strgen - - -class MqttWrapper(Mqtt): - _nodes: list[MqttNode] - - def __init__(self, - client_id: str, - topic_prefix='hk', - randomize_client_id=False, - clean_session=True): - if randomize_client_id: - client_id += '_'+strgen(6) - super().__init__(clean_session=clean_session, - client_id=client_id) - self._nodes = [] - self._topic_prefix = topic_prefix - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - super().on_connect(client, userdata, flags, rc) - for node in self._nodes: - node.on_connect(self) - - def on_disconnect(self, client: mqtt.Client, userdata, rc): - super().on_disconnect(client, userdata, rc) - for node in self._nodes: - node.on_disconnect() - - def on_message(self, client: mqtt.Client, userdata, msg): - try: - topic = msg.topic - topic_node = topic[len(self._topic_prefix)+1:topic.find('/', len(self._topic_prefix)+1)] - for node in self._nodes: - if node.id in ('+', topic_node): - node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) - except Exception as e: - self._logger.exception(str(e)) - - def add_node(self, node: MqttNode): - self._nodes.append(node) - if self._connected: - node.on_connect(self) - - def subscribe(self, - node_id: str, - topic: str, - qos: int): - self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos) - - def publish(self, - node_id: str, - topic: str, - payload: bytes, - qos: int): - self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos) - self._client.loop_write() diff --git a/src/home/mqtt/module/diagnostics.py b/src/home/mqtt/module/diagnostics.py deleted file mode 100644 index 5db5e99..0000000 --- a/src/home/mqtt/module/diagnostics.py +++ /dev/null @@ -1,64 +0,0 @@ -from .._payload import MqttPayload, MqttPayloadCustomField -from .._node import MqttNode, MqttModule -from typing import Optional - -MODULE_NAME = 'MqttDiagnosticsModule' - - -class DiagnosticsFlags(MqttPayloadCustomField): - state: bool - config_changed_value_present: bool - config_changed: bool - - @staticmethod - def unpack(flags: int): - # _logger.debug(f'StatFlags.unpack: flags={flags}') - state = flags & 0x1 - ccvp = (flags >> 1) & 0x1 - cc = (flags >> 2) & 0x1 - # _logger.debug(f'StatFlags.unpack: state={state}') - return DiagnosticsFlags(state=(state == 1), - config_changed_value_present=(ccvp == 1), - config_changed=(cc == 1)) - - def __index__(self): - bits = 0 - bits |= (int(self.state) & 0x1) - bits |= (int(self.config_changed_value_present) & 0x1) << 1 - bits |= (int(self.config_changed) & 0x1) << 2 - return bits - - -class InitialDiagnosticsPayload(MqttPayload): - FORMAT = '=IBbIB' - - ip: int - fw_version: int - rssi: int - free_heap: int - flags: DiagnosticsFlags - - -class DiagnosticsPayload(MqttPayload): - FORMAT = '=bIB' - - rssi: int - free_heap: int - flags: DiagnosticsFlags - - -class MqttDiagnosticsModule(MqttModule): - def on_connect(self, mqtt: MqttNode): - super().on_connect(mqtt) - for topic in ('diag', 'd1ag', 'stat', 'stat1'): - mqtt.subscribe_module(topic, self) - - def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: - message = None - if topic in ('stat', 'diag'): - message = DiagnosticsPayload.unpack(payload) - elif topic in ('stat1', 'd1ag'): - message = InitialDiagnosticsPayload.unpack(payload) - if message: - self._logger.debug(message) - return message diff --git a/src/home/mqtt/module/inverter.py b/src/home/mqtt/module/inverter.py deleted file mode 100644 index d927a06..0000000 --- a/src/home/mqtt/module/inverter.py +++ /dev/null @@ -1,195 +0,0 @@ -import time -import json -import datetime -try: - import inverterd -except: - pass - -from typing import Optional -from .._module import MqttModule -from .._node import MqttNode -from .._payload import MqttPayload, bit_field -try: - from home.database import InverterDatabase -except: - pass - -_mult_10 = lambda n: int(n*10) -_div_10 = lambda n: n/10 - - -MODULE_NAME = 'MqttInverterModule' - -STATUS_TOPIC = 'status' -GENERATION_TOPIC = 'generation' - - -class MqttInverterStatusPayload(MqttPayload): - # 46 bytes - FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH' - - PACKER = { - 'grid_voltage': _mult_10, - 'grid_freq': _mult_10, - 'ac_output_voltage': _mult_10, - 'ac_output_freq': _mult_10, - 'battery_voltage': _mult_10, - 'battery_voltage_scc': _mult_10, - 'battery_voltage_scc2': _mult_10, - 'pv1_input_voltage': _mult_10, - 'pv2_input_voltage': _mult_10 - } - UNPACKER = { - 'grid_voltage': _div_10, - 'grid_freq': _div_10, - 'ac_output_voltage': _div_10, - 'ac_output_freq': _div_10, - 'battery_voltage': _div_10, - 'battery_voltage_scc': _div_10, - 'battery_voltage_scc2': _div_10, - 'pv1_input_voltage': _div_10, - 'pv2_input_voltage': _div_10 - } - - time: int - grid_voltage: float - grid_freq: float - ac_output_voltage: float - ac_output_freq: float - ac_output_apparent_power: int - ac_output_active_power: int - output_load_percent: int - battery_voltage: float - battery_voltage_scc: float - battery_voltage_scc2: float - battery_discharge_current: int - battery_charge_current: int - battery_capacity: int - inverter_heat_sink_temp: int - mppt1_charger_temp: int - mppt2_charger_temp: int - pv1_input_power: int - pv2_input_power: int - pv1_input_voltage: float - pv2_input_voltage: float - - # H - mppt1_charger_status: bit_field(0, 16, 2) - mppt2_charger_status: bit_field(0, 16, 2) - battery_power_direction: bit_field(0, 16, 2) - dc_ac_power_direction: bit_field(0, 16, 2) - line_power_direction: bit_field(0, 16, 2) - load_connected: bit_field(0, 16, 1) - - -class MqttInverterGenerationPayload(MqttPayload): - # 8 bytes - FORMAT = 'II' - - time: int - wh: int - - -class MqttInverterModule(MqttModule): - _status_poll_freq: int - _generation_poll_freq: int - _inverter: Optional[inverterd.Client] - _database: Optional[InverterDatabase] - _gen_prev: float - - def __init__(self, status_poll_freq=0, generation_poll_freq=0): - super().__init__(tick_interval=status_poll_freq) - self._status_poll_freq = status_poll_freq - self._generation_poll_freq = generation_poll_freq - - # this defines whether this is a publisher or a subscriber - if status_poll_freq > 0: - self._inverter = inverterd.Client() - self._inverter.connect() - self._inverter.format(inverterd.Format.SIMPLE_JSON) - self._database = None - else: - self._inverter = None - self._database = InverterDatabase() - - self._gen_prev = 0 - - def on_connect(self, mqtt: MqttNode): - super().on_connect(mqtt) - if not self._inverter: - mqtt.subscribe_module(STATUS_TOPIC, self) - mqtt.subscribe_module(GENERATION_TOPIC, self) - - def tick(self): - if not self._inverter: - return - - # read status - now = time.time() - try: - raw = self._inverter.exec('get-status') - except inverterd.InverterError as e: - self._logger.error(f'inverter error: {str(e)}') - # TODO send to server - return - - data = json.loads(raw)['data'] - status = MqttInverterStatusPayload(time=round(now), **data) - self._mqtt_node_ref.publish(STATUS_TOPIC, status.pack()) - - # read today's generation stat - now = time.time() - if self._gen_prev == 0 or now - self._gen_prev >= self._generation_poll_freq: - self._gen_prev = now - today = datetime.date.today() - try: - raw = self._inverter.exec('get-day-generated', (today.year, today.month, today.day)) - except inverterd.InverterError as e: - self._logger.error(f'inverter error: {str(e)}') - # TODO send to server - return - - data = json.loads(raw)['data'] - gen = MqttInverterGenerationPayload(time=round(now), wh=data['wh']) - self._mqtt_node_ref.publish(GENERATION_TOPIC, gen.pack()) - - def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: - home_id = 1 # legacy compat - - if topic == STATUS_TOPIC: - s = MqttInverterStatusPayload.unpack(payload) - self._database.add_status(home_id=home_id, - client_time=s.time, - grid_voltage=int(s.grid_voltage*10), - grid_freq=int(s.grid_freq * 10), - ac_output_voltage=int(s.ac_output_voltage * 10), - ac_output_freq=int(s.ac_output_freq * 10), - ac_output_apparent_power=s.ac_output_apparent_power, - ac_output_active_power=s.ac_output_active_power, - output_load_percent=s.output_load_percent, - battery_voltage=int(s.battery_voltage * 10), - battery_voltage_scc=int(s.battery_voltage_scc * 10), - battery_voltage_scc2=int(s.battery_voltage_scc2 * 10), - battery_discharge_current=s.battery_discharge_current, - battery_charge_current=s.battery_charge_current, - battery_capacity=s.battery_capacity, - inverter_heat_sink_temp=s.inverter_heat_sink_temp, - mppt1_charger_temp=s.mppt1_charger_temp, - mppt2_charger_temp=s.mppt2_charger_temp, - pv1_input_power=s.pv1_input_power, - pv2_input_power=s.pv2_input_power, - pv1_input_voltage=int(s.pv1_input_voltage * 10), - pv2_input_voltage=int(s.pv2_input_voltage * 10), - mppt1_charger_status=s.mppt1_charger_status, - mppt2_charger_status=s.mppt2_charger_status, - battery_power_direction=s.battery_power_direction, - dc_ac_power_direction=s.dc_ac_power_direction, - line_power_direction=s.line_power_direction, - load_connected=s.load_connected) - return s - - elif topic == GENERATION_TOPIC: - gen = MqttInverterGenerationPayload.unpack(payload) - self._database.add_generation(home_id, gen.time, gen.wh) - return gen diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py deleted file mode 100644 index cd34332..0000000 --- a/src/home/mqtt/module/ota.py +++ /dev/null @@ -1,77 +0,0 @@ -import hashlib - -from typing import Optional -from .._payload import MqttPayload -from .._node import MqttModule, MqttNode - -MODULE_NAME = 'MqttOtaModule' - - -class OtaResultPayload(MqttPayload): - FORMAT = '=BB' - result: int - error_code: int - - -class OtaPayload(MqttPayload): - secret: str - filename: str - - # structure of returned data: - # - # uint8_t[len(secret)] secret; - # uint8_t[16] md5; - # *uint8_t data - - def pack(self): - buf = bytearray(self.secret.encode()) - m = hashlib.md5() - with open(self.filename, 'rb') as fd: - content = fd.read() - m.update(content) - buf.extend(m.digest()) - buf.extend(content) - return buf - - def unpack(cls, buf: bytes): - raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') - # secret = buf[:12].decode() - # filename = buf[12:].decode() - # return OTAPayload(secret=secret, filename=filename) - - -class MqttOtaModule(MqttModule): - _ota_request: Optional[tuple[str, int]] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._ota_request = None - - def on_connect(self, mqtt: MqttNode): - super().on_connect(mqtt) - mqtt.subscribe_module("otares", self) - - if self._ota_request is not None: - filename, qos = self._ota_request - self._ota_request = None - self.do_push_ota(self._mqtt_node_ref.secret, filename, qos) - - def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: - if topic == 'otares': - message = OtaResultPayload.unpack(payload) - self._logger.debug(message) - return message - - def do_push_ota(self, secret: str, filename: str, qos: int): - payload = OtaPayload(secret=secret, filename=filename) - self._mqtt_node_ref.publish('ota', - payload=payload.pack(), - qos=qos) - - def push_ota(self, - filename: str, - qos: int): - if not self._initialized: - self._ota_request = (filename, qos) - else: - self.do_push_ota(filename, qos) diff --git a/src/home/mqtt/module/relay.py b/src/home/mqtt/module/relay.py deleted file mode 100644 index e968031..0000000 --- a/src/home/mqtt/module/relay.py +++ /dev/null @@ -1,92 +0,0 @@ -import datetime - -from typing import Optional -from .. import MqttModule, MqttPayload, MqttNode - -MODULE_NAME = 'MqttRelayModule' - - -class MqttPowerSwitchPayload(MqttPayload): - FORMAT = '=12sB' - PACKER = { - 'state': lambda n: int(n), - 'secret': lambda s: s.encode('utf-8') - } - UNPACKER = { - 'state': lambda n: bool(n), - 'secret': lambda s: s.decode('utf-8') - } - - secret: str - state: bool - - -class MqttPowerStatusPayload(MqttPayload): - FORMAT = '=B' - PACKER = { - 'opened': lambda n: int(n), - } - UNPACKER = { - 'opened': lambda n: bool(n), - } - - opened: bool - - -class MqttRelayState: - enabled: bool - update_time: datetime.datetime - rssi: int - fw_version: int - ever_updated: bool - - def __init__(self): - self.ever_updated = False - self.enabled = False - self.rssi = 0 - - def update(self, - enabled: bool, - rssi: int, - fw_version=None): - self.ever_updated = True - self.enabled = enabled - self.rssi = rssi - self.update_time = datetime.datetime.now() - if fw_version: - self.fw_version = fw_version - - -class MqttRelayModule(MqttModule): - _legacy_topics: bool - - def __init__(self, legacy_topics=False, *args, **kwargs): - super().__init__(*args, **kwargs) - self._legacy_topics = legacy_topics - - def on_connect(self, mqtt: MqttNode): - super().on_connect(mqtt) - mqtt.subscribe_module(self._get_switch_topic(), self) - mqtt.subscribe_module('relay/status', self) - - def switchpower(self, - enable: bool): - payload = MqttPowerSwitchPayload(secret=self._mqtt_node_ref.secret, - state=enable) - self._mqtt_node_ref.publish(self._get_switch_topic(), - payload=payload.pack()) - - def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: - message = None - - if topic == self._get_switch_topic(): - message = MqttPowerSwitchPayload.unpack(payload) - elif topic == 'relay/status': - message = MqttPowerStatusPayload.unpack(payload) - - if message is not None: - self._logger.debug(message) - return message - - def _get_switch_topic(self) -> str: - return 'relay/power' if self._legacy_topics else 'relay/switch' diff --git a/src/home/mqtt/module/temphum.py b/src/home/mqtt/module/temphum.py deleted file mode 100644 index fd02cca..0000000 --- a/src/home/mqtt/module/temphum.py +++ /dev/null @@ -1,82 +0,0 @@ -from .._node import MqttNode -from .._module import MqttModule -from .._payload import MqttPayload -from typing import Optional -from ...temphum import BaseSensor - -two_digits_precision = lambda x: round(x, 2) - -MODULE_NAME = 'MqttTempHumModule' -DATA_TOPIC = 'temphum/data' - - -class MqttTemphumDataPayload(MqttPayload): - FORMAT = '=ddb' - UNPACKER = { - 'temp': two_digits_precision, - 'rh': two_digits_precision - } - - temp: float - rh: float - error: int - - -# class MqttTempHumNodes(HashableEnum): -# KBN_SH_HALL = auto() -# KBN_SH_BATHROOM = auto() -# KBN_SH_LIVINGROOM = auto() -# KBN_SH_BEDROOM = auto() -# -# KBN_BH_2FL = auto() -# KBN_BH_2FL_STREET = auto() -# KBN_BH_1FL_LIVINGROOM = auto() -# KBN_BH_1FL_BEDROOM = auto() -# KBN_BH_1FL_BATHROOM = auto() -# -# KBN_NH_1FL_INV = auto() -# KBN_NH_1FL_CENTER = auto() -# KBN_NH_1LF_KT = auto() -# KBN_NH_1FL_DS = auto() -# KBN_NH_1FS_EZ = auto() -# -# SPB_FLAT120_CABINET = auto() - - -class MqttTempHumModule(MqttModule): - def __init__(self, - sensor: Optional[BaseSensor] = None, - write_to_database=False, - *args, **kwargs): - if sensor is not None: - kwargs['tick_interval'] = 10 - super().__init__(*args, **kwargs) - self._sensor = sensor - - def on_connect(self, mqtt: MqttNode): - super().on_connect(mqtt) - mqtt.subscribe_module(DATA_TOPIC, self) - - def tick(self): - if not self._sensor: - return - - error = 0 - temp = 0 - rh = 0 - try: - temp = self._sensor.temperature() - rh = self._sensor.humidity() - except: - error = 1 - pld = MqttTemphumDataPayload(temp=temp, rh=rh, error=error) - self._mqtt_node_ref.publish(DATA_TOPIC, pld.pack()) - - def handle_payload(self, - mqtt: MqttNode, - topic: str, - payload: bytes) -> Optional[MqttPayload]: - if topic == DATA_TOPIC: - message = MqttTemphumDataPayload.unpack(payload) - self._logger.debug(message) - return message diff --git a/src/home/pio/__init__.py b/src/home/pio/__init__.py deleted file mode 100644 index 7216bc4..0000000 --- a/src/home/pio/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .products import get_products, platformio_ini \ No newline at end of file diff --git a/src/home/pio/exceptions.py b/src/home/pio/exceptions.py deleted file mode 100644 index a6afd20..0000000 --- a/src/home/pio/exceptions.py +++ /dev/null @@ -1,2 +0,0 @@ -class ProductConfigNotFoundError(Exception): - pass diff --git a/src/home/pio/products.py b/src/home/pio/products.py deleted file mode 100644 index 388da03..0000000 --- a/src/home/pio/products.py +++ /dev/null @@ -1,113 +0,0 @@ -import os -import logging - -from io import StringIO -from collections import OrderedDict - - -_logger = logging.getLogger(__name__) -_products_dir = os.path.join( - os.path.dirname(__file__), - '..', '..', '..', - 'platformio' -) - - -def get_products(): - products = [] - for f in os.listdir(_products_dir): - if f in ('common',): - continue - - if os.path.isdir(os.path.join(_products_dir, f)): - products.append(f) - - return products - - -def platformio_ini(product_config: dict, - target: str, - # node_id: str, - build_specific_defines: dict, - build_specific_defines_enums: list[str], - platform: str, - framework: str = 'arduino', - upload_port: str = '/dev/ttyUSB0', - monitor_speed: int = 115200, - debug=False, - debug_network=False) -> str: - node_id = build_specific_defines['CONFIG_NODE_ID'] - - # defines - defines = { - **product_config['common_defines'], - 'CONFIG_NODE_ID': node_id, - 'CONFIG_WIFI_AP_SSID': ('HK_'+node_id)[:31] - } - try: - defines.update(product_config['target_defines'][target]) - except KeyError: - pass - defines['CONFIG_NODE_SECRET_SIZE'] = len(defines['CONFIG_NODE_SECRET']) - defines['CONFIG_MQTT_CLIENT_ID'] = node_id - - build_type = 'release' - if debug: - defines['DEBUG'] = True - build_type = 'debug' - if debug_network: - defines['DEBUG'] = True - defines['DEBUG_ESP_SSL'] = True - defines['DEBUG_ESP_PORT'] = 'Serial' - build_type = 'debug' - if build_specific_defines: - for k, v in build_specific_defines.items(): - defines[k] = v - defines = OrderedDict(sorted(defines.items(), key=lambda t: t[0])) - - # libs - libs = [] - if 'common_libs' in product_config: - libs.extend(product_config['common_libs']) - if 'target_libs' in product_config and target in product_config['target_libs']: - libs.extend(product_config['target_libs'][target]) - libs = list(set(libs)) - libs.sort() - - try: - target_real_name = product_config['target_board_names'][target] - except KeyError: - target_real_name = target - - buf = StringIO() - - buf.write('; Generated by pio_ini.py\n\n') - buf.write(f'[env:{target_real_name}]\n') - buf.write(f'platform = {platform}\n') - buf.write(f'board = {target_real_name}\n') - buf.write(f'framework = {framework}\n') - buf.write(f'upload_port = {upload_port}\n') - buf.write(f'monitor_speed = {monitor_speed}\n') - if libs: - buf.write(f'lib_deps =') - for lib in libs: - buf.write(f' {lib}\n') - buf.write(f'build_flags =\n') - if defines: - for name, value in defines.items(): - buf.write(f' -D{name}') - is_enum = name in build_specific_defines_enums - if type(value) is not bool: - buf.write('=') - if type(value) is str: - if not is_enum: - buf.write('"\\"') - value = value.replace('"', '\\"') - buf.write(f'{value}') - if type(value) is str and not is_enum: - buf.write('"\\"') - buf.write('\n') - buf.write(f' -I../common/include') - buf.write(f'\nbuild_type = {build_type}') - - return buf.getvalue() diff --git a/src/home/relay/__init__.py b/src/home/relay/__init__.py deleted file mode 100644 index 406403d..0000000 --- a/src/home/relay/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -import importlib - -__all__ = ['RelayClient', 'RelayServer'] - - -def __getattr__(name): - _map = { - 'RelayClient': '.sunxi_h3_client', - 'RelayServer': '.sunxi_h3_server' - } - - if name in __all__: - module = importlib.import_module(_map[name], __name__) - return getattr(module, name) - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/home/relay/__init__.pyi b/src/home/relay/__init__.pyi deleted file mode 100644 index 7a4a2f4..0000000 --- a/src/home/relay/__init__.pyi +++ /dev/null @@ -1,2 +0,0 @@ -from .sunxi_h3_client import RelayClient as RelayClient -from .sunxi_h3_server import RelayServer as RelayServer diff --git a/src/home/relay/sunxi_h3_client.py b/src/home/relay/sunxi_h3_client.py deleted file mode 100644 index 8c8d6c4..0000000 --- a/src/home/relay/sunxi_h3_client.py +++ /dev/null @@ -1,39 +0,0 @@ -import socket - - -class RelayClient: - def __init__(self, port=8307, host='127.0.0.1'): - self._host = host - self._port = port - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - - def __del__(self): - self.sock.close() - - def connect(self): - self.sock.connect((self._host, self._port)) - - def _write(self, line): - self.sock.sendall((line+'\r\n').encode()) - - def _read(self): - buf = bytearray() - while True: - buf.extend(self.sock.recv(256)) - if b'\r\n' in buf: - break - - response = buf.decode().strip() - return response - - def on(self): - self._write('on') - return self._read() - - def off(self): - self._write('off') - return self._read() - - def status(self): - self._write('get') - return self._read() diff --git a/src/home/relay/sunxi_h3_server.py b/src/home/relay/sunxi_h3_server.py deleted file mode 100644 index 1f33969..0000000 --- a/src/home/relay/sunxi_h3_server.py +++ /dev/null @@ -1,82 +0,0 @@ -import asyncio -import logging - -from pyA20.gpio import gpio -from pyA20.gpio import port as gpioport -from ..util import Addr - -logger = logging.getLogger(__name__) - - -class RelayServer: - OFF = 1 - ON = 0 - - def __init__(self, - pinname: str, - addr: Addr): - if not hasattr(gpioport, pinname): - raise ValueError(f'invalid pin {pinname}') - - self.pin = getattr(gpioport, pinname) - self.addr = addr - - gpio.init() - gpio.setcfg(self.pin, gpio.OUTPUT) - - self.lock = asyncio.Lock() - - def run(self): - asyncio.run(self.run_server()) - - async def relay_set(self, value): - async with self.lock: - gpio.output(self.pin, value) - - async def relay_get(self): - async with self.lock: - return int(gpio.input(self.pin)) == RelayServer.ON - - async def handle_client(self, reader, writer): - request = None - while request != 'quit': - try: - request = await reader.read(255) - if request == b'\x04': - break - request = request.decode('utf-8').strip() - except Exception: - break - - data = 'unknown' - if request == 'on': - await self.relay_set(RelayServer.ON) - logger.debug('set on') - data = 'ok' - - elif request == 'off': - await self.relay_set(RelayServer.OFF) - logger.debug('set off') - data = 'ok' - - elif request == 'get': - status = await self.relay_get() - data = 'on' if status is True else 'off' - - writer.write((data + '\r\n').encode('utf-8')) - try: - await writer.drain() - except ConnectionError: - break - - try: - writer.close() - except ConnectionError: - pass - - async def run_server(self): - host, port = self.addr - server = await asyncio.start_server(self.handle_client, host, port) - async with server: - logger.info('Server started.') - await server.serve_forever() diff --git a/src/home/soundsensor/__init__.py b/src/home/soundsensor/__init__.py deleted file mode 100644 index 30052f8..0000000 --- a/src/home/soundsensor/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -import importlib - -__all__ = [ - 'SoundSensorNode', - 'SoundSensorHitHandler', - 'SoundSensorServer', - 'SoundSensorServerGuardClient' -] - - -def __getattr__(name): - if name in __all__: - if name == 'SoundSensorNode': - file = 'node' - elif name == 'SoundSensorServerGuardClient': - file = 'server_client' - else: - file = 'server' - module = importlib.import_module(f'.{file}', __name__) - return getattr(module, name) - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/home/soundsensor/__init__.pyi b/src/home/soundsensor/__init__.pyi deleted file mode 100644 index cb34972..0000000 --- a/src/home/soundsensor/__init__.pyi +++ /dev/null @@ -1,8 +0,0 @@ -from .server import ( - SoundSensorHitHandler as SoundSensorHitHandler, - SoundSensorServer as SoundSensorServer, -) -from .server_client import ( - SoundSensorServerGuardClient as SoundSensorServerGuardClient -) -from .node import SoundSensorNode as SoundSensorNode diff --git a/src/home/soundsensor/node.py b/src/home/soundsensor/node.py deleted file mode 100644 index 292452f..0000000 --- a/src/home/soundsensor/node.py +++ /dev/null @@ -1,75 +0,0 @@ -import logging -import threading - -from typing import Optional -from time import sleep -from ..util import stringify, send_datagram, Addr - -from pyA20.gpio import gpio -from pyA20.gpio import port as gpioport - -logger = logging.getLogger(__name__) - - -class SoundSensorNode: - def __init__(self, - name: str, - pinname: str, - server_addr: Optional[Addr], - threshold: int = 1, - delay=0.005): - - if not hasattr(gpioport, pinname): - raise ValueError(f'invalid pin {pinname}') - - self.pin = getattr(gpioport, pinname) - self.name = name - self.delay = delay - self.threshold = threshold - - self.server_addr = server_addr - - self.hits = 0 - self.hitlock = threading.Lock() - - self.interrupted = False - - def run(self): - try: - t = threading.Thread(target=self.sensor_reader) - t.daemon = True - t.start() - - while True: - with self.hitlock: - hits = self.hits - self.hits = 0 - - if hits >= self.threshold: - try: - if self.server_addr is not None: - send_datagram(stringify([self.name, hits]), self.server_addr) - else: - logger.debug(f'server reporting disabled, skipping reporting {hits} hits') - except OSError as exc: - logger.exception(exc) - - sleep(1) - - except (KeyboardInterrupt, SystemExit) as e: - self.interrupted = True - logger.info(str(e)) - - def sensor_reader(self): - gpio.init() - gpio.setcfg(self.pin, gpio.INPUT) - gpio.pullup(self.pin, gpio.PULLUP) - - while not self.interrupted: - state = gpio.input(self.pin) - sleep(self.delay) - - if not state: - with self.hitlock: - logger.debug('got a hit') - self.hits += 1 diff --git a/src/home/soundsensor/server.py b/src/home/soundsensor/server.py deleted file mode 100644 index a627390..0000000 --- a/src/home/soundsensor/server.py +++ /dev/null @@ -1,128 +0,0 @@ -import asyncio -import json -import logging -import threading - -from ..database.sqlite import SQLiteBase -from ..config import config -from .. import http - -from typing import Type -from ..util import Addr - -logger = logging.getLogger(__name__) - - -class SoundSensorHitHandler(asyncio.DatagramProtocol): - def datagram_received(self, data, addr): - try: - data = json.loads(data) - except json.JSONDecodeError as e: - logger.error('failed to parse json datagram') - logger.exception(e) - return - - try: - name, hits = data - except (ValueError, IndexError) as e: - logger.error('failed to unpack data') - logger.exception(e) - return - - self.handler(name, hits) - - def handler(self, name: str, hits: int): - pass - - -class Database(SQLiteBase): - SCHEMA = 1 - - def __init__(self): - super().__init__(dbname='sound_sensor_server') - - def schema_init(self, version: int) -> None: - cursor = self.cursor() - - if version < 1: - cursor.execute("CREATE TABLE IF NOT EXISTS status (guard_enabled INTEGER NOT NULL)") - cursor.execute("INSERT INTO status (guard_enabled) VALUES (-1)") - - self.commit() - - def get_guard_enabled(self) -> int: - cur = self.cursor() - cur.execute("SELECT guard_enabled FROM status LIMIT 1") - return int(cur.fetchone()[0]) - - def set_guard_enabled(self, enabled: bool) -> None: - cur = self.cursor() - cur.execute("UPDATE status SET guard_enabled=?", (int(enabled),)) - self.commit() - - -class SoundSensorServer: - def __init__(self, - addr: Addr, - handler_impl: Type[SoundSensorHitHandler]): - self.addr = addr - self.impl = handler_impl - self.db = Database() - - self._recording_lock = threading.Lock() - self._recording_enabled = True - - if self.guard_control_enabled(): - current_status = self.db.get_guard_enabled() - if current_status == -1: - self.set_recording(config['server']['guard_recording_default'] - if 'guard_recording_default' in config['server'] - else False, - update=False) - else: - self.set_recording(bool(current_status), update=False) - - @staticmethod - def guard_control_enabled() -> bool: - return 'guard_control' in config['server'] and config['server']['guard_control'] is True - - def set_recording(self, enabled: bool, update=True): - with self._recording_lock: - self._recording_enabled = enabled - if update: - self.db.set_guard_enabled(enabled) - - def is_recording_enabled(self) -> bool: - with self._recording_lock: - return self._recording_enabled - - def run(self): - if self.guard_control_enabled(): - t = threading.Thread(target=self.run_guard_server) - t.daemon = True - t.start() - - loop = asyncio.get_event_loop() - t = loop.create_datagram_endpoint(self.impl, local_addr=self.addr) - loop.run_until_complete(t) - loop.run_forever() - - def run_guard_server(self): - routes = http.routes() - - @routes.post('/guard/enable') - async def guard_enable(request): - self.set_recording(True) - return http.ok() - - @routes.post('/guard/disable') - async def guard_disable(request): - self.set_recording(False) - return http.ok() - - @routes.get('/guard/status') - async def guard_status(request): - return http.ok({'enabled': self.is_recording_enabled()}) - - asyncio.set_event_loop(asyncio.new_event_loop()) # need to create new event loop in new thread - http.serve(self.addr, routes, handle_signals=False) # handle_signals=True doesn't work in separate thread diff --git a/src/home/soundsensor/server_client.py b/src/home/soundsensor/server_client.py deleted file mode 100644 index 7eef996..0000000 --- a/src/home/soundsensor/server_client.py +++ /dev/null @@ -1,38 +0,0 @@ -import requests -import logging - -from ..util import Addr -from ..api.errors import ApiResponseError - - -class SoundSensorServerGuardClient: - def __init__(self, addr: Addr): - self.endpoint = f'http://{addr[0]}:{addr[1]}' - self.logger = logging.getLogger(self.__class__.__name__) - - def guard_enable(self): - return self._call('guard/enable', is_post=True) - - def guard_disable(self): - return self._call('guard/disable', is_post=True) - - def guard_status(self): - return self._call('guard/status') - - def _call(self, - method: str, - is_post=False): - - url = f'{self.endpoint}/{method}' - self.logger.debug(f'calling {url}') - - r = requests.get(url) if not is_post else requests.post(url) - - if r.status_code != 200: - response = r.json() - raise ApiResponseError(status_code=r.status_code, - error_type=response['error'], - error_message=response['message'] or None, - error_stacktrace=response['stacktrace'] if 'stacktrace' in response else None) - - return r.json()['response'] diff --git a/src/home/telegram/__init__.py b/src/home/telegram/__init__.py deleted file mode 100644 index a68dae1..0000000 --- a/src/home/telegram/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .telegram import send_message, send_photo diff --git a/src/home/telegram/_botcontext.py b/src/home/telegram/_botcontext.py deleted file mode 100644 index a143bfe..0000000 --- a/src/home/telegram/_botcontext.py +++ /dev/null @@ -1,86 +0,0 @@ -from typing import Optional, List - -from telegram import Update, User, CallbackQuery -from telegram.constants import ParseMode -from telegram.ext import CallbackContext - -from ._botdb import BotDatabase -from ._botlang import lang -from ._botutil import IgnoreMarkup, exc2text - - -class Context: - _update: Optional[Update] - _callback_context: Optional[CallbackContext] - _markup_getter: callable - db: Optional[BotDatabase] - _user_lang: Optional[str] - - def __init__(self, - update: Optional[Update], - callback_context: Optional[CallbackContext], - markup_getter: callable, - store: Optional[BotDatabase]): - self._update = update - self._callback_context = callback_context - self._markup_getter = markup_getter - self._store = store - self._user_lang = None - - async def reply(self, text, markup=None): - if markup is None: - markup = self._markup_getter(self) - kwargs = dict(parse_mode=ParseMode.HTML) - if not isinstance(markup, IgnoreMarkup): - kwargs['reply_markup'] = markup - return await self._update.message.reply_text(text, **kwargs) - - async def reply_exc(self, e: Exception) -> None: - await self.reply(exc2text(e), markup=IgnoreMarkup()) - - async def answer(self, text: str = None): - await self.callback_query.answer(text) - - async def edit(self, text, markup=None): - kwargs = dict(parse_mode=ParseMode.HTML) - if not isinstance(markup, IgnoreMarkup): - kwargs['reply_markup'] = markup - await self.callback_query.edit_message_text(text, **kwargs) - - @property - def text(self) -> str: - return self._update.message.text - - @property - def callback_query(self) -> CallbackQuery: - return self._update.callback_query - - @property - def args(self) -> Optional[List[str]]: - return self._callback_context.args - - @property - def user_id(self) -> int: - return self.user.id - - @property - def user_data(self): - return self._callback_context.user_data - - @property - def user(self) -> User: - return self._update.effective_user - - @property - def user_lang(self) -> str: - if self._user_lang is None: - self._user_lang = self._store.get_user_lang(self.user_id) - return self._user_lang - - def lang(self, key: str, *args) -> str: - return lang.get(key, self.user_lang, *args) - - def is_callback_context(self) -> bool: - return self._update.callback_query \ - and self._update.callback_query.data \ - and self._update.callback_query.data != '' diff --git a/src/home/telegram/_botdb.py b/src/home/telegram/_botdb.py deleted file mode 100644 index 9e9cf94..0000000 --- a/src/home/telegram/_botdb.py +++ /dev/null @@ -1,32 +0,0 @@ -from home.database.sqlite import SQLiteBase - - -class BotDatabase(SQLiteBase): - def __init__(self): - super().__init__() - - def schema_init(self, version: int) -> None: - if version < 1: - cursor = self.cursor() - cursor.execute("""CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY, - lang TEXT NOT NULL - )""") - self.commit() - - def get_user_lang(self, user_id: int, default: str = 'en') -> str: - cursor = self.cursor() - cursor.execute('SELECT lang FROM users WHERE id=?', (user_id,)) - row = cursor.fetchone() - - if row is None: - cursor.execute('INSERT INTO users (id, lang) VALUES (?, ?)', (user_id, default)) - self.commit() - return default - else: - return row[0] - - def set_user_lang(self, user_id: int, lang: str) -> None: - cursor = self.cursor() - cursor.execute('UPDATE users SET lang=? WHERE id=?', (lang, user_id)) - self.commit() diff --git a/src/home/telegram/_botlang.py b/src/home/telegram/_botlang.py deleted file mode 100644 index f5f85bb..0000000 --- a/src/home/telegram/_botlang.py +++ /dev/null @@ -1,120 +0,0 @@ -import logging - -from typing import Optional, Dict, List, Union - -_logger = logging.getLogger(__name__) - - -class LangStrings(dict): - _lang: Optional[str] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._lang = None - - def setlang(self, lang: str): - self._lang = lang - - def __missing__(self, key): - _logger.warning(f'key {key} is missing in language {self._lang}') - return '{%s}' % key - - def __setitem__(self, key, value): - raise NotImplementedError(f'setting translation strings this way is prohibited (was trying to set {key}={value})') - - -class LangPack: - strings: Dict[str, LangStrings[str, str]] - default_lang: str - - def __init__(self): - self.strings = {} - self.default_lang = 'en' - - def ru(self, **kwargs) -> None: - self.set(kwargs, 'ru') - - def en(self, **kwargs) -> None: - self.set(kwargs, 'en') - - def set(self, - strings: Union[LangStrings, dict], - lang: str) -> None: - - if isinstance(strings, dict) and not isinstance(strings, LangStrings): - strings = LangStrings(**strings) - strings.setlang(lang) - - if lang not in self.strings: - self.strings[lang] = strings - else: - self.strings[lang].update(strings) - - def all(self, key): - result = [] - for strings in self.strings.values(): - result.append(strings[key]) - return result - - @property - def languages(self) -> List[str]: - return list(self.strings.keys()) - - def get(self, key: str, lang: str, *args) -> str: - if args: - return self.strings[lang][key] % args - else: - return self.strings[lang][key] - - def get_langpack(self, _lang: str) -> dict: - return self.strings[_lang] - - def __call__(self, *args, **kwargs): - return self.strings[self.default_lang][args[0]] - - def __getitem__(self, key): - return self.strings[self.default_lang][key] - - def __setitem__(self, key, value): - raise NotImplementedError('setting translation strings this way is prohibited') - - def __contains__(self, key): - return key in self.strings[self.default_lang] - - @staticmethod - def pfx(prefix: str, l: list) -> list: - return list(map(lambda s: f'{prefix}{s}', l)) - - - -languages = { - 'en': 'English', - 'ru': 'Русский' -} - - -lang = LangPack() -lang.en( - en='English', - ru='Russian', - start_message="Select command on the keyboard.", - unknown_message="Unknown message", - cancel="🚫 Cancel", - back='🔙 Back', - select_language="Select language on the keyboard.", - invalid_language="Invalid language. Please try again.", - saved='Saved.', - please_wait="⏳ Please wait..." -) -lang.ru( - en='Английский', - ru='Русский', - start_message="Выберите команду на клавиатуре.", - unknown_message="Неизвестная команда", - cancel="🚫 Отмена", - back='🔙 Назад', - select_language="Выберите язык на клавиатуре.", - invalid_language="Неверный язык. Пожалуйста, попробуйте снова", - saved="Настройки сохранены.", - please_wait="⏳ Ожидайте..." -) \ No newline at end of file diff --git a/src/home/telegram/_botutil.py b/src/home/telegram/_botutil.py deleted file mode 100644 index b551a55..0000000 --- a/src/home/telegram/_botutil.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging -import traceback - -from html import escape -from telegram import User -from home.api import WebApiClient as APIClient -from home.api.types import BotType -from home.api.errors import ApiResponseError - -_logger = logging.getLogger(__name__) - - -def user_any_name(user: User) -> str: - name = [user.first_name, user.last_name] - name = list(filter(lambda s: s is not None, name)) - name = ' '.join(name).strip() - - if not name: - name = user.username - - if not name: - name = str(user.id) - - return name - - -class ReportingHelper: - def __init__(self, client: APIClient, bot_type: BotType): - self.client = client - self.bot_type = bot_type - - def report(self, message, text: str = None) -> None: - if text is None: - text = message.text - try: - self.client.log_bot_request(self.bot_type, message.chat_id, text) - except ApiResponseError as error: - _logger.exception(error) - - -def exc2text(e: Exception) -> str: - tb = ''.join(traceback.format_tb(e.__traceback__)) - return f'{e.__class__.__name__}: ' + escape(str(e)) + "\n\n" + escape(tb) - - -class IgnoreMarkup: - pass diff --git a/src/home/telegram/aio.py b/src/home/telegram/aio.py deleted file mode 100644 index fc87c1c..0000000 --- a/src/home/telegram/aio.py +++ /dev/null @@ -1,18 +0,0 @@ -import functools -import asyncio - -from .telegram import ( - send_message as _send_message_sync, - send_photo as _send_photo_sync -) - - -async def send_message(*args, **kwargs): - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, functools.partial(_send_message_sync, *args, **kwargs)) - - -async def send_photo(*args, **kwargs): - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, functools.partial(_send_photo_sync, *args, **kwargs)) - diff --git a/src/home/telegram/bot.py b/src/home/telegram/bot.py deleted file mode 100644 index e6ebc6e..0000000 --- a/src/home/telegram/bot.py +++ /dev/null @@ -1,583 +0,0 @@ -from __future__ import annotations - -import logging -import itertools - -from enum import Enum, auto -from functools import wraps -from typing import Optional, Union, Tuple, Coroutine - -from telegram import Update, ReplyKeyboardMarkup -from telegram.ext import ( - Application, - filters, - CommandHandler, - MessageHandler, - CallbackQueryHandler, - CallbackContext, - ConversationHandler -) -from telegram.ext.filters import BaseFilter -from telegram.error import TimedOut - -from home.config import config -from home.api import WebApiClient -from home.api.types import BotType - -from ._botlang import lang, languages -from ._botdb import BotDatabase -from ._botutil import ReportingHelper, exc2text, IgnoreMarkup, user_any_name -from ._botcontext import Context - - -db: Optional[BotDatabase] = None - -_user_filter: Optional[BaseFilter] = None -_cancel_filter = filters.Text(lang.all('cancel')) -_back_filter = filters.Text(lang.all('back')) -_cancel_and_back_filter = filters.Text(lang.all('back') + lang.all('cancel')) - -_logger = logging.getLogger(__name__) -_application: Optional[Application] = None -_reporting: Optional[ReportingHelper] = None -_exception_handler: Optional[Coroutine] = None -_dispatcher = None -_markup_getter: Optional[callable] = None -_start_handler_ref: Optional[Coroutine] = None - - -def text_filter(*args): - if not _user_filter: - raise RuntimeError('user_filter is not initialized') - return filters.Text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter - - -async def _handler_of_handler(*args, **kwargs): - self = None - context = None - update = None - - _args = list(args) - while len(_args): - v = _args[0] - if isinstance(v, conversation): - self = v - _args.pop(0) - elif isinstance(v, Update): - update = v - _args.pop(0) - elif isinstance(v, CallbackContext): - context = v - _args.pop(0) - break - - ctx = Context(update, - callback_context=context, - markup_getter=lambda _ctx: None if not _markup_getter else _markup_getter(_ctx), - store=db) - try: - _args.insert(0, ctx) - - f = kwargs['f'] - del kwargs['f'] - - if 'return_with_context' in kwargs: - return_with_context = True - del kwargs['return_with_context'] - else: - return_with_context = False - - if 'argument' in kwargs and kwargs['argument'] == 'message_key': - del kwargs['argument'] - mkey = None - for k, v in lang.get_langpack(ctx.user_lang).items(): - if ctx.text == v: - mkey = k - break - _args.insert(0, mkey) - - if self: - _args.insert(0, self) - - result = await f(*_args, **kwargs) - return result if not return_with_context else (result, ctx) - - except Exception as e: - if _exception_handler: - if not _exception_handler(e, ctx) and not isinstance(e, TimedOut): - _logger.exception(e) - if not ctx.is_callback_context(): - await ctx.reply_exc(e) - else: - notify_user(ctx.user_id, exc2text(e)) - else: - _logger.exception(e) - - -def handler(**kwargs): - def inner(f): - @wraps(f) - async def _handler(*args, **inner_kwargs): - if 'argument' in kwargs and kwargs['argument'] == 'message_key': - inner_kwargs['argument'] = 'message_key' - return await _handler_of_handler(f=f, *args, **inner_kwargs) - - messages = [] - texts = [] - - if 'messages' in kwargs: - messages += kwargs['messages'] - if 'message' in kwargs: - messages.append(kwargs['message']) - - if 'text' in kwargs: - texts.append(kwargs['text']) - if 'texts' in kwargs: - texts += kwargs['texts'] - - if messages or texts: - new_messages = list(itertools.chain.from_iterable([lang.all(m) for m in messages])) - texts += new_messages - texts = list(set(texts)) - _application.add_handler( - MessageHandler(text_filter(*texts), _handler), - group=0 - ) - - if 'command' in kwargs: - _application.add_handler(CommandHandler(kwargs['command'], _handler), group=0) - - if 'callback' in kwargs: - _application.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) - - return _handler - - return inner - - -def simplehandler(f: Coroutine): - @wraps(f) - async def _handler(*args, **kwargs): - return await _handler_of_handler(f=f, *args, **kwargs) - return _handler - - -def callbackhandler(*args, **kwargs): - def inner(f): - @wraps(f) - async def _handler(*args, **kwargs): - return await _handler_of_handler(f=f, *args, **kwargs) - pattern_kwargs = {} - if kwargs['callback'] != '*': - pattern_kwargs['pattern'] = kwargs['callback'] - _application.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) - return _handler - return inner - - -async def exceptionhandler(f: callable): - global _exception_handler - if _exception_handler: - _logger.warning('exception handler already set, we will overwrite it') - _exception_handler = f - - -def defaultreplymarkup(f: callable): - global _markup_getter - _markup_getter = f - - -def convinput(state, is_enter=False, **kwargs): - def inner(f): - f.__dict__['_conv_data'] = dict( - orig_f=f, - enter=is_enter, - type=ConversationMethodType.ENTRY if is_enter and state == 0 else ConversationMethodType.STATE_HANDLER, - state=state, - **kwargs - ) - - @wraps(f) - async def _impl(*args, **kwargs): - result, ctx = await _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) - if result == conversation.END: - await start(ctx) - return result - - return _impl - - return inner - - -def conventer(state, **kwargs): - return convinput(state, is_enter=True, **kwargs) - - -class ConversationMethodType(Enum): - ENTRY = auto() - STATE_HANDLER = auto() - - -class conversation: - END = ConversationHandler.END - STATE_SEQS = [] - - def __init__(self, enable_back=False): - self._logger = logging.getLogger(self.__class__.__name__) - self._user_state_cache = {} - self._back_enabled = enable_back - - def make_handlers(self, f: callable, **kwargs) -> list: - messages = {} - handlers = [] - - if 'messages' in kwargs: - if isinstance(kwargs['messages'], dict): - messages = kwargs['messages'] - else: - for m in kwargs['messages']: - messages[m] = None - - if 'message' in kwargs: - if isinstance(kwargs['message'], str): - messages[kwargs['message']] = None - else: - AttributeError('invalid message type: ' + type(kwargs['message'])) - - if messages: - for message, target_state in messages.items(): - if not target_state: - handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), f)) - else: - handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), self.make_invoker(target_state))) - - if 'regex' in kwargs: - handlers.append(MessageHandler(filters.Regex(kwargs['regex']) & _user_filter, f)) - - if 'command' in kwargs: - handlers.append(CommandHandler(kwargs['command'], f, _user_filter)) - - return handlers - - def make_invoker(self, state): - def _invoke(update: Update, context: CallbackContext): - ctx = Context(update, - callback_context=context, - markup_getter=lambda _ctx: None if not _markup_getter else _markup_getter(_ctx), - store=db) - return self.invoke(state, ctx) - return _invoke - - def invoke(self, state, ctx: Context): - self._logger.debug(f'invoke, state={state}') - for item in dir(self): - f = getattr(self, item) - if not callable(f) or item.startswith('_') or '_conv_data' not in f.__dict__: - continue - cd = f.__dict__['_conv_data'] - if cd['enter'] and cd['state'] == state: - return cd['orig_f'](self, ctx) - - raise RuntimeError(f'invoke: failed to find method for state {state}') - - def get_handler(self) -> ConversationHandler: - entry_points = [] - states = {} - - l_cancel_filter = _cancel_filter if not self._back_enabled else _cancel_and_back_filter - - for item in dir(self): - f = getattr(self, item) - if not callable(f) or item.startswith('_') or '_conv_data' not in f.__dict__: - continue - - cd = f.__dict__['_conv_data'] - - if cd['type'] == ConversationMethodType.ENTRY: - entry_points = self.make_handlers(f, **cd) - elif cd['type'] == ConversationMethodType.STATE_HANDLER: - states[cd['state']] = self.make_handlers(f, **cd) - states[cd['state']].append( - MessageHandler(_user_filter & ~l_cancel_filter, conversation.invalid) - ) - - fallbacks = [MessageHandler(_user_filter & _cancel_filter, self.cancel)] - if self._back_enabled: - fallbacks.append(MessageHandler(_user_filter & _back_filter, self.back)) - - return ConversationHandler( - entry_points=entry_points, - states=states, - fallbacks=fallbacks - ) - - def get_user_state(self, user_id: int) -> Optional[int]: - if user_id not in self._user_state_cache: - return None - return self._user_state_cache[user_id] - - # TODO store in ctx.user_state - def set_user_state(self, user_id: int, state: Union[int, None]): - if not self._back_enabled: - return - if state is not None: - self._user_state_cache[user_id] = state - else: - del self._user_state_cache[user_id] - - @staticmethod - @simplehandler - async def invalid(ctx: Context): - await ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) - # return 0 # FIXME is this needed - - @simplehandler - async def cancel(self, ctx: Context): - await start(ctx) - self.set_user_state(ctx.user_id, None) - return conversation.END - - @simplehandler - async def back(self, ctx: Context): - cur_state = self.get_user_state(ctx.user_id) - if cur_state is None: - await start(ctx) - self.set_user_state(ctx.user_id, None) - return conversation.END - - new_state = None - for seq in self.STATE_SEQS: - if cur_state in seq: - idx = seq.index(cur_state) - if idx > 0: - return self.invoke(seq[idx-1], ctx) - - if new_state is None: - raise RuntimeError('failed to determine state to go back to') - - @classmethod - def add_cancel_button(cls, ctx: Context, buttons): - buttons.append([ctx.lang('cancel')]) - - @classmethod - def add_back_button(cls, ctx: Context, buttons): - # buttons.insert(0, [ctx.lang('back')]) - buttons.append([ctx.lang('back')]) - - def reply(self, - ctx: Context, - state: Union[int, Enum], - text: str, - buttons: Optional[list], - with_cancel=False, - with_back=False, - buttons_lang_completed=False): - - if buttons: - new_buttons = [] - if not buttons_lang_completed: - for item in buttons: - if isinstance(item, list): - item = map(lambda s: ctx.lang(s), item) - new_buttons.append(list(item)) - elif isinstance(item, str): - new_buttons.append([ctx.lang(item)]) - else: - raise ValueError('invalid type: ' + type(item)) - else: - new_buttons = list(buttons) - - buttons = None - else: - if with_cancel or with_back: - new_buttons = [] - else: - new_buttons = None - - if with_cancel: - self.add_cancel_button(ctx, new_buttons) - if with_back: - if not self._back_enabled: - raise AttributeError(f'back is not enabled for this conversation ({self.__class__.__name__})') - self.add_back_button(ctx, new_buttons) - - markup = ReplyKeyboardMarkup(new_buttons, one_time_keyboard=True) if new_buttons else IgnoreMarkup() - ctx.reply(text, markup=markup) - self.set_user_state(ctx.user_id, state) - return state - - -class LangConversation(conversation): - START, = range(1) - - @conventer(START, command='lang') - async def entry(self, ctx: Context): - self._logger.debug(f'current language: {ctx.user_lang}') - - buttons = [] - for name in languages.values(): - buttons.append(name) - markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False) - - await ctx.reply(ctx.lang('select_language'), markup=markup) - return self.START - - @convinput(START, messages=lang.languages) - async def input(self, ctx: Context): - selected_lang = None - for key, value in languages.items(): - if value == ctx.text: - selected_lang = key - break - - if selected_lang is None: - raise ValueError('could not find the language') - - db.set_user_lang(ctx.user_id, selected_lang) - await ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) - - return self.END - - -def initialize(): - global _user_filter - global _application - # global _updater - global _dispatcher - - # init user_filter - _user_ids = config.app_config.get_user_ids() - if len(_user_ids) > 0: - _logger.info('allowed users: ' + str(_user_ids)) - _user_filter = filters.User(_user_ids) - else: - _user_filter = filters.ALL # not sure if this is correct - - _application = Application.builder()\ - .token(config.app_config.get('bot.token'))\ - .connect_timeout(7)\ - .read_timeout(6)\ - .build() - - # transparently log all messages - # _application.dispatcher.add_handler(MessageHandler(filters.ALL & _user_filter, _logging_message_handler), group=10) - # _application.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) - - -def run(start_handler=None, any_handler=None): - global db - global _start_handler_ref - - if not start_handler: - start_handler = _default_start_handler - if not any_handler: - any_handler = _default_any_handler - if not db: - db = BotDatabase() - - _start_handler_ref = start_handler - - _application.add_handler(LangConversation().get_handler(), group=0) - _application.add_handler(CommandHandler('start', - callback=simplehandler(start_handler), - filters=_user_filter)) - _application.add_handler(MessageHandler(filters.ALL & _user_filter, any_handler)) - - _application.run_polling() - - -def add_conversation(conv: conversation) -> None: - _application.add_handler(conv.get_handler(), group=0) - - -def add_handler(h): - _application.add_handler(h, group=0) - - -async def start(ctx: Context): - return await _start_handler_ref(ctx) - - -async def _default_start_handler(ctx: Context): - if 'start_message' not in lang: - return await ctx.reply('Please define start_message or override start()') - await ctx.reply(ctx.lang('start_message')) - - -@simplehandler -async def _default_any_handler(ctx: Context): - if 'invalid_command' not in lang: - return await ctx.reply('Please define invalid_command or override any()') - await ctx.reply(ctx.lang('invalid_command')) - - -def _logging_message_handler(update: Update, context: CallbackContext): - if _reporting: - _reporting.report(update.message) - - -def _logging_callback_handler(update: Update, context: CallbackContext): - if _reporting: - _reporting.report(update.callback_query.message, text=update.callback_query.data) - - -def enable_logging(bot_type: BotType): - api = WebApiClient(timeout=3) - api.enable_async() - - global _reporting - _reporting = ReportingHelper(api, bot_type) - - -def notify_all(text_getter: callable, - exclude: Tuple[int] = ()) -> None: - if 'notify_users' not in config['bot']: - _logger.error('notify_all() called but no notify_users directive found in the config') - return - - for user_id in config['bot']['notify_users']: - if user_id in exclude: - continue - - text = text_getter(db.get_user_lang(user_id)) - _application.bot.send_message(chat_id=user_id, - text=text, - parse_mode='HTML') - - -def notify_user(user_id: int, text: Union[str, Exception], **kwargs) -> None: - if isinstance(text, Exception): - text = exc2text(text) - _application.bot.send_message(chat_id=user_id, - text=text, - parse_mode='HTML', - **kwargs) - - -def send_photo(user_id, **kwargs): - _application.bot.send_photo(chat_id=user_id, **kwargs) - - -def send_audio(user_id, **kwargs): - _application.bot.send_audio(chat_id=user_id, **kwargs) - - -def send_file(user_id, **kwargs): - _application.bot.send_document(chat_id=user_id, **kwargs) - - -def edit_message_text(user_id, message_id, *args, **kwargs): - _application.bot.edit_message_text(chat_id=user_id, - message_id=message_id, - parse_mode='HTML', - *args, **kwargs) - - -def delete_message(user_id, message_id): - _application.bot.delete_message(chat_id=user_id, message_id=message_id) - - -def set_database(_db: BotDatabase): - global db - db = _db - diff --git a/src/home/telegram/config.py b/src/home/telegram/config.py deleted file mode 100644 index 4c7d74b..0000000 --- a/src/home/telegram/config.py +++ /dev/null @@ -1,75 +0,0 @@ -from ..config import ConfigUnit -from typing import Optional, Union -from abc import ABC -from enum import Enum - - -class TelegramUserListType(Enum): - USERS = 'users' - NOTIFY = 'notify_users' - - -class TelegramUserIdsConfig(ConfigUnit): - NAME = 'telegram_user_ids' - - @classmethod - def schema(cls) -> Optional[dict]: - return { - 'roottype': 'dict', - 'type': 'integer' - } - - -_user_ids_config = TelegramUserIdsConfig() - - -def _user_id_mapper(user: Union[str, int]) -> int: - if isinstance(user, int): - return user - return _user_ids_config[user] - - -class TelegramChatsConfig(ConfigUnit): - NAME = 'telegram_chats' - - @classmethod - def schema(cls) -> Optional[dict]: - return { - 'type': 'dict', - 'schema': { - 'id': {'type': 'string', 'required': True}, - 'token': {'type': 'string', 'required': True}, - } - } - - -class TelegramBotConfig(ConfigUnit, ABC): - @classmethod - def schema(cls) -> Optional[dict]: - return { - 'bot': { - 'type': 'dict', - 'schema': { - 'token': {'type': 'string', 'required': True}, - TelegramUserListType.USERS: {**TelegramBotConfig._userlist_schema(), 'required': True}, - TelegramUserListType.NOTIFY: TelegramBotConfig._userlist_schema(), - } - } - } - - @staticmethod - def _userlist_schema() -> dict: - return {'type': 'list', 'schema': {'type': ['string', 'int']}} - - @staticmethod - def custom_validator(data): - for ult in TelegramUserListType: - users = data['bot'][ult.value] - for user in users: - if isinstance(user, str): - if user not in _user_ids_config: - raise ValueError(f'user {user} not found in {TelegramUserIdsConfig.NAME}') - - def get_user_ids(self, - ult: TelegramUserListType = TelegramUserListType.USERS) -> list[int]: - return list(map(_user_id_mapper, self['bot'][ult.value])) \ No newline at end of file diff --git a/src/home/telegram/telegram.py b/src/home/telegram/telegram.py deleted file mode 100644 index f42363e..0000000 --- a/src/home/telegram/telegram.py +++ /dev/null @@ -1,49 +0,0 @@ -import requests -import logging - -from typing import Tuple -from .config import TelegramChatsConfig - -_chats = TelegramChatsConfig() -_logger = logging.getLogger(__name__) - - -def send_message(text: str, - chat: str, - parse_mode: str = 'HTML', - disable_web_page_preview: bool = False,): - data, token = _send_telegram_data(text, chat, parse_mode, disable_web_page_preview) - req = requests.post('https://api.telegram.org/bot%s/sendMessage' % token, data=data) - return req.json() - - -def send_photo(filename: str, chat: str): - chat_data = _chats[chat] - data = { - 'chat_id': chat_data['id'], - } - token = chat_data['token'] - - url = f'https://api.telegram.org/bot{token}/sendPhoto' - with open(filename, "rb") as fd: - req = requests.post(url, data=data, files={"photo": fd}) - return req.json() - - -def _send_telegram_data(text: str, - chat: str, - parse_mode: str = None, - disable_web_page_preview: bool = False) -> Tuple[dict, str]: - chat_data = _chats[chat] - data = { - 'chat_id': chat_data['id'], - 'text': text - } - - if parse_mode is not None: - data['parse_mode'] = parse_mode - - if disable_web_page_preview: - data['disable_web_page_preview'] = 1 - - return data, chat_data['token'] diff --git a/src/home/temphum/__init__.py b/src/home/temphum/__init__.py deleted file mode 100644 index 46d14e6..0000000 --- a/src/home/temphum/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base import SensorType, BaseSensor diff --git a/src/home/temphum/base.py b/src/home/temphum/base.py deleted file mode 100644 index 602cab7..0000000 --- a/src/home/temphum/base.py +++ /dev/null @@ -1,19 +0,0 @@ -from abc import ABC -from enum import Enum - - -class BaseSensor(ABC): - def __init__(self, bus: int): - super().__init__() - self.bus = smbus.SMBus(bus) - - def humidity(self) -> float: - pass - - def temperature(self) -> float: - pass - - -class SensorType(Enum): - Si7021 = 'si7021' - DHT12 = 'dht12' \ No newline at end of file diff --git a/src/home/temphum/i2c.py b/src/home/temphum/i2c.py deleted file mode 100644 index 7d8e2e3..0000000 --- a/src/home/temphum/i2c.py +++ /dev/null @@ -1,52 +0,0 @@ -import abc -import smbus - -from .base import BaseSensor, SensorType - - -class I2CSensor(BaseSensor, abc.ABC): - def __init__(self, bus: int): - super().__init__() - self.bus = smbus.SMBus(bus) - - -class DHT12(I2CSensor): - i2c_addr = 0x5C - - def _measure(self): - raw = self.bus.read_i2c_block_data(self.i2c_addr, 0, 5) - if (raw[0] + raw[1] + raw[2] + raw[3]) & 0xff != raw[4]: - raise ValueError("checksum error") - return raw - - def temperature(self) -> float: - raw = self._measure() - temp = raw[2] + (raw[3] & 0x7f) * 0.1 - if raw[3] & 0x80: - temp *= -1 - return temp - - def humidity(self) -> float: - raw = self._measure() - return raw[0] + raw[1] * 0.1 - - -class Si7021(I2CSensor): - i2c_addr = 0x40 - - def temperature(self) -> float: - raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE3, 2) - return 175.72 * (raw[0] << 8 | raw[1]) / 65536.0 - 46.85 - - def humidity(self) -> float: - raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE5, 2) - return 125.0 * (raw[0] << 8 | raw[1]) / 65536.0 - 6.0 - - -def create_sensor(type: SensorType, bus: int) -> BaseSensor: - if type == SensorType.Si7021: - return Si7021(bus) - elif type == SensorType.DHT12: - return DHT12(bus) - else: - raise ValueError('unexpected sensor type') diff --git a/src/home/util.py b/src/home/util.py deleted file mode 100644 index 11e7116..0000000 --- a/src/home/util.py +++ /dev/null @@ -1,255 +0,0 @@ -from __future__ import annotations - -import json -import socket -import time -import subprocess -import traceback -import logging -import string -import random -import re - -from enum import Enum -from datetime import datetime -from typing import Optional, List -from zlib import adler32 - -logger = logging.getLogger(__name__) - - -def validate_ipv4_or_hostname(address: str, raise_exception: bool = False) -> bool: - if re.match(r'^(\d{1,3}\.){3}\d{1,3}$', address): - parts = address.split('.') - if all(0 <= int(part) < 256 for part in parts): - return True - else: - if raise_exception: - raise ValueError(f"invalid IPv4 address: {address}") - return False - - if re.match(r'^[a-zA-Z0-9.-]+$', address): - return True - else: - if raise_exception: - raise ValueError(f"invalid hostname: {address}") - return False - - -def validate_mac_address(mac_address: str) -> bool: - mac_pattern = r'^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$' - if re.match(mac_pattern, mac_address): - return True - else: - return False - - -class Addr: - host: str - port: Optional[int] - - def __init__(self, host: str, port: Optional[int] = None): - self.host = host - self.port = port - - @staticmethod - def fromstring(addr: str) -> Addr: - colons = addr.count(':') - if colons != 1: - raise ValueError('invalid host:port format') - - if not colons: - host = addr - port= None - else: - host, port = addr.split(':') - - validate_ipv4_or_hostname(host, raise_exception=True) - - if port is not None: - port = int(port) - if not 0 <= port <= 65535: - raise ValueError(f'invalid port {port}') - - return Addr(host, port) - - def __str__(self): - buf = self.host - if self.port is not None: - buf += ':'+str(self.port) - return buf - - def __iter__(self): - yield self.host - yield self.port - - -# https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks -def chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i:i + n] - - -def json_serial(obj): - """JSON serializer for datetime objects""" - if isinstance(obj, datetime): - return obj.timestamp() - if isinstance(obj, Enum): - return obj.value - raise TypeError("Type %s not serializable" % type(obj)) - - -def stringify(v) -> str: - return json.dumps(v, separators=(',', ':'), default=json_serial) - - -def ipv4_valid(ip: str) -> bool: - try: - socket.inet_aton(ip) - return True - except socket.error: - return False - - -def strgen(n: int): - return ''.join(random.choices(string.ascii_letters + string.digits, k=n)) - - -class MySimpleSocketClient: - host: str - port: int - - def __init__(self, host: str, port: int): - self.host = host - self.port = port - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.connect((self.host, self.port)) - self.sock.settimeout(5) - - def __del__(self): - self.sock.close() - - def write(self, line: str) -> None: - self.sock.sendall((line + '\r\n').encode()) - - def read(self) -> str: - buf = bytearray() - while True: - buf.extend(self.sock.recv(256)) - if b'\r\n' in buf: - break - - response = buf.decode().strip() - return response - - -def send_datagram(message: str, addr: Addr) -> None: - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.sendto(message.encode(), addr) - - -def format_tb(exc) -> Optional[List[str]]: - tb = traceback.format_tb(exc.__traceback__) - if not tb: - return None - - tb = list(map(lambda s: s.strip(), tb)) - tb.reverse() - if tb[0][-1:] == ':': - tb[0] = tb[0][:-1] - - return tb - - -class ChildProcessInfo: - pid: int - cmd: str - - def __init__(self, - pid: int, - cmd: str): - self.pid = pid - self.cmd = cmd - - -def find_child_processes(ppid: int) -> List[ChildProcessInfo]: - p = subprocess.run(['pgrep', '-P', str(ppid), '--list-full'], capture_output=True) - if p.returncode != 0: - raise OSError(f'pgrep returned {p.returncode}') - - children = [] - - lines = p.stdout.decode().strip().split('\n') - for line in lines: - try: - space_idx = line.index(' ') - except ValueError as exc: - logger.exception(exc) - continue - - pid = int(line[0:space_idx]) - cmd = line[space_idx+1:] - - children.append(ChildProcessInfo(pid, cmd)) - - return children - - -class Stopwatch: - elapsed: float - time_started: Optional[float] - - def __init__(self): - self.elapsed = 0 - self.time_started = None - - def go(self): - if self.time_started is not None: - raise StopwatchError('stopwatch was already started') - - self.time_started = time.time() - - def pause(self): - if self.time_started is None: - raise StopwatchError('stopwatch was paused') - - self.elapsed += time.time() - self.time_started - self.time_started = None - - def get_elapsed_time(self): - elapsed = self.elapsed - if self.time_started is not None: - elapsed += time.time() - self.time_started - return elapsed - - def reset(self): - self.time_started = None - self.elapsed = 0 - - def is_paused(self): - return self.time_started is None - - -class StopwatchError(RuntimeError): - pass - - -def filesize_fmt(num, suffix="B") -> str: - for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: - if abs(num) < 1024.0: - return f"{num:3.1f} {unit}{suffix}" - num /= 1024.0 - return f"{num:.1f} Yi{suffix}" - - -class HashableEnum(Enum): - def hash(self) -> int: - return adler32(self.name.encode()) - - -def next_tick_gen(freq): - t = time.time() - while True: - t += freq - yield max(t - time.time(), 0) \ No newline at end of file -- cgit v1.2.3