summaryrefslogtreecommitdiff
path: root/src/home
diff options
context:
space:
mode:
authorEvgeny Zinoviev <me@ch1p.io>2023-05-31 09:22:00 +0300
committerEvgeny Zinoviev <me@ch1p.io>2023-06-10 02:07:23 +0300
commitf29e139cbb7e4a4d539cba6e894ef4a6acd312d6 (patch)
tree6246f126325c5c36fb573134a05f2771cd747966 /src/home
parent3e3753d726f8a02d98368f20f77dd9fa739e3d80 (diff)
WIP: big refactoring
Diffstat (limited to 'src/home')
-rw-r--r--src/home/audio/amixer.py2
-rw-r--r--src/home/config/__init__.py14
-rw-r--r--src/home/config/_configs.py55
-rw-r--r--src/home/config/config.py329
-rw-r--r--src/home/database/clickhouse.py2
-rw-r--r--src/home/database/sqlite.py25
-rw-r--r--src/home/inverter/config.py13
-rw-r--r--src/home/media/__init__.py1
-rw-r--r--src/home/mqtt/__init__.py11
-rw-r--r--src/home/mqtt/_config.py165
-rw-r--r--src/home/mqtt/_module.py70
-rw-r--r--src/home/mqtt/_mqtt.py (renamed from src/home/mqtt/mqtt.py)52
-rw-r--r--src/home/mqtt/_node.py92
-rw-r--r--src/home/mqtt/_payload.py (renamed from src/home/mqtt/payload/base_payload.py)4
-rw-r--r--src/home/mqtt/_util.py15
-rw-r--r--src/home/mqtt/_wrapper.py59
-rw-r--r--src/home/mqtt/esp.py106
-rw-r--r--src/home/mqtt/module/diagnostics.py (renamed from src/home/mqtt/payload/esp.py)56
-rw-r--r--src/home/mqtt/module/inverter.py195
-rw-r--r--src/home/mqtt/module/ota.py77
-rw-r--r--src/home/mqtt/module/relay.py92
-rw-r--r--src/home/mqtt/module/temphum.py82
-rw-r--r--src/home/mqtt/payload/__init__.py1
-rw-r--r--src/home/mqtt/payload/inverter.py73
-rw-r--r--src/home/mqtt/payload/relay.py22
-rw-r--r--src/home/mqtt/payload/sensors.py20
-rw-r--r--src/home/mqtt/payload/temphum.py15
-rw-r--r--src/home/mqtt/relay.py71
-rw-r--r--src/home/mqtt/temphum.py54
-rw-r--r--src/home/mqtt/util.py8
-rw-r--r--src/home/pio/products.py4
-rw-r--r--src/home/telegram/_botcontext.py19
-rw-r--r--src/home/telegram/bot.py149
-rw-r--r--src/home/telegram/config.py75
-rw-r--r--src/home/temphum/__init__.py19
-rw-r--r--src/home/temphum/base.py20
-rw-r--r--src/home/temphum/dht12.py22
-rw-r--r--src/home/temphum/i2c.py52
-rw-r--r--src/home/temphum/si7021.py13
-rw-r--r--src/home/util.py70
40 files changed, 1525 insertions, 699 deletions
diff --git a/src/home/audio/amixer.py b/src/home/audio/amixer.py
index 53e6bce..5133c97 100644
--- a/src/home/audio/amixer.py
+++ b/src/home/audio/amixer.py
@@ -1,6 +1,6 @@
import subprocess
-from ..config import config
+from ..config import app_config as config
from threading import Lock
from typing import Union, List
diff --git a/src/home/config/__init__.py b/src/home/config/__init__.py
index cc9c091..2fa5214 100644
--- a/src/home/config/__init__.py
+++ b/src/home/config/__init__.py
@@ -1 +1,13 @@
-from .config import ConfigStore, config, is_development_mode, setup_logging
+from .config import (
+ Config,
+ ConfigUnit,
+ AppConfigUnit,
+ Translation,
+ config,
+ is_development_mode,
+ setup_logging
+)
+from ._configs import (
+ LinuxBoardsConfig,
+ ServicesListConfig
+) \ No newline at end of file
diff --git a/src/home/config/_configs.py b/src/home/config/_configs.py
new file mode 100644
index 0000000..3a1aae5
--- /dev/null
+++ b/src/home/config/_configs.py
@@ -0,0 +1,55 @@
+from .config import ConfigUnit
+from typing import Optional
+
+
+class ServicesListConfig(ConfigUnit):
+ NAME = 'services_list'
+
+ @staticmethod
+ def schema() -> Optional[dict]:
+ return {
+ 'type': 'list',
+ 'empty': False,
+ 'schema': {
+ 'type': 'string'
+ }
+ }
+
+
+class LinuxBoardsConfig(ConfigUnit):
+ NAME = 'linux_boards'
+
+ @staticmethod
+ def schema() -> Optional[dict]:
+ return {
+ 'type': 'dict',
+ 'schema': {
+ 'mdns': {'type': 'string', 'required': True},
+ 'board': {'type': 'string', 'required': True},
+ 'network': {
+ 'type': 'list',
+ 'required': True,
+ 'empty': False,
+ 'allowed': ['wifi', 'ethernet']
+ },
+ 'ram': {'type': 'integer', 'required': True},
+ 'online': {'type': 'boolean', 'required': True},
+
+ # optional
+ 'services': {
+ 'type': 'list',
+ 'empty': False,
+ 'allowed': ServicesListConfig().get()
+ },
+ 'ext_hdd': {
+ 'type': 'list',
+ 'schema': {
+ 'type': 'dict',
+ 'schema': {
+ 'mountpoint': {'type': 'string', 'required': True},
+ 'size': {'type': 'integer', 'required': True}
+ }
+ },
+ },
+ }
+ }
diff --git a/src/home/config/config.py b/src/home/config/config.py
index 4681685..aef9ee7 100644
--- a/src/home/config/config.py
+++ b/src/home/config/config.py
@@ -1,58 +1,256 @@
-import toml
import yaml
import logging
import os
+import pprint
-from os.path import join, isdir, isfile
-from typing import Optional, Any, MutableMapping
+from abc import ABC
+from cerberus import Validator, DocumentError
+from typing import Optional, Any, MutableMapping, Union
from argparse import ArgumentParser
-from ..util import parse_addr
+from enum import Enum, auto
+from os.path import join, isdir, isfile
+from ..util import Addr
+
+
+CONFIG_DIRECTORIES = (
+ join(os.environ['HOME'], '.config', 'homekit'),
+ '/etc/homekit'
+)
+
+class RootSchemaType(Enum):
+ DEFAULT = auto()
+ DICT = auto()
+ LIST = auto()
+
+
+class BaseConfigUnit(ABC):
+ _data: MutableMapping[str, Any]
+ _logger: logging.Logger
+ def __init__(self):
+ self._data = {}
+ self._logger = logging.getLogger(self.__class__.__name__)
+
+ def __getitem__(self, key):
+ return self._data[key]
+
+ def __setitem__(self, key, value):
+ raise NotImplementedError('overwriting config values is prohibited')
-def _get_config_path(name: str) -> str:
- formats = ['toml', 'yaml']
+ def __contains__(self, key):
+ return key in self._data
- dirname = join(os.environ['HOME'], '.config', name)
+ def load_from(self, path: str):
+ with open(path, 'r') as fd:
+ self._data = yaml.safe_load(fd)
- if isdir(dirname):
- for fmt in formats:
- filename = join(dirname, f'config.{fmt}')
- if isfile(filename):
- return filename
+ def get(self,
+ key: Optional[str] = None,
+ default=None):
+ if key is None:
+ return self._data
- raise IOError(f'config not found in {dirname}')
+ cur = self._data
+ pts = key.split('.')
+ for i in range(len(pts)):
+ k = pts[i]
+ if i < len(pts)-1:
+ if k not in cur:
+ raise KeyError(f'key {k} not found')
+ else:
+ return cur[k] if k in cur else default
+ cur = self._data[k]
- else:
- filenames = [join(os.environ['HOME'], '.config', f'{name}.{format}') for format in formats]
- for file in filenames:
- if isfile(file):
- return file
+ raise KeyError(f'option {key} not found')
- raise IOError(f'config not found')
+class ConfigUnit(BaseConfigUnit):
+ NAME = 'dumb'
+
+ def __init__(self, name=None, load=True):
+ super().__init__()
+
+ self._data = {}
+ self._logger = logging.getLogger(self.__class__.__name__)
+
+ if self.NAME != 'dumb' and load:
+ self.load_from(self.get_config_path())
+ self.validate()
+
+ elif name is not None:
+ self.NAME = name
+
+ @classmethod
+ def get_config_path(cls, name=None) -> str:
+ if name is None:
+ name = cls.NAME
+ if name is None:
+ raise ValueError('get_config_path: name is none')
+
+ for dirname in CONFIG_DIRECTORIES:
+ if isdir(dirname):
+ filename = join(dirname, f'{name}.yaml')
+ if isfile(filename):
+ return filename
+
+ raise IOError(f'\'{name}.yaml\' not found')
+
+ @staticmethod
+ def schema() -> Optional[dict]:
+ return None
+
+ def validate(self):
+ schema = self.schema()
+ if not schema:
+ self._logger.warning('validate: no schema')
+ return
+
+ if isinstance(self, AppConfigUnit):
+ schema['logging'] = {
+ 'type': 'dict',
+ 'schema': {
+ 'logging': {'type': 'bool'}
+ }
+ }
+
+ rst = RootSchemaType.DEFAULT
+ try:
+ if schema['type'] == 'dict':
+ rst = RootSchemaType.DICT
+ elif schema['type'] == 'list':
+ rst = RootSchemaType.LIST
+ elif schema['roottype'] == 'dict':
+ del schema['roottype']
+ rst = RootSchemaType.DICT
+ except KeyError:
+ pass
+
+ if rst == RootSchemaType.DICT:
+ v = Validator({'document': {
+ 'type': 'dict',
+ 'keysrules': {'type': 'string'},
+ 'valuesrules': schema
+ }})
+ result = v.validate({'document': self._data})
+ elif rst == RootSchemaType.LIST:
+ v = Validator({'document': schema})
+ result = v.validate({'document': self._data})
+ else:
+ v = Validator(schema)
+ result = v.validate(self._data)
+ # pprint.pprint(self._data)
+ if not result:
+ # pprint.pprint(v.errors)
+ raise DocumentError(f'{self.__class__.__name__}: failed to validate data:\n{pprint.pformat(v.errors)}')
+ try:
+ self.custom_validator(self._data)
+ except Exception as e:
+ raise DocumentError(f'{self.__class__.__name__}: {str(e)}')
+
+ @staticmethod
+ def custom_validator(data):
+ pass
-class ConfigStore:
- data: MutableMapping[str, Any]
+ def get_addr(self, key: str):
+ return Addr.fromstring(self.get(key))
+
+
+class AppConfigUnit(ConfigUnit):
+ _logging_verbose: bool
+ _logging_fmt: Optional[str]
+ _logging_file: Optional[str]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(load=False, *args, **kwargs)
+ self._logging_verbose = False
+ self._logging_fmt = None
+ self._logging_file = None
+
+ def logging_set_fmt(self, fmt: str) -> None:
+ self._logging_fmt = fmt
+
+ def logging_get_fmt(self) -> Optional[str]:
+ try:
+ return self['logging']['default_fmt']
+ except KeyError:
+ return self._logging_fmt
+
+ def logging_set_file(self, file: str) -> None:
+ self._logging_file = file
+
+ def logging_get_file(self) -> Optional[str]:
+ try:
+ return self['logging']['file']
+ except KeyError:
+ return self._logging_file
+
+ def logging_set_verbose(self):
+ self._logging_verbose = True
+
+ def logging_is_verbose(self) -> bool:
+ try:
+ return bool(self['logging']['verbose'])
+ except KeyError:
+ return self._logging_verbose
+
+
+class TranslationUnit(BaseConfigUnit):
+ pass
+
+
+class Translation:
+ LANGUAGES = ('en', 'ru')
+ _langs: dict[str, TranslationUnit]
+
+ def __init__(self, name: str):
+ super().__init__()
+ self._langs = {}
+ for lang in self.LANGUAGES:
+ for dirname in CONFIG_DIRECTORIES:
+ if isdir(dirname):
+ filename = join(dirname, f'i18n-{lang}', f'{name}.yaml')
+ if lang in self._langs:
+ raise RuntimeError(f'{name}: translation unit for lang \'{lang}\' already loaded')
+ self._langs[lang] = TranslationUnit()
+ self._langs[lang].load_from(filename)
+ diff = set()
+ for data in self._langs.values():
+ diff ^= data.get().keys()
+ if len(diff) > 0:
+ raise RuntimeError(f'{name}: translation units have difference in keys: ' + ', '.join(diff))
+
+ def get(self, lang: str) -> TranslationUnit:
+ return self._langs[lang]
+
+
+class Config:
app_name: Optional[str]
+ app_config: AppConfigUnit
def __init__(self):
- self.data = {}
self.app_name = None
+ self.app_config = AppConfigUnit()
+
+ def load_app(self,
+ name: Optional[Union[str, AppConfigUnit, bool]] = None,
+ use_cli=True,
+ parser: ArgumentParser = None,
+ no_config=False):
+ global app_config
+
+ if issubclass(name, AppConfigUnit) or name == AppConfigUnit:
+ self.app_name = name.NAME
+ self.app_config = name()
+ app_config = self.app_config
+ else:
+ self.app_name = name if isinstance(name, str) else None
- def load(self, name: Optional[str] = None,
- use_cli=True,
- parser: ArgumentParser = None):
- self.app_name = name
-
- if (name is None) and (not use_cli):
+ if self.app_name is None and not use_cli:
raise RuntimeError('either config name must be none or use_cli must be True')
- log_default_fmt = False
- log_file = None
- log_verbose = False
- no_config = name is False
-
+ no_config = name is False or no_config
path = None
+
if use_cli:
if parser is None:
parser = ArgumentParser()
@@ -68,75 +266,38 @@ class ConfigStore:
path = args.config
if args.verbose:
- log_verbose = True
+ self.app_config.logging_set_verbose()
if args.log_file:
- log_file = args.log_file
+ self.app_config.logging_set_file(args.log_file)
if args.log_default_fmt:
- log_default_fmt = args.log_default_fmt
+ self.app_config.logging_set_fmt(args.log_default_fmt)
- if not no_config and path is None:
- path = _get_config_path(name)
+ if not isinstance(name, ConfigUnit):
+ if not no_config and path is None:
+ path = ConfigUnit.get_config_path(name=self.app_name)
- if no_config:
- self.data = {}
- else:
- if path.endswith('.toml'):
- self.data = toml.load(path)
- elif path.endswith('.yaml'):
- with open(path, 'r') as fd:
- self.data = yaml.safe_load(fd)
-
- if 'logging' in self:
- if not log_file and 'file' in self['logging']:
- log_file = self['logging']['file']
- if log_default_fmt and 'default_fmt' in self['logging']:
- log_default_fmt = self['logging']['default_fmt']
+ if not no_config:
+ self.app_config.load_from(path)
- setup_logging(log_verbose, log_file, log_default_fmt)
+ setup_logging(self.app_config.logging_is_verbose(),
+ self.app_config.logging_get_file(),
+ self.app_config.logging_get_fmt())
if use_cli:
return args
- def __getitem__(self, key):
- return self.data[key]
-
- def __setitem__(self, key, value):
- raise NotImplementedError('overwriting config values is prohibited')
-
- def __contains__(self, key):
- return key in self.data
-
- def get(self, key: str, default=None):
- cur = self.data
- pts = key.split('.')
- for i in range(len(pts)):
- k = pts[i]
- if i < len(pts)-1:
- if k not in cur:
- raise KeyError(f'key {k} not found')
- else:
- return cur[k] if k in cur else default
- cur = self.data[k]
- raise KeyError(f'option {key} not found')
-
- def get_addr(self, key: str):
- return parse_addr(self.get(key))
-
- def items(self):
- return self.data.items()
-
-config = ConfigStore()
+config = Config()
def is_development_mode() -> bool:
if 'HK_MODE' in os.environ and os.environ['HK_MODE'] == 'dev':
return True
- return ('logging' in config) and ('verbose' in config['logging']) and (config['logging']['verbose'] is True)
+ return ('logging' in config.app_config) and ('verbose' in config.app_config['logging']) and (config.app_config['logging']['verbose'] is True)
-def setup_logging(verbose=False, log_file=None, default_fmt=False):
+def setup_logging(verbose=False, log_file=None, default_fmt=None):
logging_level = logging.INFO
if is_development_mode() or verbose:
logging_level = logging.DEBUG
diff --git a/src/home/database/clickhouse.py b/src/home/database/clickhouse.py
index ca81628..d0ec283 100644
--- a/src/home/database/clickhouse.py
+++ b/src/home/database/clickhouse.py
@@ -1,7 +1,7 @@
import logging
from zoneinfo import ZoneInfo
-from datetime import datetime, timedelta
+from datetime import datetime
from clickhouse_driver import Client as ClickhouseClient
from ..config import is_development_mode
diff --git a/src/home/database/sqlite.py b/src/home/database/sqlite.py
index bfba929..8c6145c 100644
--- a/src/home/database/sqlite.py
+++ b/src/home/database/sqlite.py
@@ -5,24 +5,27 @@ import logging
from ..config import config, is_development_mode
-def _get_database_path(name: str, dbname: str) -> str:
- return os.path.join(os.environ['HOME'], '.config', name, f'{dbname}.db')
+def _get_database_path(name: str) -> str:
+ return os.path.join(
+ os.environ['HOME'],
+ '.config',
+ 'homekit',
+ 'data',
+ f'{name}.db')
class SQLiteBase:
SCHEMA = 1
- def __init__(self, name=None, dbname='bot', check_same_thread=False):
- db_path = config.get('db_path', default=None)
- if db_path is None:
- if not name:
- name = config.app_name
- if not dbname:
- dbname = name
- db_path = _get_database_path(name, dbname)
+ def __init__(self, name=None, check_same_thread=False):
+ if name is None:
+ name = config.app_config['database_name']
+ database_path = _get_database_path(name)
+ if not os.path.exists(os.path.dirname(database_path)):
+ os.makedirs(os.path.dirname(database_path))
self.logger = logging.getLogger(self.__class__.__name__)
- self.sqlite = sqlite3.connect(db_path, check_same_thread=check_same_thread)
+ self.sqlite = sqlite3.connect(database_path, check_same_thread=check_same_thread)
if is_development_mode():
self.sql_logger = logging.getLogger(self.__class__.__name__)
diff --git a/src/home/inverter/config.py b/src/home/inverter/config.py
new file mode 100644
index 0000000..62b8859
--- /dev/null
+++ b/src/home/inverter/config.py
@@ -0,0 +1,13 @@
+from ..config import ConfigUnit
+from typing import Optional
+
+
+class InverterdConfig(ConfigUnit):
+ NAME = 'inverterd'
+
+ @staticmethod
+ def schema() -> Optional[dict]:
+ return {
+ 'remote_addr': {'type': 'string'},
+ 'local_addr': {'type': 'string'},
+ } \ No newline at end of file
diff --git a/src/home/media/__init__.py b/src/home/media/__init__.py
index 976c990..6923105 100644
--- a/src/home/media/__init__.py
+++ b/src/home/media/__init__.py
@@ -12,6 +12,7 @@ __map__ = {
__all__ = list(itertools.chain(*__map__.values()))
+
def __getattr__(name):
if name in __all__:
for file, names in __map__.items():
diff --git a/src/home/mqtt/__init__.py b/src/home/mqtt/__init__.py
index 982e2b6..707d59c 100644
--- a/src/home/mqtt/__init__.py
+++ b/src/home/mqtt/__init__.py
@@ -1,4 +1,7 @@
-from .mqtt import MqttBase
-from .util import poll_tick
-from .relay import MqttRelay, MqttRelayState
-from .temphum import MqttTempHum \ No newline at end of file
+from ._mqtt import Mqtt
+from ._node import MqttNode
+from ._module import MqttModule
+from ._wrapper import MqttWrapper
+from ._config import MqttConfig, MqttCreds, MqttNodesConfig
+from ._payload import MqttPayload, MqttPayloadCustomField
+from ._util import get_modules as get_mqtt_modules \ No newline at end of file
diff --git a/src/home/mqtt/_config.py b/src/home/mqtt/_config.py
new file mode 100644
index 0000000..f9047b4
--- /dev/null
+++ b/src/home/mqtt/_config.py
@@ -0,0 +1,165 @@
+from ..config import ConfigUnit
+from typing import Optional, Union
+from ..util import Addr
+from collections import namedtuple
+
+MqttCreds = namedtuple('MqttCreds', 'username, password')
+
+
+class MqttConfig(ConfigUnit):
+ NAME = 'mqtt'
+
+ @staticmethod
+ def schema() -> Optional[dict]:
+ addr_schema = {
+ 'type': 'dict',
+ 'required': True,
+ 'schema': {
+ 'host': {'type': 'string', 'required': True},
+ 'port': {'type': 'integer', 'required': True}
+ }
+ }
+
+ schema = {}
+ for key in ('local', 'remote'):
+ schema[f'{key}_addr'] = addr_schema
+
+ schema['creds'] = {
+ 'type': 'dict',
+ 'required': True,
+ 'keysrules': {'type': 'string'},
+ 'valuesrules': {
+ 'type': 'dict',
+ 'schema': {
+ 'username': {'type': 'string', 'required': True},
+ 'password': {'type': 'string', 'required': True},
+ }
+ }
+ }
+
+ for key in ('client', 'server'):
+ schema[f'default_{key}_creds'] = {'type': 'string', 'required': True}
+
+ return schema
+
+ def remote_addr(self) -> Addr:
+ return Addr(host=self['remote_addr']['host'],
+ port=self['remote_addr']['port'])
+
+ def local_addr(self) -> Addr:
+ return Addr(host=self['local_addr']['host'],
+ port=self['local_addr']['port'])
+
+ def creds_by_name(self, name: str) -> MqttCreds:
+ return MqttCreds(username=self['creds'][name]['username'],
+ password=self['creds'][name]['password'])
+
+ def creds(self) -> MqttCreds:
+ return self.creds_by_name(self['default_client_creds'])
+
+ def server_creds(self) -> MqttCreds:
+ return self.creds_by_name(self['default_server_creds'])
+
+
+class MqttNodesConfig(ConfigUnit):
+ NAME = 'mqtt_nodes'
+
+ @staticmethod
+ def schema() -> Optional[dict]:
+ return {
+ 'common': {
+ 'type': 'dict',
+ 'schema': {
+ 'temphum': {
+ 'type': 'dict',
+ 'schema': {
+ 'interval': {'type': 'integer'}
+ }
+ },
+ 'password': {'type': 'string'}
+ }
+ },
+ 'nodes': {
+ 'type': 'dict',
+ 'required': True,
+ 'keysrules': {'type': 'string'},
+ 'valuesrules': {
+ 'type': 'dict',
+ 'schema': {
+ 'type': {'type': 'string', 'required': True, 'allowed': ['esp8266', 'linux', 'none'],},
+ 'board': {'type': 'string', 'allowed': ['nodemcu', 'd1_mini_lite', 'esp12e']},
+ 'temphum': {
+ 'type': 'dict',
+ 'schema': {
+ 'module': {'type': 'string', 'required': True, 'allowed': ['si7021', 'dht12']},
+ 'interval': {'type': 'integer'},
+ 'i2c_bus': {'type': 'integer'},
+ 'tcpserver': {
+ 'type': 'dict',
+ 'schema': {
+ 'port': {'type': 'integer', 'required': True}
+ }
+ }
+ }
+ },
+ 'relay': {
+ 'type': 'dict',
+ 'schema': {
+ 'device_type': {'type': 'string', 'allowed': ['lamp', 'pump', 'solenoid'], 'required': True},
+ 'legacy_topics': {'type': 'boolean'}
+ }
+ },
+ 'password': {'type': 'string'}
+ }
+ }
+ }
+ }
+
+ @staticmethod
+ def custom_validator(data):
+ for name, node in data['nodes'].items():
+ if 'temphum' in node:
+ if node['type'] == 'linux':
+ if 'i2c_bus' not in node['temphum']:
+ raise KeyError(f'nodes.{name}.temphum: i2c_bus is missing but required for type=linux')
+ if node['type'] in ('esp8266',) and 'board' not in node:
+ raise KeyError(f'nodes.{name}: board is missing but required for type={node["type"]}')
+
+ def get_node(self, name: str) -> dict:
+ node = self['nodes'][name]
+ if node['type'] == 'none':
+ return node
+
+ try:
+ if 'password' not in node:
+ node['password'] = self['common']['password']
+ except KeyError:
+ pass
+
+ try:
+ if 'temphum' in node:
+ for ckey, cval in self['common']['temphum'].items():
+ if ckey not in node['temphum']:
+ node['temphum'][ckey] = cval
+ except KeyError:
+ pass
+
+ return node
+
+ def get_nodes(self,
+ filters: Optional[Union[list[str], tuple[str]]] = None,
+ only_names=False) -> Union[dict, list[str]]:
+ if filters:
+ for f in filters:
+ if f not in ('temphum', 'relay'):
+ raise ValueError(f'{self.__class__.__name__}::get_node(): invalid filter {f}')
+ reslist = []
+ resdict = {}
+ for name in self['nodes'].keys():
+ node = self.get_node(name)
+ if (not filters) or ('temphum' in filters and 'temphum' in node) or ('relay' in filters and 'relay' in node):
+ if only_names:
+ reslist.append(name)
+ else:
+ resdict[name] = node
+ return reslist if only_names else resdict
diff --git a/src/home/mqtt/_module.py b/src/home/mqtt/_module.py
new file mode 100644
index 0000000..80f27bb
--- /dev/null
+++ b/src/home/mqtt/_module.py
@@ -0,0 +1,70 @@
+from __future__ import annotations
+
+import abc
+import logging
+import threading
+
+from time import sleep
+from ..util import next_tick_gen
+
+from typing import TYPE_CHECKING, Optional
+if TYPE_CHECKING:
+ from ._node import MqttNode
+ from ._payload import MqttPayload
+
+
+class MqttModule(abc.ABC):
+ _tick_interval: int
+ _initialized: bool
+ _connected: bool
+ _ticker: Optional[threading.Thread]
+ _mqtt_node_ref: Optional[MqttNode]
+
+ def __init__(self, tick_interval=0):
+ self._tick_interval = tick_interval
+ self._initialized = False
+ self._ticker = None
+ self._logger = logging.getLogger(self.__class__.__name__)
+ self._connected = False
+ self._mqtt_node_ref = None
+
+ def on_connect(self, mqtt: MqttNode):
+ self._connected = True
+ self._mqtt_node_ref = mqtt
+ if self._tick_interval:
+ self._start_ticker()
+
+ def on_disconnect(self, mqtt: MqttNode):
+ self._connected = False
+ self._mqtt_node_ref = None
+
+ def is_initialized(self):
+ return self._initialized
+
+ def set_initialized(self):
+ self._initialized = True
+
+ def unset_initialized(self):
+ self._initialized = False
+
+ def tick(self):
+ pass
+
+ def _tick(self):
+ g = next_tick_gen(self._tick_interval)
+ while self._connected:
+ sleep(next(g))
+ if not self._connected:
+ break
+ self.tick()
+
+ def _start_ticker(self):
+ if not self._ticker or not self._ticker.is_alive():
+ name_part = f'{self._mqtt_node_ref.id}/' if self._mqtt_node_ref else ''
+ self._ticker = None
+ self._ticker = threading.Thread(target=self._tick,
+ name=f'mqtt:{self.__class__.__name__}/{name_part}ticker')
+ self._ticker.start()
+
+ def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
+ pass
diff --git a/src/home/mqtt/mqtt.py b/src/home/mqtt/_mqtt.py
index 4acd4f6..746ae2e 100644
--- a/src/home/mqtt/mqtt.py
+++ b/src/home/mqtt/_mqtt.py
@@ -3,19 +3,24 @@ import paho.mqtt.client as mqtt
import ssl
import logging
-from typing import Tuple
-from ..config import config
+from ._config import MqttCreds, MqttConfig
+from typing import Optional
-def username_and_password() -> Tuple[str, str]:
- username = config['mqtt']['username'] if 'username' in config['mqtt'] else None
- password = config['mqtt']['password'] if 'password' in config['mqtt'] else None
- return username, password
+class Mqtt:
+ _connected: bool
+ _is_server: bool
+ _mqtt_config: MqttConfig
+ def __init__(self,
+ clean_session=True,
+ client_id='',
+ creds: Optional[MqttCreds] = None,
+ is_server=False):
+ if not client_id:
+ raise ValueError('client_id must not be empty')
-class MqttBase:
- def __init__(self, clean_session=True):
- self._client = mqtt.Client(client_id=config['mqtt']['client_id'],
+ self._client = mqtt.Client(client_id=client_id,
protocol=mqtt.MQTTv311,
clean_session=clean_session)
self._client.on_connect = self.on_connect
@@ -24,15 +29,17 @@ class MqttBase:
self._client.on_log = self.on_log
self._client.on_publish = self.on_publish
self._loop_started = False
-
+ self._connected = False
+ self._is_server = is_server
+ self._mqtt_config = MqttConfig()
self._logger = logging.getLogger(self.__class__.__name__)
- username, password = username_and_password()
- if username and password:
- self._logger.debug(f'username={username} password={password}')
- self._client.username_pw_set(username, password)
+ if not creds:
+ creds = self._mqtt_config.creds() if not is_server else self._mqtt_config.server_creds()
+
+ self._client.username_pw_set(creds.username, creds.password)
- def configure_tls(self):
+ def _configure_tls(self):
ca_certs = os.path.realpath(os.path.join(
os.path.dirname(os.path.realpath(__file__)),
'..',
@@ -41,13 +48,14 @@ class MqttBase:
'assets',
'mqtt_ca.crt'
))
- self._client.tls_set(ca_certs=ca_certs, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_TLSv1_2)
+ self._client.tls_set(ca_certs=ca_certs,
+ cert_reqs=ssl.CERT_REQUIRED,
+ tls_version=ssl.PROTOCOL_TLSv1_2)
def connect_and_loop(self, loop_forever=True):
- host = config['mqtt']['host']
- port = config['mqtt']['port']
-
- self._client.connect(host, port, 60)
+ self._configure_tls()
+ addr = self._mqtt_config.local_addr() if self._is_server else self._mqtt_config.remote_addr()
+ self._client.connect(addr.host, addr.port, 60)
if loop_forever:
self._client.loop_forever()
else:
@@ -61,9 +69,11 @@ class MqttBase:
def on_connect(self, client: mqtt.Client, userdata, flags, rc):
self._logger.info("Connected with result code " + str(rc))
+ self._connected = True
def on_disconnect(self, client: mqtt.Client, userdata, rc):
self._logger.info("Disconnected with result code " + str(rc))
+ self._connected = False
def on_log(self, client: mqtt.Client, userdata, level, buf):
level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO
@@ -73,4 +83,4 @@ class MqttBase:
self._logger.debug(msg.topic + ": " + str(msg.payload))
def on_publish(self, client: mqtt.Client, userdata, mid):
- self._logger.debug(f'publish done, mid={mid}') \ No newline at end of file
+ self._logger.debug(f'publish done, mid={mid}')
diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py
new file mode 100644
index 0000000..4e259a4
--- /dev/null
+++ b/src/home/mqtt/_node.py
@@ -0,0 +1,92 @@
+import logging
+import importlib
+
+from typing import List, TYPE_CHECKING, Optional
+from ._payload import MqttPayload
+from ._module import MqttModule
+if TYPE_CHECKING:
+ from ._wrapper import MqttWrapper
+else:
+ MqttWrapper = None
+
+
+class MqttNode:
+ _modules: List[MqttModule]
+ _module_subscriptions: dict[str, MqttModule]
+ _node_id: str
+ _node_secret: str
+ _payload_callbacks: list[callable]
+ _wrapper: Optional[MqttWrapper]
+
+ def __init__(self,
+ node_id: str,
+ node_secret: Optional[str] = None):
+ self._modules = []
+ self._module_subscriptions = {}
+ self._node_id = node_id
+ self._node_secret = node_secret
+ self._payload_callbacks = []
+ self._logger = logging.getLogger(self.__class__.__name__)
+ self._wrapper = None
+
+ def on_connect(self, wrapper: MqttWrapper):
+ self._wrapper = wrapper
+ for module in self._modules:
+ if not module.is_initialized():
+ module.on_connect(self)
+ module.set_initialized()
+
+ def on_disconnect(self):
+ self._wrapper = None
+ for module in self._modules:
+ module.unset_initialized()
+
+ def on_message(self, topic, payload):
+ if topic in self._module_subscriptions:
+ payload = self._module_subscriptions[topic].handle_payload(self, topic, payload)
+ if isinstance(payload, MqttPayload):
+ for f in self._payload_callbacks:
+ f(self, payload)
+
+ def load_module(self, module_name: str, *args, **kwargs) -> MqttModule:
+ module = importlib.import_module(f'..module.{module_name}', __name__)
+ if not hasattr(module, 'MODULE_NAME'):
+ raise RuntimeError(f'MODULE_NAME not found in module {module}')
+ cl = getattr(module, getattr(module, 'MODULE_NAME'))
+ instance = cl(*args, **kwargs)
+ self.add_module(instance)
+ return instance
+
+ def add_module(self, module: MqttModule):
+ self._modules.append(module)
+ if self._wrapper and self._wrapper._connected:
+ module.on_connect(self)
+ module.set_initialized()
+
+ def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1):
+ if not self._wrapper or not self._wrapper._connected:
+ raise RuntimeError('not connected')
+
+ self._module_subscriptions[topic] = module
+ self._wrapper.subscribe(self.id, topic, qos)
+
+ def publish(self,
+ topic: str,
+ payload: bytes,
+ qos: int = 1):
+ self._wrapper.publish(self.id, topic, payload, qos)
+
+ def add_payload_callback(self, callback: callable):
+ self._payload_callbacks.append(callback)
+
+ @property
+ def id(self) -> str:
+ return self._node_id
+
+ @property
+ def secret(self) -> str:
+ return self._node_secret
+
+ @secret.setter
+ def secret(self, secret: str) -> None:
+ self._node_secret = secret
diff --git a/src/home/mqtt/payload/base_payload.py b/src/home/mqtt/_payload.py
index 1abd898..58eeae3 100644
--- a/src/home/mqtt/payload/base_payload.py
+++ b/src/home/mqtt/_payload.py
@@ -1,5 +1,5 @@
-import abc
import struct
+import abc
import re
from typing import Optional, Tuple
@@ -142,4 +142,4 @@ def _bit_field_params(cl) -> Optional[Tuple[int, ...]]:
match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__)
if match is not None:
return tuple([int(match.group(i)) for i in range(1, 4)])
- return None
+ return None \ No newline at end of file
diff --git a/src/home/mqtt/_util.py b/src/home/mqtt/_util.py
new file mode 100644
index 0000000..390d463
--- /dev/null
+++ b/src/home/mqtt/_util.py
@@ -0,0 +1,15 @@
+import os
+import re
+
+from typing import List
+
+
+def get_modules() -> List[str]:
+ modules = []
+ modules_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module')
+ for name in os.listdir(modules_dir):
+ if os.path.isdir(os.path.join(modules_dir, name)):
+ continue
+ name = re.sub(r'\.py$', '', name)
+ modules.append(name)
+ return modules
diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py
new file mode 100644
index 0000000..f858f88
--- /dev/null
+++ b/src/home/mqtt/_wrapper.py
@@ -0,0 +1,59 @@
+import paho.mqtt.client as mqtt
+
+from ._mqtt import Mqtt
+from ._node import MqttNode
+from ..config import config
+from ..util import strgen
+
+
+class MqttWrapper(Mqtt):
+ _nodes: list[MqttNode]
+
+ def __init__(self,
+ client_id: str,
+ topic_prefix='hk',
+ randomize_client_id=False,
+ clean_session=True):
+ if randomize_client_id:
+ client_id += '_'+strgen(6)
+ super().__init__(clean_session=clean_session,
+ client_id=client_id)
+ self._nodes = []
+ self._topic_prefix = topic_prefix
+
+ def on_connect(self, client: mqtt.Client, userdata, flags, rc):
+ super().on_connect(client, userdata, flags, rc)
+ for node in self._nodes:
+ node.on_connect(self)
+
+ def on_disconnect(self, client: mqtt.Client, userdata, rc):
+ super().on_disconnect(client, userdata, rc)
+ for node in self._nodes:
+ node.on_disconnect()
+
+ def on_message(self, client: mqtt.Client, userdata, msg):
+ try:
+ topic = msg.topic
+ for node in self._nodes:
+ node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload)
+ except Exception as e:
+ self._logger.exception(str(e))
+
+ def add_node(self, node: MqttNode):
+ self._nodes.append(node)
+ if self._connected:
+ node.on_connect(self)
+
+ def subscribe(self,
+ node_id: str,
+ topic: str,
+ qos: int):
+ self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos)
+
+ def publish(self,
+ node_id: str,
+ topic: str,
+ payload: bytes,
+ qos: int):
+ self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos)
+ self._client.loop_write()
diff --git a/src/home/mqtt/esp.py b/src/home/mqtt/esp.py
deleted file mode 100644
index 56ced83..0000000
--- a/src/home/mqtt/esp.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import re
-import paho.mqtt.client as mqtt
-
-from .mqtt import MqttBase
-from typing import Optional, Union
-from .payload.esp import (
- OTAPayload,
- OTAResultPayload,
- DiagnosticsPayload,
- InitialDiagnosticsPayload
-)
-
-
-class MqttEspDevice:
- id: str
- secret: Optional[str]
-
- def __init__(self, id: str, secret: Optional[str] = None):
- self.id = id
- self.secret = secret
-
-
-class MqttEspBase(MqttBase):
- _devices: list[MqttEspDevice]
- _message_callback: Optional[callable]
- _ota_publish_callback: Optional[callable]
-
- TOPIC_LEAF = 'esp'
-
- def __init__(self,
- devices: Union[MqttEspDevice, list[MqttEspDevice]],
- subscribe_to_updates=True):
- super().__init__(clean_session=True)
- if not isinstance(devices, list):
- devices = [devices]
- self._devices = devices
- self._message_callback = None
- self._ota_publish_callback = None
- self._subscribe_to_updates = subscribe_to_updates
- self._ota_mid = None
-
- def on_connect(self, client: mqtt.Client, userdata, flags, rc):
- super().on_connect(client, userdata, flags, rc)
-
- if self._subscribe_to_updates:
- for device in self._devices:
- topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#'
- self._logger.debug(f"subscribing to {topic}")
- client.subscribe(topic, qos=1)
-
- def on_publish(self, client: mqtt.Client, userdata, mid):
- if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback:
- self._ota_publish_callback()
-
- def set_message_callback(self, callback: callable):
- self._message_callback = callback
-
- def on_message(self, client: mqtt.Client, userdata, msg):
- try:
- match = re.match(self.get_mqtt_topics(), msg.topic)
- self._logger.debug(f'topic: {msg.topic}')
- if not match:
- return
-
- device_id = match.group(1)
- subtopic = match.group(2)
-
- # try:
- next(d for d in self._devices if d.id == device_id)
- # except StopIteration:h
- # return
-
- message = None
- if subtopic == 'stat':
- message = DiagnosticsPayload.unpack(msg.payload)
- elif subtopic == 'stat1':
- message = InitialDiagnosticsPayload.unpack(msg.payload)
- elif subtopic == 'otares':
- message = OTAResultPayload.unpack(msg.payload)
-
- if message and self._message_callback:
- self._message_callback(device_id, message)
- return True
-
- except Exception as e:
- self._logger.exception(str(e))
-
- def push_ota(self,
- device_id,
- filename: str,
- publish_callback: callable,
- qos: int):
- device = next(d for d in self._devices if d.id == device_id)
- assert device.secret is not None, 'device secret not specified'
-
- self._ota_publish_callback = publish_callback
- payload = OTAPayload(secret=device.secret, filename=filename)
- publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
- payload=payload.pack(),
- qos=qos)
- self._ota_mid = publish_result.mid
- self._client.loop_write()
-
- @classmethod
- def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None):
- return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$' \ No newline at end of file
diff --git a/src/home/mqtt/payload/esp.py b/src/home/mqtt/module/diagnostics.py
index 171cdb9..5db5e99 100644
--- a/src/home/mqtt/payload/esp.py
+++ b/src/home/mqtt/module/diagnostics.py
@@ -1,39 +1,8 @@
-import hashlib
+from .._payload import MqttPayload, MqttPayloadCustomField
+from .._node import MqttNode, MqttModule
+from typing import Optional
-from .base_payload import MqttPayload, MqttPayloadCustomField
-
-
-class OTAResultPayload(MqttPayload):
- FORMAT = '=BB'
- result: int
- error_code: int
-
-
-class OTAPayload(MqttPayload):
- secret: str
- filename: str
-
- # structure of returned data:
- #
- # uint8_t[len(secret)] secret;
- # uint8_t[16] md5;
- # *uint8_t data
-
- def pack(self):
- buf = bytearray(self.secret.encode())
- m = hashlib.md5()
- with open(self.filename, 'rb') as fd:
- content = fd.read()
- m.update(content)
- buf.extend(m.digest())
- buf.extend(content)
- return buf
-
- def unpack(cls, buf: bytes):
- raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented')
- # secret = buf[:12].decode()
- # filename = buf[12:].decode()
- # return OTAPayload(secret=secret, filename=filename)
+MODULE_NAME = 'MqttDiagnosticsModule'
class DiagnosticsFlags(MqttPayloadCustomField):
@@ -76,3 +45,20 @@ class DiagnosticsPayload(MqttPayload):
rssi: int
free_heap: int
flags: DiagnosticsFlags
+
+
+class MqttDiagnosticsModule(MqttModule):
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
+ for topic in ('diag', 'd1ag', 'stat', 'stat1'):
+ mqtt.subscribe_module(topic, self)
+
+ def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
+ message = None
+ if topic in ('stat', 'diag'):
+ message = DiagnosticsPayload.unpack(payload)
+ elif topic in ('stat1', 'd1ag'):
+ message = InitialDiagnosticsPayload.unpack(payload)
+ if message:
+ self._logger.debug(message)
+ return message
diff --git a/src/home/mqtt/module/inverter.py b/src/home/mqtt/module/inverter.py
new file mode 100644
index 0000000..d927a06
--- /dev/null
+++ b/src/home/mqtt/module/inverter.py
@@ -0,0 +1,195 @@
+import time
+import json
+import datetime
+try:
+ import inverterd
+except:
+ pass
+
+from typing import Optional
+from .._module import MqttModule
+from .._node import MqttNode
+from .._payload import MqttPayload, bit_field
+try:
+ from home.database import InverterDatabase
+except:
+ pass
+
+_mult_10 = lambda n: int(n*10)
+_div_10 = lambda n: n/10
+
+
+MODULE_NAME = 'MqttInverterModule'
+
+STATUS_TOPIC = 'status'
+GENERATION_TOPIC = 'generation'
+
+
+class MqttInverterStatusPayload(MqttPayload):
+ # 46 bytes
+ FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH'
+
+ PACKER = {
+ 'grid_voltage': _mult_10,
+ 'grid_freq': _mult_10,
+ 'ac_output_voltage': _mult_10,
+ 'ac_output_freq': _mult_10,
+ 'battery_voltage': _mult_10,
+ 'battery_voltage_scc': _mult_10,
+ 'battery_voltage_scc2': _mult_10,
+ 'pv1_input_voltage': _mult_10,
+ 'pv2_input_voltage': _mult_10
+ }
+ UNPACKER = {
+ 'grid_voltage': _div_10,
+ 'grid_freq': _div_10,
+ 'ac_output_voltage': _div_10,
+ 'ac_output_freq': _div_10,
+ 'battery_voltage': _div_10,
+ 'battery_voltage_scc': _div_10,
+ 'battery_voltage_scc2': _div_10,
+ 'pv1_input_voltage': _div_10,
+ 'pv2_input_voltage': _div_10
+ }
+
+ time: int
+ grid_voltage: float
+ grid_freq: float
+ ac_output_voltage: float
+ ac_output_freq: float
+ ac_output_apparent_power: int
+ ac_output_active_power: int
+ output_load_percent: int
+ battery_voltage: float
+ battery_voltage_scc: float
+ battery_voltage_scc2: float
+ battery_discharge_current: int
+ battery_charge_current: int
+ battery_capacity: int
+ inverter_heat_sink_temp: int
+ mppt1_charger_temp: int
+ mppt2_charger_temp: int
+ pv1_input_power: int
+ pv2_input_power: int
+ pv1_input_voltage: float
+ pv2_input_voltage: float
+
+ # H
+ mppt1_charger_status: bit_field(0, 16, 2)
+ mppt2_charger_status: bit_field(0, 16, 2)
+ battery_power_direction: bit_field(0, 16, 2)
+ dc_ac_power_direction: bit_field(0, 16, 2)
+ line_power_direction: bit_field(0, 16, 2)
+ load_connected: bit_field(0, 16, 1)
+
+
+class MqttInverterGenerationPayload(MqttPayload):
+ # 8 bytes
+ FORMAT = 'II'
+
+ time: int
+ wh: int
+
+
+class MqttInverterModule(MqttModule):
+ _status_poll_freq: int
+ _generation_poll_freq: int
+ _inverter: Optional[inverterd.Client]
+ _database: Optional[InverterDatabase]
+ _gen_prev: float
+
+ def __init__(self, status_poll_freq=0, generation_poll_freq=0):
+ super().__init__(tick_interval=status_poll_freq)
+ self._status_poll_freq = status_poll_freq
+ self._generation_poll_freq = generation_poll_freq
+
+ # this defines whether this is a publisher or a subscriber
+ if status_poll_freq > 0:
+ self._inverter = inverterd.Client()
+ self._inverter.connect()
+ self._inverter.format(inverterd.Format.SIMPLE_JSON)
+ self._database = None
+ else:
+ self._inverter = None
+ self._database = InverterDatabase()
+
+ self._gen_prev = 0
+
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
+ if not self._inverter:
+ mqtt.subscribe_module(STATUS_TOPIC, self)
+ mqtt.subscribe_module(GENERATION_TOPIC, self)
+
+ def tick(self):
+ if not self._inverter:
+ return
+
+ # read status
+ now = time.time()
+ try:
+ raw = self._inverter.exec('get-status')
+ except inverterd.InverterError as e:
+ self._logger.error(f'inverter error: {str(e)}')
+ # TODO send to server
+ return
+
+ data = json.loads(raw)['data']
+ status = MqttInverterStatusPayload(time=round(now), **data)
+ self._mqtt_node_ref.publish(STATUS_TOPIC, status.pack())
+
+ # read today's generation stat
+ now = time.time()
+ if self._gen_prev == 0 or now - self._gen_prev >= self._generation_poll_freq:
+ self._gen_prev = now
+ today = datetime.date.today()
+ try:
+ raw = self._inverter.exec('get-day-generated', (today.year, today.month, today.day))
+ except inverterd.InverterError as e:
+ self._logger.error(f'inverter error: {str(e)}')
+ # TODO send to server
+ return
+
+ data = json.loads(raw)['data']
+ gen = MqttInverterGenerationPayload(time=round(now), wh=data['wh'])
+ self._mqtt_node_ref.publish(GENERATION_TOPIC, gen.pack())
+
+ def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
+ home_id = 1 # legacy compat
+
+ if topic == STATUS_TOPIC:
+ s = MqttInverterStatusPayload.unpack(payload)
+ self._database.add_status(home_id=home_id,
+ client_time=s.time,
+ grid_voltage=int(s.grid_voltage*10),
+ grid_freq=int(s.grid_freq * 10),
+ ac_output_voltage=int(s.ac_output_voltage * 10),
+ ac_output_freq=int(s.ac_output_freq * 10),
+ ac_output_apparent_power=s.ac_output_apparent_power,
+ ac_output_active_power=s.ac_output_active_power,
+ output_load_percent=s.output_load_percent,
+ battery_voltage=int(s.battery_voltage * 10),
+ battery_voltage_scc=int(s.battery_voltage_scc * 10),
+ battery_voltage_scc2=int(s.battery_voltage_scc2 * 10),
+ battery_discharge_current=s.battery_discharge_current,
+ battery_charge_current=s.battery_charge_current,
+ battery_capacity=s.battery_capacity,
+ inverter_heat_sink_temp=s.inverter_heat_sink_temp,
+ mppt1_charger_temp=s.mppt1_charger_temp,
+ mppt2_charger_temp=s.mppt2_charger_temp,
+ pv1_input_power=s.pv1_input_power,
+ pv2_input_power=s.pv2_input_power,
+ pv1_input_voltage=int(s.pv1_input_voltage * 10),
+ pv2_input_voltage=int(s.pv2_input_voltage * 10),
+ mppt1_charger_status=s.mppt1_charger_status,
+ mppt2_charger_status=s.mppt2_charger_status,
+ battery_power_direction=s.battery_power_direction,
+ dc_ac_power_direction=s.dc_ac_power_direction,
+ line_power_direction=s.line_power_direction,
+ load_connected=s.load_connected)
+ return s
+
+ elif topic == GENERATION_TOPIC:
+ gen = MqttInverterGenerationPayload.unpack(payload)
+ self._database.add_generation(home_id, gen.time, gen.wh)
+ return gen
diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py
new file mode 100644
index 0000000..cd34332
--- /dev/null
+++ b/src/home/mqtt/module/ota.py
@@ -0,0 +1,77 @@
+import hashlib
+
+from typing import Optional
+from .._payload import MqttPayload
+from .._node import MqttModule, MqttNode
+
+MODULE_NAME = 'MqttOtaModule'
+
+
+class OtaResultPayload(MqttPayload):
+ FORMAT = '=BB'
+ result: int
+ error_code: int
+
+
+class OtaPayload(MqttPayload):
+ secret: str
+ filename: str
+
+ # structure of returned data:
+ #
+ # uint8_t[len(secret)] secret;
+ # uint8_t[16] md5;
+ # *uint8_t data
+
+ def pack(self):
+ buf = bytearray(self.secret.encode())
+ m = hashlib.md5()
+ with open(self.filename, 'rb') as fd:
+ content = fd.read()
+ m.update(content)
+ buf.extend(m.digest())
+ buf.extend(content)
+ return buf
+
+ def unpack(cls, buf: bytes):
+ raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented')
+ # secret = buf[:12].decode()
+ # filename = buf[12:].decode()
+ # return OTAPayload(secret=secret, filename=filename)
+
+
+class MqttOtaModule(MqttModule):
+ _ota_request: Optional[tuple[str, int]]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._ota_request = None
+
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
+ mqtt.subscribe_module("otares", self)
+
+ if self._ota_request is not None:
+ filename, qos = self._ota_request
+ self._ota_request = None
+ self.do_push_ota(self._mqtt_node_ref.secret, filename, qos)
+
+ def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
+ if topic == 'otares':
+ message = OtaResultPayload.unpack(payload)
+ self._logger.debug(message)
+ return message
+
+ def do_push_ota(self, secret: str, filename: str, qos: int):
+ payload = OtaPayload(secret=secret, filename=filename)
+ self._mqtt_node_ref.publish('ota',
+ payload=payload.pack(),
+ qos=qos)
+
+ def push_ota(self,
+ filename: str,
+ qos: int):
+ if not self._initialized:
+ self._ota_request = (filename, qos)
+ else:
+ self.do_push_ota(filename, qos)
diff --git a/src/home/mqtt/module/relay.py b/src/home/mqtt/module/relay.py
new file mode 100644
index 0000000..e968031
--- /dev/null
+++ b/src/home/mqtt/module/relay.py
@@ -0,0 +1,92 @@
+import datetime
+
+from typing import Optional
+from .. import MqttModule, MqttPayload, MqttNode
+
+MODULE_NAME = 'MqttRelayModule'
+
+
+class MqttPowerSwitchPayload(MqttPayload):
+ FORMAT = '=12sB'
+ PACKER = {
+ 'state': lambda n: int(n),
+ 'secret': lambda s: s.encode('utf-8')
+ }
+ UNPACKER = {
+ 'state': lambda n: bool(n),
+ 'secret': lambda s: s.decode('utf-8')
+ }
+
+ secret: str
+ state: bool
+
+
+class MqttPowerStatusPayload(MqttPayload):
+ FORMAT = '=B'
+ PACKER = {
+ 'opened': lambda n: int(n),
+ }
+ UNPACKER = {
+ 'opened': lambda n: bool(n),
+ }
+
+ opened: bool
+
+
+class MqttRelayState:
+ enabled: bool
+ update_time: datetime.datetime
+ rssi: int
+ fw_version: int
+ ever_updated: bool
+
+ def __init__(self):
+ self.ever_updated = False
+ self.enabled = False
+ self.rssi = 0
+
+ def update(self,
+ enabled: bool,
+ rssi: int,
+ fw_version=None):
+ self.ever_updated = True
+ self.enabled = enabled
+ self.rssi = rssi
+ self.update_time = datetime.datetime.now()
+ if fw_version:
+ self.fw_version = fw_version
+
+
+class MqttRelayModule(MqttModule):
+ _legacy_topics: bool
+
+ def __init__(self, legacy_topics=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._legacy_topics = legacy_topics
+
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
+ mqtt.subscribe_module(self._get_switch_topic(), self)
+ mqtt.subscribe_module('relay/status', self)
+
+ def switchpower(self,
+ enable: bool):
+ payload = MqttPowerSwitchPayload(secret=self._mqtt_node_ref.secret,
+ state=enable)
+ self._mqtt_node_ref.publish(self._get_switch_topic(),
+ payload=payload.pack())
+
+ def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
+ message = None
+
+ if topic == self._get_switch_topic():
+ message = MqttPowerSwitchPayload.unpack(payload)
+ elif topic == 'relay/status':
+ message = MqttPowerStatusPayload.unpack(payload)
+
+ if message is not None:
+ self._logger.debug(message)
+ return message
+
+ def _get_switch_topic(self) -> str:
+ return 'relay/power' if self._legacy_topics else 'relay/switch'
diff --git a/src/home/mqtt/module/temphum.py b/src/home/mqtt/module/temphum.py
new file mode 100644
index 0000000..fd02cca
--- /dev/null
+++ b/src/home/mqtt/module/temphum.py
@@ -0,0 +1,82 @@
+from .._node import MqttNode
+from .._module import MqttModule
+from .._payload import MqttPayload
+from typing import Optional
+from ...temphum import BaseSensor
+
+two_digits_precision = lambda x: round(x, 2)
+
+MODULE_NAME = 'MqttTempHumModule'
+DATA_TOPIC = 'temphum/data'
+
+
+class MqttTemphumDataPayload(MqttPayload):
+ FORMAT = '=ddb'
+ UNPACKER = {
+ 'temp': two_digits_precision,
+ 'rh': two_digits_precision
+ }
+
+ temp: float
+ rh: float
+ error: int
+
+
+# class MqttTempHumNodes(HashableEnum):
+# KBN_SH_HALL = auto()
+# KBN_SH_BATHROOM = auto()
+# KBN_SH_LIVINGROOM = auto()
+# KBN_SH_BEDROOM = auto()
+#
+# KBN_BH_2FL = auto()
+# KBN_BH_2FL_STREET = auto()
+# KBN_BH_1FL_LIVINGROOM = auto()
+# KBN_BH_1FL_BEDROOM = auto()
+# KBN_BH_1FL_BATHROOM = auto()
+#
+# KBN_NH_1FL_INV = auto()
+# KBN_NH_1FL_CENTER = auto()
+# KBN_NH_1LF_KT = auto()
+# KBN_NH_1FL_DS = auto()
+# KBN_NH_1FS_EZ = auto()
+#
+# SPB_FLAT120_CABINET = auto()
+
+
+class MqttTempHumModule(MqttModule):
+ def __init__(self,
+ sensor: Optional[BaseSensor] = None,
+ write_to_database=False,
+ *args, **kwargs):
+ if sensor is not None:
+ kwargs['tick_interval'] = 10
+ super().__init__(*args, **kwargs)
+ self._sensor = sensor
+
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
+ mqtt.subscribe_module(DATA_TOPIC, self)
+
+ def tick(self):
+ if not self._sensor:
+ return
+
+ error = 0
+ temp = 0
+ rh = 0
+ try:
+ temp = self._sensor.temperature()
+ rh = self._sensor.humidity()
+ except:
+ error = 1
+ pld = MqttTemphumDataPayload(temp=temp, rh=rh, error=error)
+ self._mqtt_node_ref.publish(DATA_TOPIC, pld.pack())
+
+ def handle_payload(self,
+ mqtt: MqttNode,
+ topic: str,
+ payload: bytes) -> Optional[MqttPayload]:
+ if topic == DATA_TOPIC:
+ message = MqttTemphumDataPayload.unpack(payload)
+ self._logger.debug(message)
+ return message
diff --git a/src/home/mqtt/payload/__init__.py b/src/home/mqtt/payload/__init__.py
deleted file mode 100644
index eee6709..0000000
--- a/src/home/mqtt/payload/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .base_payload import MqttPayload \ No newline at end of file
diff --git a/src/home/mqtt/payload/inverter.py b/src/home/mqtt/payload/inverter.py
deleted file mode 100644
index 09388df..0000000
--- a/src/home/mqtt/payload/inverter.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import struct
-
-from .base_payload import MqttPayload, bit_field
-from typing import Tuple
-
-_mult_10 = lambda n: int(n*10)
-_div_10 = lambda n: n/10
-
-
-class Status(MqttPayload):
- # 46 bytes
- FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH'
-
- PACKER = {
- 'grid_voltage': _mult_10,
- 'grid_freq': _mult_10,
- 'ac_output_voltage': _mult_10,
- 'ac_output_freq': _mult_10,
- 'battery_voltage': _mult_10,
- 'battery_voltage_scc': _mult_10,
- 'battery_voltage_scc2': _mult_10,
- 'pv1_input_voltage': _mult_10,
- 'pv2_input_voltage': _mult_10
- }
- UNPACKER = {
- 'grid_voltage': _div_10,
- 'grid_freq': _div_10,
- 'ac_output_voltage': _div_10,
- 'ac_output_freq': _div_10,
- 'battery_voltage': _div_10,
- 'battery_voltage_scc': _div_10,
- 'battery_voltage_scc2': _div_10,
- 'pv1_input_voltage': _div_10,
- 'pv2_input_voltage': _div_10
- }
-
- time: int
- grid_voltage: float
- grid_freq: float
- ac_output_voltage: float
- ac_output_freq: float
- ac_output_apparent_power: int
- ac_output_active_power: int
- output_load_percent: int
- battery_voltage: float
- battery_voltage_scc: float
- battery_voltage_scc2: float
- battery_discharge_current: int
- battery_charge_current: int
- battery_capacity: int
- inverter_heat_sink_temp: int
- mppt1_charger_temp: int
- mppt2_charger_temp: int
- pv1_input_power: int
- pv2_input_power: int
- pv1_input_voltage: float
- pv2_input_voltage: float
-
- # H
- mppt1_charger_status: bit_field(0, 16, 2)
- mppt2_charger_status: bit_field(0, 16, 2)
- battery_power_direction: bit_field(0, 16, 2)
- dc_ac_power_direction: bit_field(0, 16, 2)
- line_power_direction: bit_field(0, 16, 2)
- load_connected: bit_field(0, 16, 1)
-
-
-class Generation(MqttPayload):
- # 8 bytes
- FORMAT = 'II'
-
- time: int
- wh: int
diff --git a/src/home/mqtt/payload/relay.py b/src/home/mqtt/payload/relay.py
deleted file mode 100644
index 4902991..0000000
--- a/src/home/mqtt/payload/relay.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from .base_payload import MqttPayload
-from .esp import (
- OTAResultPayload,
- OTAPayload,
- InitialDiagnosticsPayload,
- DiagnosticsPayload
-)
-
-
-class PowerPayload(MqttPayload):
- FORMAT = '=12sB'
- PACKER = {
- 'state': lambda n: int(n),
- 'secret': lambda s: s.encode('utf-8')
- }
- UNPACKER = {
- 'state': lambda n: bool(n),
- 'secret': lambda s: s.decode('utf-8')
- }
-
- secret: str
- state: bool
diff --git a/src/home/mqtt/payload/sensors.py b/src/home/mqtt/payload/sensors.py
deleted file mode 100644
index f99b307..0000000
--- a/src/home/mqtt/payload/sensors.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from .base_payload import MqttPayload
-
-_mult_100 = lambda n: int(n*100)
-_div_100 = lambda n: n/100
-
-
-class Temperature(MqttPayload):
- FORMAT = 'IhH'
- PACKER = {
- 'temp': _mult_100,
- 'rh': _mult_100,
- }
- UNPACKER = {
- 'temp': _div_100,
- 'rh': _div_100,
- }
-
- time: int
- temp: float
- rh: float
diff --git a/src/home/mqtt/payload/temphum.py b/src/home/mqtt/payload/temphum.py
deleted file mode 100644
index c0b744e..0000000
--- a/src/home/mqtt/payload/temphum.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from .base_payload import MqttPayload
-
-two_digits_precision = lambda x: round(x, 2)
-
-
-class TempHumDataPayload(MqttPayload):
- FORMAT = '=ddb'
- UNPACKER = {
- 'temp': two_digits_precision,
- 'rh': two_digits_precision
- }
-
- temp: float
- rh: float
- error: int
diff --git a/src/home/mqtt/relay.py b/src/home/mqtt/relay.py
deleted file mode 100644
index a90f19c..0000000
--- a/src/home/mqtt/relay.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import paho.mqtt.client as mqtt
-import re
-import datetime
-
-from .payload.relay import (
- PowerPayload,
-)
-from .esp import MqttEspBase
-
-
-class MqttRelay(MqttEspBase):
- TOPIC_LEAF = 'relay'
-
- def set_power(self, device_id, enable: bool, secret=None):
- device = next(d for d in self._devices if d.id == device_id)
- secret = secret if secret else device.secret
-
- assert secret is not None, 'device secret not specified'
-
- payload = PowerPayload(secret=secret,
- state=enable)
- self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/power',
- payload=payload.pack(),
- qos=1)
- self._client.loop_write()
-
- def on_message(self, client: mqtt.Client, userdata, msg):
- if super().on_message(client, userdata, msg):
- return
-
- try:
- match = re.match(self.get_mqtt_topics(['power']), msg.topic)
- if not match:
- return
-
- device_id = match.group(1)
- subtopic = match.group(2)
-
- message = None
- if subtopic == 'power':
- message = PowerPayload.unpack(msg.payload)
-
- if message and self._message_callback:
- self._message_callback(device_id, message)
-
- except Exception as e:
- self._logger.exception(str(e))
-
-
-class MqttRelayState:
- enabled: bool
- update_time: datetime.datetime
- rssi: int
- fw_version: int
- ever_updated: bool
-
- def __init__(self):
- self.ever_updated = False
- self.enabled = False
- self.rssi = 0
-
- def update(self,
- enabled: bool,
- rssi: int,
- fw_version=None):
- self.ever_updated = True
- self.enabled = enabled
- self.rssi = rssi
- self.update_time = datetime.datetime.now()
- if fw_version:
- self.fw_version = fw_version
diff --git a/src/home/mqtt/temphum.py b/src/home/mqtt/temphum.py
deleted file mode 100644
index 44810ef..0000000
--- a/src/home/mqtt/temphum.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import paho.mqtt.client as mqtt
-import re
-
-from enum import auto
-from .payload.temphum import TempHumDataPayload
-from .esp import MqttEspBase
-from ..util import HashableEnum
-
-
-class MqttTempHumNodes(HashableEnum):
- KBN_SH_HALL = auto()
- KBN_SH_BATHROOM = auto()
- KBN_SH_LIVINGROOM = auto()
- KBN_SH_BEDROOM = auto()
-
- KBN_BH_2FL = auto()
- KBN_BH_2FL_STREET = auto()
- KBN_BH_1FL_LIVINGROOM = auto()
- KBN_BH_1FL_BEDROOM = auto()
- KBN_BH_1FL_BATHROOM = auto()
-
- KBN_NH_1FL_INV = auto()
- KBN_NH_1FL_CENTER = auto()
- KBN_NH_1LF_KT = auto()
- KBN_NH_1FL_DS = auto()
- KBN_NH_1FS_EZ = auto()
-
- SPB_FLAT120_CABINET = auto()
-
-
-class MqttTempHum(MqttEspBase):
- TOPIC_LEAF = 'temphum'
-
- def on_message(self, client: mqtt.Client, userdata, msg):
- if super().on_message(client, userdata, msg):
- return
-
- try:
- match = re.match(self.get_mqtt_topics(['data']), msg.topic)
- if not match:
- return
-
- device_id = match.group(1)
- subtopic = match.group(2)
-
- message = None
- if subtopic == 'data':
- message = TempHumDataPayload.unpack(msg.payload)
-
- if message and self._message_callback:
- self._message_callback(device_id, message)
-
- except Exception as e:
- self._logger.exception(str(e))
diff --git a/src/home/mqtt/util.py b/src/home/mqtt/util.py
deleted file mode 100644
index f71ffd8..0000000
--- a/src/home/mqtt/util.py
+++ /dev/null
@@ -1,8 +0,0 @@
-import time
-
-
-def poll_tick(freq):
- t = time.time()
- while True:
- t += freq
- yield max(t - time.time(), 0)
diff --git a/src/home/pio/products.py b/src/home/pio/products.py
index 7649078..388da03 100644
--- a/src/home/pio/products.py
+++ b/src/home/pio/products.py
@@ -16,10 +16,6 @@ _products_dir = os.path.join(
def get_products():
products = []
for f in os.listdir(_products_dir):
- # temp hack
- if f.endswith('-esp01'):
- continue
- # skip the common dir
if f in ('common',):
continue
diff --git a/src/home/telegram/_botcontext.py b/src/home/telegram/_botcontext.py
index f343eeb..a143bfe 100644
--- a/src/home/telegram/_botcontext.py
+++ b/src/home/telegram/_botcontext.py
@@ -1,6 +1,7 @@
from typing import Optional, List
-from telegram import Update, ParseMode, User, CallbackQuery
+from telegram import Update, User, CallbackQuery
+from telegram.constants import ParseMode
from telegram.ext import CallbackContext
from ._botdb import BotDatabase
@@ -26,25 +27,25 @@ class Context:
self._store = store
self._user_lang = None
- def reply(self, text, markup=None):
+ async def reply(self, text, markup=None):
if markup is None:
markup = self._markup_getter(self)
kwargs = dict(parse_mode=ParseMode.HTML)
if not isinstance(markup, IgnoreMarkup):
kwargs['reply_markup'] = markup
- return self._update.message.reply_text(text, **kwargs)
+ return await self._update.message.reply_text(text, **kwargs)
- def reply_exc(self, e: Exception) -> None:
- self.reply(exc2text(e), markup=IgnoreMarkup())
+ async def reply_exc(self, e: Exception) -> None:
+ await self.reply(exc2text(e), markup=IgnoreMarkup())
- def answer(self, text: str = None):
- self.callback_query.answer(text)
+ async def answer(self, text: str = None):
+ await self.callback_query.answer(text)
- def edit(self, text, markup=None):
+ async def edit(self, text, markup=None):
kwargs = dict(parse_mode=ParseMode.HTML)
if not isinstance(markup, IgnoreMarkup):
kwargs['reply_markup'] = markup
- self.callback_query.edit_message_text(text, **kwargs)
+ await self.callback_query.edit_message_text(text, **kwargs)
@property
def text(self) -> str:
diff --git a/src/home/telegram/bot.py b/src/home/telegram/bot.py
index 10bfe06..7e22263 100644
--- a/src/home/telegram/bot.py
+++ b/src/home/telegram/bot.py
@@ -5,19 +5,19 @@ import itertools
from enum import Enum, auto
from functools import wraps
-from typing import Optional, Union, Tuple
+from typing import Optional, Union, Tuple, Coroutine
from telegram import Update, ReplyKeyboardMarkup
from telegram.ext import (
- Updater,
- Filters,
- BaseFilter,
+ Application,
+ filters,
CommandHandler,
MessageHandler,
CallbackQueryHandler,
CallbackContext,
ConversationHandler
)
+from telegram.ext.filters import BaseFilter
from telegram.error import TimedOut
from home.config import config
@@ -33,26 +33,26 @@ from ._botcontext import Context
db: Optional[BotDatabase] = None
_user_filter: Optional[BaseFilter] = None
-_cancel_filter = Filters.text(lang.all('cancel'))
-_back_filter = Filters.text(lang.all('back'))
-_cancel_and_back_filter = Filters.text(lang.all('back') + lang.all('cancel'))
+_cancel_filter = filters.Text(lang.all('cancel'))
+_back_filter = filters.Text(lang.all('back'))
+_cancel_and_back_filter = filters.Text(lang.all('back') + lang.all('cancel'))
_logger = logging.getLogger(__name__)
-_updater: Optional[Updater] = None
+_application: Optional[Application] = None
_reporting: Optional[ReportingHelper] = None
-_exception_handler: Optional[callable] = None
+_exception_handler: Optional[Coroutine] = None
_dispatcher = None
_markup_getter: Optional[callable] = None
-_start_handler_ref: Optional[callable] = None
+_start_handler_ref: Optional[Coroutine] = None
def text_filter(*args):
if not _user_filter:
raise RuntimeError('user_filter is not initialized')
- return Filters.text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter
+ return filters.Text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter
-def _handler_of_handler(*args, **kwargs):
+async def _handler_of_handler(*args, **kwargs):
self = None
context = None
update = None
@@ -99,7 +99,7 @@ def _handler_of_handler(*args, **kwargs):
if self:
_args.insert(0, self)
- result = f(*_args, **kwargs)
+ result = await f(*_args, **kwargs)
return result if not return_with_context else (result, ctx)
except Exception as e:
@@ -107,7 +107,7 @@ def _handler_of_handler(*args, **kwargs):
if not _exception_handler(e, ctx) and not isinstance(e, TimedOut):
_logger.exception(e)
if not ctx.is_callback_context():
- ctx.reply_exc(e)
+ await ctx.reply_exc(e)
else:
notify_user(ctx.user_id, exc2text(e))
else:
@@ -117,10 +117,10 @@ def _handler_of_handler(*args, **kwargs):
def handler(**kwargs):
def inner(f):
@wraps(f)
- def _handler(*args, **inner_kwargs):
+ async def _handler(*args, **inner_kwargs):
if 'argument' in kwargs and kwargs['argument'] == 'message_key':
inner_kwargs['argument'] = 'message_key'
- return _handler_of_handler(f=f, *args, **inner_kwargs)
+ return await _handler_of_handler(f=f, *args, **inner_kwargs)
messages = []
texts = []
@@ -139,43 +139,43 @@ def handler(**kwargs):
new_messages = list(itertools.chain.from_iterable([lang.all(m) for m in messages]))
texts += new_messages
texts = list(set(texts))
- _updater.dispatcher.add_handler(
+ _application.add_handler(
MessageHandler(text_filter(*texts), _handler),
group=0
)
if 'command' in kwargs:
- _updater.dispatcher.add_handler(CommandHandler(kwargs['command'], _handler), group=0)
+ _application.add_handler(CommandHandler(kwargs['command'], _handler), group=0)
if 'callback' in kwargs:
- _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0)
+ _application.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0)
return _handler
return inner
-def simplehandler(f: callable):
+def simplehandler(f: Coroutine):
@wraps(f)
- def _handler(*args, **kwargs):
- return _handler_of_handler(f=f, *args, **kwargs)
+ async def _handler(*args, **kwargs):
+ return await _handler_of_handler(f=f, *args, **kwargs)
return _handler
def callbackhandler(*args, **kwargs):
def inner(f):
@wraps(f)
- def _handler(*args, **kwargs):
- return _handler_of_handler(f=f, *args, **kwargs)
+ async def _handler(*args, **kwargs):
+ return await _handler_of_handler(f=f, *args, **kwargs)
pattern_kwargs = {}
if kwargs['callback'] != '*':
pattern_kwargs['pattern'] = kwargs['callback']
- _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0)
+ _application.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0)
return _handler
return inner
-def exceptionhandler(f: callable):
+async def exceptionhandler(f: callable):
global _exception_handler
if _exception_handler:
_logger.warning('exception handler already set, we will overwrite it')
@@ -198,10 +198,10 @@ def convinput(state, is_enter=False, **kwargs):
)
@wraps(f)
- def _impl(*args, **kwargs):
- result, ctx = _handler_of_handler(f=f, *args, **kwargs, return_with_context=True)
+ async def _impl(*args, **kwargs):
+ result, ctx = await _handler_of_handler(f=f, *args, **kwargs, return_with_context=True)
if result == conversation.END:
- start(ctx)
+ await start(ctx)
return result
return _impl
@@ -252,7 +252,7 @@ class conversation:
handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), self.make_invoker(target_state)))
if 'regex' in kwargs:
- handlers.append(MessageHandler(Filters.regex(kwargs['regex']) & _user_filter, f))
+ handlers.append(MessageHandler(filters.Regex(kwargs['regex']) & _user_filter, f))
if 'command' in kwargs:
handlers.append(CommandHandler(kwargs['command'], f, _user_filter))
@@ -327,21 +327,21 @@ class conversation:
@staticmethod
@simplehandler
- def invalid(ctx: Context):
- ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup())
+ async def invalid(ctx: Context):
+ await ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup())
# return 0 # FIXME is this needed
@simplehandler
- def cancel(self, ctx: Context):
- start(ctx)
+ async def cancel(self, ctx: Context):
+ await start(ctx)
self.set_user_state(ctx.user_id, None)
return conversation.END
@simplehandler
- def back(self, ctx: Context):
+ async def back(self, ctx: Context):
cur_state = self.get_user_state(ctx.user_id)
if cur_state is None:
- start(ctx)
+ await start(ctx)
self.set_user_state(ctx.user_id, None)
return conversation.END
@@ -411,7 +411,7 @@ class LangConversation(conversation):
START, = range(1)
@conventer(START, command='lang')
- def entry(self, ctx: Context):
+ async def entry(self, ctx: Context):
self._logger.debug(f'current language: {ctx.user_lang}')
buttons = []
@@ -419,11 +419,11 @@ class LangConversation(conversation):
buttons.append(name)
markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False)
- ctx.reply(ctx.lang('select_language'), markup=markup)
+ await ctx.reply(ctx.lang('select_language'), markup=markup)
return self.START
@convinput(START, messages=lang.languages)
- def input(self, ctx: Context):
+ async def input(self, ctx: Context):
selected_lang = None
for key, value in languages.items():
if value == ctx.text:
@@ -434,30 +434,34 @@ class LangConversation(conversation):
raise ValueError('could not find the language')
db.set_user_lang(ctx.user_id, selected_lang)
- ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup())
+ await ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup())
return self.END
def initialize():
global _user_filter
- global _updater
+ global _application
+ # global _updater
global _dispatcher
# init user_filter
- if 'users' in config['bot']:
- _logger.info('allowed users: ' + str(config['bot']['users']))
- _user_filter = Filters.user(config['bot']['users'])
+ _user_ids = config.app_config.get_user_ids()
+ if len(_user_ids) > 0:
+ _logger.info('allowed users: ' + str(_user_ids))
+ _user_filter = filters.User(_user_ids)
else:
- _user_filter = Filters.all # not sure if this is correct
+ _user_filter = filters.ALL # not sure if this is correct
- # init updater
- _updater = Updater(config['bot']['token'],
- request_kwargs={'read_timeout': 6, 'connect_timeout': 7})
+ _application = Application.builder()\
+ .token(config.app_config.get('bot.token'))\
+ .connect_timeout(7)\
+ .read_timeout(6)\
+ .build()
# transparently log all messages
- _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, _logging_message_handler), group=10)
- _updater.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10)
+ # _application.dispatcher.add_handler(MessageHandler(filters.ALL & _user_filter, _logging_message_handler), group=10)
+ # _application.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10)
def run(start_handler=None, any_handler=None):
@@ -473,37 +477,38 @@ def run(start_handler=None, any_handler=None):
_start_handler_ref = start_handler
- _updater.dispatcher.add_handler(LangConversation().get_handler(), group=0)
- _updater.dispatcher.add_handler(CommandHandler('start', simplehandler(start_handler), _user_filter))
- _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, any_handler))
+ _application.add_handler(LangConversation().get_handler(), group=0)
+ _application.add_handler(CommandHandler('start',
+ callback=simplehandler(start_handler),
+ filters=_user_filter))
+ _application.add_handler(MessageHandler(filters.ALL & _user_filter, any_handler))
- _updater.start_polling()
- _updater.idle()
+ _application.run_polling()
def add_conversation(conv: conversation) -> None:
- _updater.dispatcher.add_handler(conv.get_handler(), group=0)
+ _application.add_handler(conv.get_handler(), group=0)
def add_handler(h):
- _updater.dispatcher.add_handler(h, group=0)
+ _application.add_handler(h, group=0)
-def start(ctx: Context):
- return _start_handler_ref(ctx)
+async def start(ctx: Context):
+ return await _start_handler_ref(ctx)
-def _default_start_handler(ctx: Context):
+async def _default_start_handler(ctx: Context):
if 'start_message' not in lang:
- return ctx.reply('Please define start_message or override start()')
- ctx.reply(ctx.lang('start_message'))
+ return await ctx.reply('Please define start_message or override start()')
+ await ctx.reply(ctx.lang('start_message'))
@simplehandler
-def _default_any_handler(ctx: Context):
+async def _default_any_handler(ctx: Context):
if 'invalid_command' not in lang:
- return ctx.reply('Please define invalid_command or override any()')
- ctx.reply(ctx.lang('invalid_command'))
+ return await ctx.reply('Please define invalid_command or override any()')
+ await ctx.reply(ctx.lang('invalid_command'))
def _logging_message_handler(update: Update, context: CallbackContext):
@@ -535,7 +540,7 @@ def notify_all(text_getter: callable,
continue
text = text_getter(db.get_user_lang(user_id))
- _updater.bot.send_message(chat_id=user_id,
+ _application.bot.send_message(chat_id=user_id,
text=text,
parse_mode='HTML')
@@ -543,33 +548,33 @@ def notify_all(text_getter: callable,
def notify_user(user_id: int, text: Union[str, Exception], **kwargs) -> None:
if isinstance(text, Exception):
text = exc2text(text)
- _updater.bot.send_message(chat_id=user_id,
+ _application.bot.send_message(chat_id=user_id,
text=text,
parse_mode='HTML',
**kwargs)
def send_photo(user_id, **kwargs):
- _updater.bot.send_photo(chat_id=user_id, **kwargs)
+ _application.bot.send_photo(chat_id=user_id, **kwargs)
def send_audio(user_id, **kwargs):
- _updater.bot.send_audio(chat_id=user_id, **kwargs)
+ _application.bot.send_audio(chat_id=user_id, **kwargs)
def send_file(user_id, **kwargs):
- _updater.bot.send_document(chat_id=user_id, **kwargs)
+ _application.bot.send_document(chat_id=user_id, **kwargs)
def edit_message_text(user_id, message_id, *args, **kwargs):
- _updater.bot.edit_message_text(chat_id=user_id,
+ _application.bot.edit_message_text(chat_id=user_id,
message_id=message_id,
parse_mode='HTML',
*args, **kwargs)
def delete_message(user_id, message_id):
- _updater.bot.delete_message(chat_id=user_id, message_id=message_id)
+ _application.bot.delete_message(chat_id=user_id, message_id=message_id)
def set_database(_db: BotDatabase):
diff --git a/src/home/telegram/config.py b/src/home/telegram/config.py
new file mode 100644
index 0000000..7a46087
--- /dev/null
+++ b/src/home/telegram/config.py
@@ -0,0 +1,75 @@
+from ..config import ConfigUnit
+from typing import Optional, Union
+from abc import ABC
+from enum import Enum
+
+
+class TelegramUserListType(Enum):
+ USERS = 'users'
+ NOTIFY = 'notify_users'
+
+
+class TelegramUserIdsConfig(ConfigUnit):
+ NAME = 'telegram_user_ids'
+
+ @staticmethod
+ def schema() -> Optional[dict]:
+ return {
+ 'roottype': 'dict',
+ 'type': 'integer'
+ }
+
+
+_user_ids_config = TelegramUserIdsConfig()
+
+
+def _user_id_mapper(user: Union[str, int]) -> int:
+ if isinstance(user, int):
+ return user
+ return _user_ids_config[user]
+
+
+class TelegramChatsConfig(ConfigUnit):
+ NAME = 'telegram_chats'
+
+ @staticmethod
+ def schema() -> Optional[dict]:
+ return {
+ 'type': 'dict',
+ 'schema': {
+ 'id': {'type': 'string', 'required': True},
+ 'token': {'type': 'string', 'required': True},
+ }
+ }
+
+
+class TelegramBotConfig(ConfigUnit, ABC):
+ @staticmethod
+ def schema() -> Optional[dict]:
+ return {
+ 'bot': {
+ 'type': 'dict',
+ 'schema': {
+ 'token': {'type': 'string', 'required': True},
+ TelegramUserListType.USERS: {**TelegramBotConfig._userlist_schema(), 'required': True},
+ TelegramUserListType.NOTIFY: TelegramBotConfig._userlist_schema(),
+ }
+ }
+ }
+
+ @staticmethod
+ def _userlist_schema() -> dict:
+ return {'type': 'list', 'schema': {'type': ['string', 'int']}}
+
+ @staticmethod
+ def custom_validator(data):
+ for ult in TelegramUserListType:
+ users = data['bot'][ult.value]
+ for user in users:
+ if isinstance(user, str):
+ if user not in _user_ids_config:
+ raise ValueError(f'user {user} not found in {TelegramUserIdsConfig.NAME}')
+
+ def get_user_ids(self,
+ ult: TelegramUserListType = TelegramUserListType.USERS) -> list[int]:
+ return list(map(_user_id_mapper, self['bot'][ult.value])) \ No newline at end of file
diff --git a/src/home/temphum/__init__.py b/src/home/temphum/__init__.py
index 55a7e1f..46d14e6 100644
--- a/src/home/temphum/__init__.py
+++ b/src/home/temphum/__init__.py
@@ -1,18 +1 @@
-from .base import SensorType, TempHumSensor
-from .si7021 import Si7021
-from .dht12 import DHT12
-
-__all__ = [
- 'SensorType',
- 'TempHumSensor',
- 'create_sensor'
-]
-
-
-def create_sensor(type: SensorType, bus: int) -> TempHumSensor:
- if type == SensorType.Si7021:
- return Si7021(bus)
- elif type == SensorType.DHT12:
- return DHT12(bus)
- else:
- raise ValueError('unexpected sensor type')
+from .base import SensorType, BaseSensor
diff --git a/src/home/temphum/base.py b/src/home/temphum/base.py
index e774433..602cab7 100644
--- a/src/home/temphum/base.py
+++ b/src/home/temphum/base.py
@@ -1,25 +1,19 @@
-import smbus
-
-from abc import abstractmethod, ABC
+from abc import ABC
from enum import Enum
-class TempHumSensor:
- @abstractmethod
+class BaseSensor(ABC):
+ def __init__(self, bus: int):
+ super().__init__()
+ self.bus = smbus.SMBus(bus)
+
def humidity(self) -> float:
pass
- @abstractmethod
def temperature(self) -> float:
pass
-class I2CTempHumSensor(TempHumSensor, ABC):
- def __init__(self, bus: int):
- super().__init__()
- self.bus = smbus.SMBus(bus)
-
-
class SensorType(Enum):
Si7021 = 'si7021'
- DHT12 = 'dht12'
+ DHT12 = 'dht12' \ No newline at end of file
diff --git a/src/home/temphum/dht12.py b/src/home/temphum/dht12.py
deleted file mode 100644
index d495766..0000000
--- a/src/home/temphum/dht12.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from .base import I2CTempHumSensor
-
-
-class DHT12(I2CTempHumSensor):
- i2c_addr = 0x5C
-
- def _measure(self):
- raw = self.bus.read_i2c_block_data(self.i2c_addr, 0, 5)
- if (raw[0] + raw[1] + raw[2] + raw[3]) & 0xff != raw[4]:
- raise ValueError("checksum error")
- return raw
-
- def temperature(self) -> float:
- raw = self._measure()
- temp = raw[2] + (raw[3] & 0x7f) * 0.1
- if raw[3] & 0x80:
- temp *= -1
- return temp
-
- def humidity(self) -> float:
- raw = self._measure()
- return raw[0] + raw[1] * 0.1
diff --git a/src/home/temphum/i2c.py b/src/home/temphum/i2c.py
new file mode 100644
index 0000000..7d8e2e3
--- /dev/null
+++ b/src/home/temphum/i2c.py
@@ -0,0 +1,52 @@
+import abc
+import smbus
+
+from .base import BaseSensor, SensorType
+
+
+class I2CSensor(BaseSensor, abc.ABC):
+ def __init__(self, bus: int):
+ super().__init__()
+ self.bus = smbus.SMBus(bus)
+
+
+class DHT12(I2CSensor):
+ i2c_addr = 0x5C
+
+ def _measure(self):
+ raw = self.bus.read_i2c_block_data(self.i2c_addr, 0, 5)
+ if (raw[0] + raw[1] + raw[2] + raw[3]) & 0xff != raw[4]:
+ raise ValueError("checksum error")
+ return raw
+
+ def temperature(self) -> float:
+ raw = self._measure()
+ temp = raw[2] + (raw[3] & 0x7f) * 0.1
+ if raw[3] & 0x80:
+ temp *= -1
+ return temp
+
+ def humidity(self) -> float:
+ raw = self._measure()
+ return raw[0] + raw[1] * 0.1
+
+
+class Si7021(I2CSensor):
+ i2c_addr = 0x40
+
+ def temperature(self) -> float:
+ raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE3, 2)
+ return 175.72 * (raw[0] << 8 | raw[1]) / 65536.0 - 46.85
+
+ def humidity(self) -> float:
+ raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE5, 2)
+ return 125.0 * (raw[0] << 8 | raw[1]) / 65536.0 - 6.0
+
+
+def create_sensor(type: SensorType, bus: int) -> BaseSensor:
+ if type == SensorType.Si7021:
+ return Si7021(bus)
+ elif type == SensorType.DHT12:
+ return DHT12(bus)
+ else:
+ raise ValueError('unexpected sensor type')
diff --git a/src/home/temphum/si7021.py b/src/home/temphum/si7021.py
deleted file mode 100644
index 6289e15..0000000
--- a/src/home/temphum/si7021.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from .base import I2CTempHumSensor
-
-
-class Si7021(I2CTempHumSensor):
- i2c_addr = 0x40
-
- def temperature(self) -> float:
- raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE3, 2)
- return 175.72 * (raw[0] << 8 | raw[1]) / 65536.0 - 46.85
-
- def humidity(self) -> float:
- raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE5, 2)
- return 125.0 * (raw[0] << 8 | raw[1]) / 65536.0 - 6.0
diff --git a/src/home/util.py b/src/home/util.py
index 93a9d8f..35505bc 100644
--- a/src/home/util.py
+++ b/src/home/util.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
import socket
import time
@@ -6,17 +8,57 @@ import traceback
import logging
import string
import random
+import re
from enum import Enum
from datetime import datetime
from typing import Tuple, Optional, List
from zlib import adler32
-Addr = Tuple[str, int] # network address type (host, port)
-
logger = logging.getLogger(__name__)
+def validate_ipv4_or_hostname(address: str, raise_exception: bool = False) -> bool:
+ if re.match(r'^(\d{1,3}\.){3}\d{1,3}$', address):
+ parts = address.split('.')
+ if all(0 <= int(part) < 256 for part in parts):
+ return True
+ else:
+ if raise_exception:
+ raise ValueError(f"invalid IPv4 address: {address}")
+ return False
+
+ if re.match(r'^[a-zA-Z0-9.-]+$', address):
+ return True
+ else:
+ if raise_exception:
+ raise ValueError(f"invalid hostname: {address}")
+ return False
+
+
+class Addr:
+ host: str
+ port: int
+
+ def __init__(self, host: str, port: int):
+ self.host = host
+ self.port = port
+
+ @staticmethod
+ def fromstring(addr: str) -> Addr:
+ if addr.count(':') != 1:
+ raise ValueError('invalid host:port format')
+
+ host, port = addr.split(':')
+ validate_ipv4_or_hostname(host, raise_exception=True)
+
+ port = int(port)
+ if not 0 <= port <= 65535:
+ raise ValueError(f'invalid port {port}')
+
+ return Addr(host, port)
+
+
# https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
@@ -45,21 +87,6 @@ def ipv4_valid(ip: str) -> bool:
return False
-def parse_addr(addr: str) -> Addr:
- if addr.count(':') != 1:
- raise ValueError('invalid host:port format')
-
- host, port = addr.split(':')
- if not ipv4_valid(host):
- raise ValueError('invalid ipv4 address')
-
- port = int(port)
- if not 0 <= port <= 65535:
- raise ValueError('invalid port')
-
- return host, port
-
-
def strgen(n: int):
return ''.join(random.choices(string.ascii_letters + string.digits, k=n))
@@ -193,4 +220,11 @@ def filesize_fmt(num, suffix="B") -> str:
class HashableEnum(Enum):
def hash(self) -> int:
- return adler32(self.name.encode()) \ No newline at end of file
+ return adler32(self.name.encode())
+
+
+def next_tick_gen(freq):
+ t = time.time()
+ while True:
+ t += freq
+ yield max(t - time.time(), 0) \ No newline at end of file