summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEvgeny Zinoviev <me@ch1p.io>2023-06-10 21:54:56 +0300
committerEvgeny Zinoviev <me@ch1p.io>2023-06-10 21:55:01 +0300
commit327a5298359027099631c3c9967b7585928cd367 (patch)
treefb6358ceb3182c285bce3cff392654b0538c2c5c
parentf29e139cbb7e4a4d539cba6e894ef4a6acd312d6 (diff)
port relay_mqtt_http_proxy to new config scheme; config: support addr types & normalization
-rw-r--r--src/home/config/_configs.py8
-rw-r--r--src/home/config/config.py61
-rw-r--r--src/home/inverter/config.py4
-rw-r--r--src/home/mqtt/_config.py8
-rw-r--r--src/home/mqtt/_wrapper.py5
-rw-r--r--src/home/telegram/config.py12
-rw-r--r--src/home/util.py33
-rwxr-xr-xsrc/inverter_bot.py4
-rwxr-xr-xsrc/relay_mqtt_bot.py4
-rwxr-xr-xsrc/relay_mqtt_http_proxy.py89
-rwxr-xr-xsrc/test_new_config.py12
11 files changed, 159 insertions, 81 deletions
diff --git a/src/home/config/_configs.py b/src/home/config/_configs.py
index 3a1aae5..1628cba 100644
--- a/src/home/config/_configs.py
+++ b/src/home/config/_configs.py
@@ -5,8 +5,8 @@ from typing import Optional
class ServicesListConfig(ConfigUnit):
NAME = 'services_list'
- @staticmethod
- def schema() -> Optional[dict]:
+ @classmethod
+ def schema(cls) -> Optional[dict]:
return {
'type': 'list',
'empty': False,
@@ -19,8 +19,8 @@ class ServicesListConfig(ConfigUnit):
class LinuxBoardsConfig(ConfigUnit):
NAME = 'linux_boards'
- @staticmethod
- def schema() -> Optional[dict]:
+ @classmethod
+ def schema(cls) -> Optional[dict]:
return {
'type': 'dict',
'schema': {
diff --git a/src/home/config/config.py b/src/home/config/config.py
index aef9ee7..dc00d2e 100644
--- a/src/home/config/config.py
+++ b/src/home/config/config.py
@@ -1,10 +1,10 @@
import yaml
import logging
import os
-import pprint
+import cerberus
+import cerberus.errors
from abc import ABC
-from cerberus import Validator, DocumentError
from typing import Optional, Any, MutableMapping, Union
from argparse import ArgumentParser
from enum import Enum, auto
@@ -12,11 +12,20 @@ from os.path import join, isdir, isfile
from ..util import Addr
+class MyValidator(cerberus.Validator):
+ def _normalize_coerce_addr(self, value):
+ return Addr.fromstring(value)
+
+
+MyValidator.types_mapping['addr'] = cerberus.TypeDefinition('Addr', (Addr,), ())
+
+
CONFIG_DIRECTORIES = (
join(os.environ['HOME'], '.config', 'homekit'),
'/etc/homekit'
)
+
class RootSchemaType(Enum):
DEFAULT = auto()
DICT = auto()
@@ -95,10 +104,19 @@ class ConfigUnit(BaseConfigUnit):
raise IOError(f'\'{name}.yaml\' not found')
- @staticmethod
- def schema() -> Optional[dict]:
+ @classmethod
+ def schema(cls) -> Optional[dict]:
return None
+ @classmethod
+ def _addr_schema(cls, required=False, **kwargs):
+ return {
+ 'type': 'addr',
+ 'coerce': Addr.fromstring,
+ 'required': required,
+ **kwargs
+ }
+
def validate(self):
schema = self.schema()
if not schema:
@@ -109,7 +127,7 @@ class ConfigUnit(BaseConfigUnit):
schema['logging'] = {
'type': 'dict',
'schema': {
- 'logging': {'type': 'bool'}
+ 'logging': {'type': 'boolean'}
}
}
@@ -125,27 +143,27 @@ class ConfigUnit(BaseConfigUnit):
except KeyError:
pass
+ v = MyValidator()
+
if rst == RootSchemaType.DICT:
- v = Validator({'document': {
- 'type': 'dict',
- 'keysrules': {'type': 'string'},
- 'valuesrules': schema
- }})
- result = v.validate({'document': self._data})
+ normalized = v.validated({'document': self._data},
+ {'document': {
+ 'type': 'dict',
+ 'keysrules': {'type': 'string'},
+ 'valuesrules': schema
+ }})['document']
elif rst == RootSchemaType.LIST:
- v = Validator({'document': schema})
- result = v.validate({'document': self._data})
+ v = MyValidator()
+ normalized = v.validated({'document': self._data}, {'document': schema})['document']
else:
- v = Validator(schema)
- result = v.validate(self._data)
- # pprint.pprint(self._data)
- if not result:
- # pprint.pprint(v.errors)
- raise DocumentError(f'{self.__class__.__name__}: failed to validate data:\n{pprint.pformat(v.errors)}')
+ normalized = v.validated(self._data, schema)
+
+ self._data = normalized
+
try:
self.custom_validator(self._data)
except Exception as e:
- raise DocumentError(f'{self.__class__.__name__}: {str(e)}')
+ raise cerberus.DocumentError(f'{self.__class__.__name__}: {str(e)}')
@staticmethod
def custom_validator(data):
@@ -238,7 +256,7 @@ class Config:
no_config=False):
global app_config
- if issubclass(name, AppConfigUnit) or name == AppConfigUnit:
+ if not isinstance(name, str) and not isinstance(name, bool) and issubclass(name, AppConfigUnit) or name == AppConfigUnit:
self.app_name = name.NAME
self.app_config = name()
app_config = self.app_config
@@ -278,6 +296,7 @@ class Config:
if not no_config:
self.app_config.load_from(path)
+ self.app_config.validate()
setup_logging(self.app_config.logging_is_verbose(),
self.app_config.logging_get_file(),
diff --git a/src/home/inverter/config.py b/src/home/inverter/config.py
index 62b8859..e284dfe 100644
--- a/src/home/inverter/config.py
+++ b/src/home/inverter/config.py
@@ -5,8 +5,8 @@ from typing import Optional
class InverterdConfig(ConfigUnit):
NAME = 'inverterd'
- @staticmethod
- def schema() -> Optional[dict]:
+ @classmethod
+ def schema(cls) -> Optional[dict]:
return {
'remote_addr': {'type': 'string'},
'local_addr': {'type': 'string'},
diff --git a/src/home/mqtt/_config.py b/src/home/mqtt/_config.py
index f9047b4..9ba9443 100644
--- a/src/home/mqtt/_config.py
+++ b/src/home/mqtt/_config.py
@@ -9,8 +9,8 @@ MqttCreds = namedtuple('MqttCreds', 'username, password')
class MqttConfig(ConfigUnit):
NAME = 'mqtt'
- @staticmethod
- def schema() -> Optional[dict]:
+ @classmethod
+ def schema(cls) -> Optional[dict]:
addr_schema = {
'type': 'dict',
'required': True,
@@ -64,8 +64,8 @@ class MqttConfig(ConfigUnit):
class MqttNodesConfig(ConfigUnit):
NAME = 'mqtt_nodes'
- @staticmethod
- def schema() -> Optional[dict]:
+ @classmethod
+ def schema(cls) -> Optional[dict]:
return {
'common': {
'type': 'dict',
diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py
index f858f88..3c2774c 100644
--- a/src/home/mqtt/_wrapper.py
+++ b/src/home/mqtt/_wrapper.py
@@ -2,7 +2,6 @@ import paho.mqtt.client as mqtt
from ._mqtt import Mqtt
from ._node import MqttNode
-from ..config import config
from ..util import strgen
@@ -34,8 +33,10 @@ class MqttWrapper(Mqtt):
def on_message(self, client: mqtt.Client, userdata, msg):
try:
topic = msg.topic
+ topic_node = topic[len(self._topic_prefix)+1:topic.find('/', len(self._topic_prefix)+1)]
for node in self._nodes:
- node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload)
+ if node.id in ('+', topic_node):
+ node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload)
except Exception as e:
self._logger.exception(str(e))
diff --git a/src/home/telegram/config.py b/src/home/telegram/config.py
index 7a46087..4c7d74b 100644
--- a/src/home/telegram/config.py
+++ b/src/home/telegram/config.py
@@ -12,8 +12,8 @@ class TelegramUserListType(Enum):
class TelegramUserIdsConfig(ConfigUnit):
NAME = 'telegram_user_ids'
- @staticmethod
- def schema() -> Optional[dict]:
+ @classmethod
+ def schema(cls) -> Optional[dict]:
return {
'roottype': 'dict',
'type': 'integer'
@@ -32,8 +32,8 @@ def _user_id_mapper(user: Union[str, int]) -> int:
class TelegramChatsConfig(ConfigUnit):
NAME = 'telegram_chats'
- @staticmethod
- def schema() -> Optional[dict]:
+ @classmethod
+ def schema(cls) -> Optional[dict]:
return {
'type': 'dict',
'schema': {
@@ -44,8 +44,8 @@ class TelegramChatsConfig(ConfigUnit):
class TelegramBotConfig(ConfigUnit, ABC):
- @staticmethod
- def schema() -> Optional[dict]:
+ @classmethod
+ def schema(cls) -> Optional[dict]:
return {
'bot': {
'type': 'dict',
diff --git a/src/home/util.py b/src/home/util.py
index 35505bc..1e12243 100644
--- a/src/home/util.py
+++ b/src/home/util.py
@@ -12,7 +12,7 @@ import re
from enum import Enum
from datetime import datetime
-from typing import Tuple, Optional, List
+from typing import Optional, List
from zlib import adler32
logger = logging.getLogger(__name__)
@@ -38,26 +38,43 @@ def validate_ipv4_or_hostname(address: str, raise_exception: bool = False) -> bo
class Addr:
host: str
- port: int
+ port: Optional[int]
- def __init__(self, host: str, port: int):
+ def __init__(self, host: str, port: Optional[int] = None):
self.host = host
self.port = port
@staticmethod
def fromstring(addr: str) -> Addr:
- if addr.count(':') != 1:
+ colons = addr.count(':')
+ if colons != 1:
raise ValueError('invalid host:port format')
- host, port = addr.split(':')
+ if not colons:
+ host = addr
+ port= None
+ else:
+ host, port = addr.split(':')
+
validate_ipv4_or_hostname(host, raise_exception=True)
- port = int(port)
- if not 0 <= port <= 65535:
- raise ValueError(f'invalid port {port}')
+ if port is not None:
+ port = int(port)
+ if not 0 <= port <= 65535:
+ raise ValueError(f'invalid port {port}')
return Addr(host, port)
+ def __str__(self):
+ buf = self.host
+ if self.port is not None:
+ buf += ':'+str(self.port)
+ return buf
+
+ def __iter__(self):
+ yield self.host
+ yield self.port
+
# https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks
def chunks(lst, n):
diff --git a/src/inverter_bot.py b/src/inverter_bot.py
index ecf01fc..d35e606 100755
--- a/src/inverter_bot.py
+++ b/src/inverter_bot.py
@@ -55,8 +55,8 @@ logger = logging.getLogger(__name__)
class InverterBotConfig(AppConfigUnit, TelegramBotConfig):
NAME = 'inverter_bot'
- @staticmethod
- def schema() -> Optional[dict]:
+ @classmethod
+ def schema(cls) -> Optional[dict]:
acmode_item_schema = {
'thresholds': {
'type': 'list',
diff --git a/src/relay_mqtt_bot.py b/src/relay_mqtt_bot.py
index 9de8c7e..020dc08 100755
--- a/src/relay_mqtt_bot.py
+++ b/src/relay_mqtt_bot.py
@@ -32,8 +32,8 @@ class RelayMqttBotConfig(AppConfigUnit, TelegramBotConfig):
super().__init__()
self._strings = Translation('mqtt_nodes')
- @staticmethod
- def schema() -> Optional[dict]:
+ @classmethod
+ def schema(cls) -> Optional[dict]:
return {
**super(TelegramBotConfig).schema(),
'relay_nodes': {
diff --git a/src/relay_mqtt_http_proxy.py b/src/relay_mqtt_http_proxy.py
index 2bc2c4a..e13c04a 100755
--- a/src/relay_mqtt_http_proxy.py
+++ b/src/relay_mqtt_http_proxy.py
@@ -1,24 +1,69 @@
#!/usr/bin/env python3
+import logging
+
from home import http
-from home.config import config
-from home.mqtt import MqttPayload, MqttWrapper, MqttNode, MqttModule
-from home.mqtt.module.relay import MqttRelayState, MqttRelayModule
+from home.config import config, AppConfigUnit
+from home.mqtt import MqttPayload, MqttWrapper, MqttNode, MqttModule, MqttNodesConfig
+from home.mqtt.module.relay import MqttRelayState, MqttRelayModule, MqttPowerStatusPayload
from home.mqtt.module.diagnostics import InitialDiagnosticsPayload, DiagnosticsPayload
from typing import Optional, Union
+
+logger = logging.getLogger(__name__)
mqtt: Optional[MqttWrapper] = None
mqtt_nodes: dict[str, MqttNode] = {}
relay_modules: dict[str, Union[MqttRelayModule, MqttModule]] = {}
relay_states: dict[str, MqttRelayState] = {}
+mqtt_nodes_config = MqttNodesConfig()
+
+
+class RelayMqttHttpProxyConfig(AppConfigUnit):
+ NAME = 'relay_mqtt_http_proxy'
+
+ @classmethod
+ def schema(cls) -> Optional[dict]:
+ return {
+ 'relay_nodes': {
+ 'type': 'list',
+ 'required': True,
+ 'schema': {
+ 'type': 'string'
+ }
+ },
+ 'listen_addr': cls._addr_schema(required=True)
+ }
+
+ @staticmethod
+ def custom_validator(data):
+ relay_node_names = mqtt_nodes_config.get_nodes(filters=('relay',), only_names=True)
+ for node in data['relay_nodes']:
+ if node not in relay_node_names:
+ raise ValueError(f'unknown relay node "{node}"')
+
def on_mqtt_message(node: MqttNode,
message: MqttPayload):
+ try:
+ is_legacy = mqtt_nodes_config[node.id]['relay']['legacy_topics']
+ logger.debug(f'on_mqtt_message: relay {node.id} uses legacy topic names')
+ except KeyError:
+ is_legacy = False
+ kwargs = {}
+
if isinstance(message, InitialDiagnosticsPayload) or isinstance(message, DiagnosticsPayload):
- kwargs = dict(rssi=message.rssi, enabled=message.flags.state)
- if device_id not in relay_states:
- relay_states[device_id] = MqttRelayState()
- relay_states[device_id].update(**kwargs)
+ kwargs['rssi'] = message.rssi
+ if is_legacy:
+ kwargs['enabled'] = message.flags.state
+
+ if not is_legacy and isinstance(message, MqttPowerStatusPayload):
+ kwargs['enabled'] = message.opened
+
+ if len(kwargs):
+ logger.debug(f'on_mqtt_message: {node.id}: going to update relay state: {str(kwargs)}')
+ if node.id not in relay_states:
+ relay_states[node.id] = MqttRelayState()
+ relay_states[node.id].update(**kwargs)
class RelayMqttHttpProxy(http.HTTPServer):
@@ -44,8 +89,7 @@ class RelayMqttHttpProxy(http.HTTPServer):
cur_state = False
enable = not cur_state
- if not node.secret:
- node.secret = node_secret
+ node.secret = node_secret
relay_module.switchpower(enable)
return self.ok()
@@ -60,20 +104,29 @@ class RelayMqttHttpProxy(http.HTTPServer):
if __name__ == '__main__':
- config.load_app('relay_mqtt_http_proxy')
-
- mqtt = MqttWrapper()
- for device_id, data in config['relays'].items():
- mqtt_node = MqttNode(node_id=device_id)
- relay_modules[device_id] = mqtt_node.load_module('relay')
- mqtt_nodes[device_id] = mqtt_node
+ config.load_app(RelayMqttHttpProxyConfig)
+
+ mqtt = MqttWrapper(client_id='relay_mqtt_http_proxy',
+ randomize_client_id=True)
+ for node_id in config.app_config['relay_nodes']:
+ node_data = mqtt_nodes_config.get_node(node_id)
+ mqtt_node = MqttNode(node_id=node_id)
+ module_kwargs = {}
+ try:
+ if node_data['relay']['legacy_topics']:
+ module_kwargs['legacy_topics'] = True
+ except KeyError:
+ pass
+ relay_modules[node_id] = mqtt_node.load_module('relay', **module_kwargs)
+ if 'legacy_topics' in module_kwargs:
+ mqtt_node.load_module('diagnostics')
mqtt_node.add_payload_callback(on_mqtt_message)
mqtt.add_node(mqtt_node)
- mqtt_node.add_payload_callback(on_mqtt_message)
+ mqtt_nodes[node_id] = mqtt_node
mqtt.connect_and_loop(loop_forever=False)
- proxy = RelayMqttHttpProxy(config.get_addr('server.listen'))
+ proxy = RelayMqttHttpProxy(config.app_config['listen_addr'])
try:
proxy.run()
except KeyboardInterrupt:
diff --git a/src/test_new_config.py b/src/test_new_config.py
deleted file mode 100755
index db9eae3..0000000
--- a/src/test_new_config.py
+++ /dev/null
@@ -1,12 +0,0 @@
-#!/usr/bin/env python3
-from home.config import config
-from home.mqtt import MqttNodesConfig
-from home.telegram.config import TelegramUserIdsConfig
-from pprint import pprint
-
-
-if __name__ == '__main__':
- config.load_app(name=False)
-
- c = TelegramUserIdsConfig()
- pprint(c.get()) \ No newline at end of file