diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/home/config/config.py | 11 | ||||
-rw-r--r-- | src/home/mqtt/__init__.py | 2 | ||||
-rw-r--r-- | src/home/mqtt/_config.py | 103 | ||||
-rwxr-xr-x | src/mqtt_node_util.py | 17 | ||||
-rwxr-xr-x | src/relay_mqtt_bot.py | 2 | ||||
-rwxr-xr-x | src/test_new_config.py | 6 |
6 files changed, 124 insertions, 17 deletions
diff --git a/src/home/config/config.py b/src/home/config/config.py index e89cc82..d526bb2 100644 --- a/src/home/config/config.py +++ b/src/home/config/config.py @@ -2,6 +2,7 @@ import toml import yaml import logging import os +import pprint from cerberus import Validator, DocumentError from typing import Optional, Any, MutableMapping, Union @@ -120,7 +121,15 @@ class ConfigUnit: # pprint(self._data) if not result: # pprint(v.errors) - raise DocumentError(f'{self.__class__.__name__}: failed to validate data: {v.errors}') + raise DocumentError(f'{self.__class__.__name__}: failed to validate data:\n{pprint.pformat(v.errors)}') + try: + self.custom_validator(self._data) + except Exception as e: + raise DocumentError(f'{self.__class__.__name__}: {str(e)}') + + @staticmethod + def custom_validator(data): + pass def __getitem__(self, key): return self._data[key] diff --git a/src/home/mqtt/__init__.py b/src/home/mqtt/__init__.py index 83169fb..707d59c 100644 --- a/src/home/mqtt/__init__.py +++ b/src/home/mqtt/__init__.py @@ -2,6 +2,6 @@ from ._mqtt import Mqtt from ._node import MqttNode from ._module import MqttModule from ._wrapper import MqttWrapper -from ._config import MqttConfig, MqttCreds +from ._config import MqttConfig, MqttCreds, MqttNodesConfig from ._payload import MqttPayload, MqttPayloadCustomField from ._util import get_modules as get_mqtt_modules
\ No newline at end of file diff --git a/src/home/mqtt/_config.py b/src/home/mqtt/_config.py index 88aaa94..3f9dd09 100644 --- a/src/home/mqtt/_config.py +++ b/src/home/mqtt/_config.py @@ -1,5 +1,5 @@ from ..config import ConfigUnit -from typing import Optional +from typing import Optional, Union from ..util import Addr from collections import namedtuple @@ -59,3 +59,104 @@ class MqttConfig(ConfigUnit): def server_creds(self) -> MqttCreds: return self.creds_by_name(self['default_server_creds']) + + +class MqttNodesConfig(ConfigUnit): + NAME = 'mqtt_nodes' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'common': { + 'type': 'dict', + 'schema': { + 'temphum': { + 'type': 'dict', + 'schema': { + 'interval': {'type': 'integer'} + } + }, + 'password': {'type': 'string'} + } + }, + 'nodes': { + 'type': 'dict', + 'required': True, + 'keysrules': {'type': 'string'}, + 'valuesrules': { + 'type': 'dict', + 'schema': { + 'type': {'type': 'string', 'required': True, 'allowed': ['esp8266', 'linux', 'none'],}, + 'board': {'type': 'string', 'allowed': ['nodemcu', 'd1_mini_lite', 'esp12e']}, + 'temphum': { + 'type': 'dict', + 'schema': { + 'module': {'type': 'string', 'required': True, 'allowed': ['si7021', 'dht12']}, + 'interval': {'type': 'integer'}, + 'i2c_bus': {'type': 'integer'}, + 'tcpserver': { + 'type': 'dict', + 'schema': { + 'port': {'type': 'integer', 'required': True} + } + } + } + }, + 'relay': { + 'type': 'dict', + 'schema': {} + }, + 'password': {'type': 'string'} + } + } + } + } + + @staticmethod + def custom_validator(data): + for name, node in data['nodes'].items(): + if 'temphum' in node: + if node['type'] == 'linux': + if 'i2c_bus' not in node['temphum']: + raise KeyError(f'nodes.{name}.temphum: i2c_bus is missing but required for type=linux') + if node['type'] in ('esp8266',) and 'board' not in node: + raise KeyError(f'nodes.{name}: board is missing but required for type={node["type"]}') + + def get_node(self, name: str) -> dict: + node = self['nodes'][name] + if node['type'] == 'none': + return node + + try: + if 'password' not in node: + node['password'] = self['common']['password'] + except KeyError: + pass + + try: + if 'temphum' in node: + for ckey, cval in self['common']['temphum'].items(): + if ckey not in node['temphum']: + node['temphum'][ckey] = cval + except KeyError: + pass + + return node + + def get_nodes(self, + filters: Optional[Union[list[str], tuple[str]]] = None, + only_names=False) -> Union[dict, list[str]]: + if filters: + for f in filters: + if f not in ('temphum', 'relay'): + raise ValueError(f'{self.__class__.__name__}::get_node(): invalid filter {f}') + reslist = [] + resdict = {} + for name in self['nodes'].keys(): + node = self.get_node(name) + if (not filters) or ('temphum' in filters and 'temphum' in node) or ('relay' in filters and 'relay' in node): + if only_names: + reslist.append(name) + else: + resdict[name] = node + return reslist if only_names else resdict diff --git a/src/mqtt_node_util.py b/src/mqtt_node_util.py index 49179b1..43830f9 100755 --- a/src/mqtt_node_util.py +++ b/src/mqtt_node_util.py @@ -6,32 +6,34 @@ from argparse import ArgumentParser, ArgumentError from home.config import config from home.mqtt import MqttNode, MqttWrapper, get_mqtt_modules +from home.mqtt import MqttNodesConfig mqtt_node: Optional[MqttNode] = None mqtt: Optional[MqttWrapper] = None if __name__ == '__main__': + nodes_config = MqttNodesConfig() + parser = ArgumentParser() - parser.add_argument('--node-id', type=str, required=True) + parser.add_argument('--node-id', type=str, required=True, choices=nodes_config.get_nodes(only_names=True)) parser.add_argument('--modules', type=str, choices=get_mqtt_modules(), nargs='*', help='mqtt modules to include') parser.add_argument('--switch-relay', choices=[0, 1], type=int, help='send relay state') parser.add_argument('--push-ota', type=str, metavar='OTA_FILENAME', help='push ota, argument receives filename') - parser.add_argument('--node-secret', type=str, - help='node admin password') config.load_app(parser=parser, no_config=True) arg = parser.parse_args() - if (arg.switch_relay is not None or arg.node_secret is not None) and 'relay' not in arg.modules: + if arg.switch_relay is not None and 'relay' not in arg.modules: raise ArgumentError(None, '--relay is only allowed when \'relay\' module included in --modules') mqtt = MqttWrapper(randomize_client_id=True, client_id='mqtt_node_util') - mqtt_node = MqttNode(node_id=arg.node_id, node_secret=arg.node_secret) + mqtt_node = MqttNode(node_id=arg.node_id, + node_secret=nodes_config.get_node(arg.node_id)['password']) mqtt.add_node(mqtt_node) @@ -43,8 +45,6 @@ if __name__ == '__main__': for m in arg.modules: module_instance = mqtt_node.load_module(m) if m == 'relay' and arg.switch_relay is not None: - if not arg.node_secret: - raise ArgumentError(None, '--switch-relay requires --node-secret') module_instance.switchpower(arg.switch_relay == 1) mqtt.configure_tls() @@ -54,9 +54,6 @@ if __name__ == '__main__': if arg.push_ota: if not os.path.exists(arg.push_ota): raise OSError(f'--push-ota: file \"{arg.push_ota}\" does not exists') - if not arg.node_secret: - raise ArgumentError(None, 'pushing OTA requires --node-secret') - ota_module.push_ota(arg.push_ota, 1) while True: diff --git a/src/relay_mqtt_bot.py b/src/relay_mqtt_bot.py index e7fa613..f6a1532 100755 --- a/src/relay_mqtt_bot.py +++ b/src/relay_mqtt_bot.py @@ -90,7 +90,7 @@ def markup(ctx: Optional[bot.Context]) -> Optional[ReplyKeyboardMarkup]: if __name__ == '__main__': devices = [] - mqtt = MqttWrapper() + mqtt = MqttWrapper(client_id='relay_mqtt_bot') for device_id, data in config['relays'].items(): mqtt_node = MqttNode(node_id=device_id, node_secret=data['secret']) relay_nodes[device_id] = mqtt_node.load_module('relay') diff --git a/src/test_new_config.py b/src/test_new_config.py index 442a03a..939281a 100755 --- a/src/test_new_config.py +++ b/src/test_new_config.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 from home.config import config, LinuxBoardsConfig, ServicesListConfig -from home.mqtt import MqttConfig +from home.mqtt import MqttConfig, MqttNodesConfig from pprint import pprint if __name__ == '__main__': config.load_app(name=False) - c = MqttConfig() - print(c.creds())
\ No newline at end of file + c = MqttNodesConfig() + pprint(c.get_nodes(filters=('temphum',), only_names=False))
\ No newline at end of file |