diff options
-rw-r--r-- | src/home/config/_configs.py | 8 | ||||
-rw-r--r-- | src/home/config/config.py | 61 | ||||
-rw-r--r-- | src/home/inverter/config.py | 4 | ||||
-rw-r--r-- | src/home/mqtt/_config.py | 8 | ||||
-rw-r--r-- | src/home/mqtt/_wrapper.py | 5 | ||||
-rw-r--r-- | src/home/telegram/config.py | 12 | ||||
-rw-r--r-- | src/home/util.py | 33 | ||||
-rwxr-xr-x | src/inverter_bot.py | 4 | ||||
-rwxr-xr-x | src/relay_mqtt_bot.py | 4 | ||||
-rwxr-xr-x | src/relay_mqtt_http_proxy.py | 89 | ||||
-rwxr-xr-x | src/test_new_config.py | 12 |
11 files changed, 159 insertions, 81 deletions
diff --git a/src/home/config/_configs.py b/src/home/config/_configs.py index 3a1aae5..1628cba 100644 --- a/src/home/config/_configs.py +++ b/src/home/config/_configs.py @@ -5,8 +5,8 @@ from typing import Optional class ServicesListConfig(ConfigUnit): NAME = 'services_list' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'type': 'list', 'empty': False, @@ -19,8 +19,8 @@ class ServicesListConfig(ConfigUnit): class LinuxBoardsConfig(ConfigUnit): NAME = 'linux_boards' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'type': 'dict', 'schema': { diff --git a/src/home/config/config.py b/src/home/config/config.py index aef9ee7..dc00d2e 100644 --- a/src/home/config/config.py +++ b/src/home/config/config.py @@ -1,10 +1,10 @@ import yaml import logging import os -import pprint +import cerberus +import cerberus.errors from abc import ABC -from cerberus import Validator, DocumentError from typing import Optional, Any, MutableMapping, Union from argparse import ArgumentParser from enum import Enum, auto @@ -12,11 +12,20 @@ from os.path import join, isdir, isfile from ..util import Addr +class MyValidator(cerberus.Validator): + def _normalize_coerce_addr(self, value): + return Addr.fromstring(value) + + +MyValidator.types_mapping['addr'] = cerberus.TypeDefinition('Addr', (Addr,), ()) + + CONFIG_DIRECTORIES = ( join(os.environ['HOME'], '.config', 'homekit'), '/etc/homekit' ) + class RootSchemaType(Enum): DEFAULT = auto() DICT = auto() @@ -95,10 +104,19 @@ class ConfigUnit(BaseConfigUnit): raise IOError(f'\'{name}.yaml\' not found') - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return None + @classmethod + def _addr_schema(cls, required=False, **kwargs): + return { + 'type': 'addr', + 'coerce': Addr.fromstring, + 'required': required, + **kwargs + } + def validate(self): schema = self.schema() if not schema: @@ -109,7 +127,7 @@ class ConfigUnit(BaseConfigUnit): schema['logging'] = { 'type': 'dict', 'schema': { - 'logging': {'type': 'bool'} + 'logging': {'type': 'boolean'} } } @@ -125,27 +143,27 @@ class ConfigUnit(BaseConfigUnit): except KeyError: pass + v = MyValidator() + if rst == RootSchemaType.DICT: - v = Validator({'document': { - 'type': 'dict', - 'keysrules': {'type': 'string'}, - 'valuesrules': schema - }}) - result = v.validate({'document': self._data}) + normalized = v.validated({'document': self._data}, + {'document': { + 'type': 'dict', + 'keysrules': {'type': 'string'}, + 'valuesrules': schema + }})['document'] elif rst == RootSchemaType.LIST: - v = Validator({'document': schema}) - result = v.validate({'document': self._data}) + v = MyValidator() + normalized = v.validated({'document': self._data}, {'document': schema})['document'] else: - v = Validator(schema) - result = v.validate(self._data) - # pprint.pprint(self._data) - if not result: - # pprint.pprint(v.errors) - raise DocumentError(f'{self.__class__.__name__}: failed to validate data:\n{pprint.pformat(v.errors)}') + normalized = v.validated(self._data, schema) + + self._data = normalized + try: self.custom_validator(self._data) except Exception as e: - raise DocumentError(f'{self.__class__.__name__}: {str(e)}') + raise cerberus.DocumentError(f'{self.__class__.__name__}: {str(e)}') @staticmethod def custom_validator(data): @@ -238,7 +256,7 @@ class Config: no_config=False): global app_config - if issubclass(name, AppConfigUnit) or name == AppConfigUnit: + if not isinstance(name, str) and not isinstance(name, bool) and issubclass(name, AppConfigUnit) or name == AppConfigUnit: self.app_name = name.NAME self.app_config = name() app_config = self.app_config @@ -278,6 +296,7 @@ class Config: if not no_config: self.app_config.load_from(path) + self.app_config.validate() setup_logging(self.app_config.logging_is_verbose(), self.app_config.logging_get_file(), diff --git a/src/home/inverter/config.py b/src/home/inverter/config.py index 62b8859..e284dfe 100644 --- a/src/home/inverter/config.py +++ b/src/home/inverter/config.py @@ -5,8 +5,8 @@ from typing import Optional class InverterdConfig(ConfigUnit): NAME = 'inverterd' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'remote_addr': {'type': 'string'}, 'local_addr': {'type': 'string'}, diff --git a/src/home/mqtt/_config.py b/src/home/mqtt/_config.py index f9047b4..9ba9443 100644 --- a/src/home/mqtt/_config.py +++ b/src/home/mqtt/_config.py @@ -9,8 +9,8 @@ MqttCreds = namedtuple('MqttCreds', 'username, password') class MqttConfig(ConfigUnit): NAME = 'mqtt' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: addr_schema = { 'type': 'dict', 'required': True, @@ -64,8 +64,8 @@ class MqttConfig(ConfigUnit): class MqttNodesConfig(ConfigUnit): NAME = 'mqtt_nodes' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'common': { 'type': 'dict', diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py index f858f88..3c2774c 100644 --- a/src/home/mqtt/_wrapper.py +++ b/src/home/mqtt/_wrapper.py @@ -2,7 +2,6 @@ import paho.mqtt.client as mqtt from ._mqtt import Mqtt from ._node import MqttNode -from ..config import config from ..util import strgen @@ -34,8 +33,10 @@ class MqttWrapper(Mqtt): def on_message(self, client: mqtt.Client, userdata, msg): try: topic = msg.topic + topic_node = topic[len(self._topic_prefix)+1:topic.find('/', len(self._topic_prefix)+1)] for node in self._nodes: - node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) + if node.id in ('+', topic_node): + node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) except Exception as e: self._logger.exception(str(e)) diff --git a/src/home/telegram/config.py b/src/home/telegram/config.py index 7a46087..4c7d74b 100644 --- a/src/home/telegram/config.py +++ b/src/home/telegram/config.py @@ -12,8 +12,8 @@ class TelegramUserListType(Enum): class TelegramUserIdsConfig(ConfigUnit): NAME = 'telegram_user_ids' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'roottype': 'dict', 'type': 'integer' @@ -32,8 +32,8 @@ def _user_id_mapper(user: Union[str, int]) -> int: class TelegramChatsConfig(ConfigUnit): NAME = 'telegram_chats' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'type': 'dict', 'schema': { @@ -44,8 +44,8 @@ class TelegramChatsConfig(ConfigUnit): class TelegramBotConfig(ConfigUnit, ABC): - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'bot': { 'type': 'dict', diff --git a/src/home/util.py b/src/home/util.py index 35505bc..1e12243 100644 --- a/src/home/util.py +++ b/src/home/util.py @@ -12,7 +12,7 @@ import re from enum import Enum from datetime import datetime -from typing import Tuple, Optional, List +from typing import Optional, List from zlib import adler32 logger = logging.getLogger(__name__) @@ -38,26 +38,43 @@ def validate_ipv4_or_hostname(address: str, raise_exception: bool = False) -> bo class Addr: host: str - port: int + port: Optional[int] - def __init__(self, host: str, port: int): + def __init__(self, host: str, port: Optional[int] = None): self.host = host self.port = port @staticmethod def fromstring(addr: str) -> Addr: - if addr.count(':') != 1: + colons = addr.count(':') + if colons != 1: raise ValueError('invalid host:port format') - host, port = addr.split(':') + if not colons: + host = addr + port= None + else: + host, port = addr.split(':') + validate_ipv4_or_hostname(host, raise_exception=True) - port = int(port) - if not 0 <= port <= 65535: - raise ValueError(f'invalid port {port}') + if port is not None: + port = int(port) + if not 0 <= port <= 65535: + raise ValueError(f'invalid port {port}') return Addr(host, port) + def __str__(self): + buf = self.host + if self.port is not None: + buf += ':'+str(self.port) + return buf + + def __iter__(self): + yield self.host + yield self.port + # https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks def chunks(lst, n): diff --git a/src/inverter_bot.py b/src/inverter_bot.py index ecf01fc..d35e606 100755 --- a/src/inverter_bot.py +++ b/src/inverter_bot.py @@ -55,8 +55,8 @@ logger = logging.getLogger(__name__) class InverterBotConfig(AppConfigUnit, TelegramBotConfig): NAME = 'inverter_bot' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: acmode_item_schema = { 'thresholds': { 'type': 'list', diff --git a/src/relay_mqtt_bot.py b/src/relay_mqtt_bot.py index 9de8c7e..020dc08 100755 --- a/src/relay_mqtt_bot.py +++ b/src/relay_mqtt_bot.py @@ -32,8 +32,8 @@ class RelayMqttBotConfig(AppConfigUnit, TelegramBotConfig): super().__init__() self._strings = Translation('mqtt_nodes') - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { **super(TelegramBotConfig).schema(), 'relay_nodes': { diff --git a/src/relay_mqtt_http_proxy.py b/src/relay_mqtt_http_proxy.py index 2bc2c4a..e13c04a 100755 --- a/src/relay_mqtt_http_proxy.py +++ b/src/relay_mqtt_http_proxy.py @@ -1,24 +1,69 @@ #!/usr/bin/env python3 +import logging + from home import http -from home.config import config -from home.mqtt import MqttPayload, MqttWrapper, MqttNode, MqttModule -from home.mqtt.module.relay import MqttRelayState, MqttRelayModule +from home.config import config, AppConfigUnit +from home.mqtt import MqttPayload, MqttWrapper, MqttNode, MqttModule, MqttNodesConfig +from home.mqtt.module.relay import MqttRelayState, MqttRelayModule, MqttPowerStatusPayload from home.mqtt.module.diagnostics import InitialDiagnosticsPayload, DiagnosticsPayload from typing import Optional, Union + +logger = logging.getLogger(__name__) mqtt: Optional[MqttWrapper] = None mqtt_nodes: dict[str, MqttNode] = {} relay_modules: dict[str, Union[MqttRelayModule, MqttModule]] = {} relay_states: dict[str, MqttRelayState] = {} +mqtt_nodes_config = MqttNodesConfig() + + +class RelayMqttHttpProxyConfig(AppConfigUnit): + NAME = 'relay_mqtt_http_proxy' + + @classmethod + def schema(cls) -> Optional[dict]: + return { + 'relay_nodes': { + 'type': 'list', + 'required': True, + 'schema': { + 'type': 'string' + } + }, + 'listen_addr': cls._addr_schema(required=True) + } + + @staticmethod + def custom_validator(data): + relay_node_names = mqtt_nodes_config.get_nodes(filters=('relay',), only_names=True) + for node in data['relay_nodes']: + if node not in relay_node_names: + raise ValueError(f'unknown relay node "{node}"') + def on_mqtt_message(node: MqttNode, message: MqttPayload): + try: + is_legacy = mqtt_nodes_config[node.id]['relay']['legacy_topics'] + logger.debug(f'on_mqtt_message: relay {node.id} uses legacy topic names') + except KeyError: + is_legacy = False + kwargs = {} + if isinstance(message, InitialDiagnosticsPayload) or isinstance(message, DiagnosticsPayload): - kwargs = dict(rssi=message.rssi, enabled=message.flags.state) - if device_id not in relay_states: - relay_states[device_id] = MqttRelayState() - relay_states[device_id].update(**kwargs) + kwargs['rssi'] = message.rssi + if is_legacy: + kwargs['enabled'] = message.flags.state + + if not is_legacy and isinstance(message, MqttPowerStatusPayload): + kwargs['enabled'] = message.opened + + if len(kwargs): + logger.debug(f'on_mqtt_message: {node.id}: going to update relay state: {str(kwargs)}') + if node.id not in relay_states: + relay_states[node.id] = MqttRelayState() + relay_states[node.id].update(**kwargs) class RelayMqttHttpProxy(http.HTTPServer): @@ -44,8 +89,7 @@ class RelayMqttHttpProxy(http.HTTPServer): cur_state = False enable = not cur_state - if not node.secret: - node.secret = node_secret + node.secret = node_secret relay_module.switchpower(enable) return self.ok() @@ -60,20 +104,29 @@ class RelayMqttHttpProxy(http.HTTPServer): if __name__ == '__main__': - config.load_app('relay_mqtt_http_proxy') - - mqtt = MqttWrapper() - for device_id, data in config['relays'].items(): - mqtt_node = MqttNode(node_id=device_id) - relay_modules[device_id] = mqtt_node.load_module('relay') - mqtt_nodes[device_id] = mqtt_node + config.load_app(RelayMqttHttpProxyConfig) + + mqtt = MqttWrapper(client_id='relay_mqtt_http_proxy', + randomize_client_id=True) + for node_id in config.app_config['relay_nodes']: + node_data = mqtt_nodes_config.get_node(node_id) + mqtt_node = MqttNode(node_id=node_id) + module_kwargs = {} + try: + if node_data['relay']['legacy_topics']: + module_kwargs['legacy_topics'] = True + except KeyError: + pass + relay_modules[node_id] = mqtt_node.load_module('relay', **module_kwargs) + if 'legacy_topics' in module_kwargs: + mqtt_node.load_module('diagnostics') mqtt_node.add_payload_callback(on_mqtt_message) mqtt.add_node(mqtt_node) - mqtt_node.add_payload_callback(on_mqtt_message) + mqtt_nodes[node_id] = mqtt_node mqtt.connect_and_loop(loop_forever=False) - proxy = RelayMqttHttpProxy(config.get_addr('server.listen')) + proxy = RelayMqttHttpProxy(config.app_config['listen_addr']) try: proxy.run() except KeyboardInterrupt: diff --git a/src/test_new_config.py b/src/test_new_config.py deleted file mode 100755 index db9eae3..0000000 --- a/src/test_new_config.py +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env python3 -from home.config import config -from home.mqtt import MqttNodesConfig -from home.telegram.config import TelegramUserIdsConfig -from pprint import pprint - - -if __name__ == '__main__': - config.load_app(name=False) - - c = TelegramUserIdsConfig() - pprint(c.get())
\ No newline at end of file |