From f29e139cbb7e4a4d539cba6e894ef4a6acd312d6 Mon Sep 17 00:00:00 2001 From: Evgeny Zinoviev Date: Wed, 31 May 2023 09:22:00 +0300 Subject: WIP: big refactoring --- src/home/mqtt/__init__.py | 11 +- src/home/mqtt/_config.py | 165 ++++++++++++++++++++++++++++ src/home/mqtt/_module.py | 70 ++++++++++++ src/home/mqtt/_mqtt.py | 86 +++++++++++++++ src/home/mqtt/_node.py | 92 ++++++++++++++++ src/home/mqtt/_payload.py | 145 +++++++++++++++++++++++++ src/home/mqtt/_util.py | 15 +++ src/home/mqtt/_wrapper.py | 59 ++++++++++ src/home/mqtt/esp.py | 106 ------------------ src/home/mqtt/module/diagnostics.py | 64 +++++++++++ src/home/mqtt/module/inverter.py | 195 ++++++++++++++++++++++++++++++++++ src/home/mqtt/module/ota.py | 77 ++++++++++++++ src/home/mqtt/module/relay.py | 92 ++++++++++++++++ src/home/mqtt/module/temphum.py | 82 ++++++++++++++ src/home/mqtt/mqtt.py | 76 ------------- src/home/mqtt/payload/__init__.py | 1 - src/home/mqtt/payload/base_payload.py | 145 ------------------------- src/home/mqtt/payload/esp.py | 78 -------------- src/home/mqtt/payload/inverter.py | 73 ------------- src/home/mqtt/payload/relay.py | 22 ---- src/home/mqtt/payload/sensors.py | 20 ---- src/home/mqtt/payload/temphum.py | 15 --- src/home/mqtt/relay.py | 71 ------------- src/home/mqtt/temphum.py | 54 ---------- src/home/mqtt/util.py | 8 -- 25 files changed, 1149 insertions(+), 673 deletions(-) create mode 100644 src/home/mqtt/_config.py create mode 100644 src/home/mqtt/_module.py create mode 100644 src/home/mqtt/_mqtt.py create mode 100644 src/home/mqtt/_node.py create mode 100644 src/home/mqtt/_payload.py create mode 100644 src/home/mqtt/_util.py create mode 100644 src/home/mqtt/_wrapper.py delete mode 100644 src/home/mqtt/esp.py create mode 100644 src/home/mqtt/module/diagnostics.py create mode 100644 src/home/mqtt/module/inverter.py create mode 100644 src/home/mqtt/module/ota.py create mode 100644 src/home/mqtt/module/relay.py create mode 100644 src/home/mqtt/module/temphum.py delete mode 100644 src/home/mqtt/mqtt.py delete mode 100644 src/home/mqtt/payload/__init__.py delete mode 100644 src/home/mqtt/payload/base_payload.py delete mode 100644 src/home/mqtt/payload/esp.py delete mode 100644 src/home/mqtt/payload/inverter.py delete mode 100644 src/home/mqtt/payload/relay.py delete mode 100644 src/home/mqtt/payload/sensors.py delete mode 100644 src/home/mqtt/payload/temphum.py delete mode 100644 src/home/mqtt/relay.py delete mode 100644 src/home/mqtt/temphum.py delete mode 100644 src/home/mqtt/util.py (limited to 'src/home/mqtt') diff --git a/src/home/mqtt/__init__.py b/src/home/mqtt/__init__.py index 982e2b6..707d59c 100644 --- a/src/home/mqtt/__init__.py +++ b/src/home/mqtt/__init__.py @@ -1,4 +1,7 @@ -from .mqtt import MqttBase -from .util import poll_tick -from .relay import MqttRelay, MqttRelayState -from .temphum import MqttTempHum \ No newline at end of file +from ._mqtt import Mqtt +from ._node import MqttNode +from ._module import MqttModule +from ._wrapper import MqttWrapper +from ._config import MqttConfig, MqttCreds, MqttNodesConfig +from ._payload import MqttPayload, MqttPayloadCustomField +from ._util import get_modules as get_mqtt_modules \ No newline at end of file diff --git a/src/home/mqtt/_config.py b/src/home/mqtt/_config.py new file mode 100644 index 0000000..f9047b4 --- /dev/null +++ b/src/home/mqtt/_config.py @@ -0,0 +1,165 @@ +from ..config import ConfigUnit +from typing import Optional, Union +from ..util import Addr +from collections import namedtuple + +MqttCreds = namedtuple('MqttCreds', 'username, password') + + +class MqttConfig(ConfigUnit): + NAME = 'mqtt' + + @staticmethod + def schema() -> Optional[dict]: + addr_schema = { + 'type': 'dict', + 'required': True, + 'schema': { + 'host': {'type': 'string', 'required': True}, + 'port': {'type': 'integer', 'required': True} + } + } + + schema = {} + for key in ('local', 'remote'): + schema[f'{key}_addr'] = addr_schema + + schema['creds'] = { + 'type': 'dict', + 'required': True, + 'keysrules': {'type': 'string'}, + 'valuesrules': { + 'type': 'dict', + 'schema': { + 'username': {'type': 'string', 'required': True}, + 'password': {'type': 'string', 'required': True}, + } + } + } + + for key in ('client', 'server'): + schema[f'default_{key}_creds'] = {'type': 'string', 'required': True} + + return schema + + def remote_addr(self) -> Addr: + return Addr(host=self['remote_addr']['host'], + port=self['remote_addr']['port']) + + def local_addr(self) -> Addr: + return Addr(host=self['local_addr']['host'], + port=self['local_addr']['port']) + + def creds_by_name(self, name: str) -> MqttCreds: + return MqttCreds(username=self['creds'][name]['username'], + password=self['creds'][name]['password']) + + def creds(self) -> MqttCreds: + return self.creds_by_name(self['default_client_creds']) + + def server_creds(self) -> MqttCreds: + return self.creds_by_name(self['default_server_creds']) + + +class MqttNodesConfig(ConfigUnit): + NAME = 'mqtt_nodes' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'common': { + 'type': 'dict', + 'schema': { + 'temphum': { + 'type': 'dict', + 'schema': { + 'interval': {'type': 'integer'} + } + }, + 'password': {'type': 'string'} + } + }, + 'nodes': { + 'type': 'dict', + 'required': True, + 'keysrules': {'type': 'string'}, + 'valuesrules': { + 'type': 'dict', + 'schema': { + 'type': {'type': 'string', 'required': True, 'allowed': ['esp8266', 'linux', 'none'],}, + 'board': {'type': 'string', 'allowed': ['nodemcu', 'd1_mini_lite', 'esp12e']}, + 'temphum': { + 'type': 'dict', + 'schema': { + 'module': {'type': 'string', 'required': True, 'allowed': ['si7021', 'dht12']}, + 'interval': {'type': 'integer'}, + 'i2c_bus': {'type': 'integer'}, + 'tcpserver': { + 'type': 'dict', + 'schema': { + 'port': {'type': 'integer', 'required': True} + } + } + } + }, + 'relay': { + 'type': 'dict', + 'schema': { + 'device_type': {'type': 'string', 'allowed': ['lamp', 'pump', 'solenoid'], 'required': True}, + 'legacy_topics': {'type': 'boolean'} + } + }, + 'password': {'type': 'string'} + } + } + } + } + + @staticmethod + def custom_validator(data): + for name, node in data['nodes'].items(): + if 'temphum' in node: + if node['type'] == 'linux': + if 'i2c_bus' not in node['temphum']: + raise KeyError(f'nodes.{name}.temphum: i2c_bus is missing but required for type=linux') + if node['type'] in ('esp8266',) and 'board' not in node: + raise KeyError(f'nodes.{name}: board is missing but required for type={node["type"]}') + + def get_node(self, name: str) -> dict: + node = self['nodes'][name] + if node['type'] == 'none': + return node + + try: + if 'password' not in node: + node['password'] = self['common']['password'] + except KeyError: + pass + + try: + if 'temphum' in node: + for ckey, cval in self['common']['temphum'].items(): + if ckey not in node['temphum']: + node['temphum'][ckey] = cval + except KeyError: + pass + + return node + + def get_nodes(self, + filters: Optional[Union[list[str], tuple[str]]] = None, + only_names=False) -> Union[dict, list[str]]: + if filters: + for f in filters: + if f not in ('temphum', 'relay'): + raise ValueError(f'{self.__class__.__name__}::get_node(): invalid filter {f}') + reslist = [] + resdict = {} + for name in self['nodes'].keys(): + node = self.get_node(name) + if (not filters) or ('temphum' in filters and 'temphum' in node) or ('relay' in filters and 'relay' in node): + if only_names: + reslist.append(name) + else: + resdict[name] = node + return reslist if only_names else resdict diff --git a/src/home/mqtt/_module.py b/src/home/mqtt/_module.py new file mode 100644 index 0000000..80f27bb --- /dev/null +++ b/src/home/mqtt/_module.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import abc +import logging +import threading + +from time import sleep +from ..util import next_tick_gen + +from typing import TYPE_CHECKING, Optional +if TYPE_CHECKING: + from ._node import MqttNode + from ._payload import MqttPayload + + +class MqttModule(abc.ABC): + _tick_interval: int + _initialized: bool + _connected: bool + _ticker: Optional[threading.Thread] + _mqtt_node_ref: Optional[MqttNode] + + def __init__(self, tick_interval=0): + self._tick_interval = tick_interval + self._initialized = False + self._ticker = None + self._logger = logging.getLogger(self.__class__.__name__) + self._connected = False + self._mqtt_node_ref = None + + def on_connect(self, mqtt: MqttNode): + self._connected = True + self._mqtt_node_ref = mqtt + if self._tick_interval: + self._start_ticker() + + def on_disconnect(self, mqtt: MqttNode): + self._connected = False + self._mqtt_node_ref = None + + def is_initialized(self): + return self._initialized + + def set_initialized(self): + self._initialized = True + + def unset_initialized(self): + self._initialized = False + + def tick(self): + pass + + def _tick(self): + g = next_tick_gen(self._tick_interval) + while self._connected: + sleep(next(g)) + if not self._connected: + break + self.tick() + + def _start_ticker(self): + if not self._ticker or not self._ticker.is_alive(): + name_part = f'{self._mqtt_node_ref.id}/' if self._mqtt_node_ref else '' + self._ticker = None + self._ticker = threading.Thread(target=self._tick, + name=f'mqtt:{self.__class__.__name__}/{name_part}ticker') + self._ticker.start() + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + pass diff --git a/src/home/mqtt/_mqtt.py b/src/home/mqtt/_mqtt.py new file mode 100644 index 0000000..746ae2e --- /dev/null +++ b/src/home/mqtt/_mqtt.py @@ -0,0 +1,86 @@ +import os.path +import paho.mqtt.client as mqtt +import ssl +import logging + +from ._config import MqttCreds, MqttConfig +from typing import Optional + + +class Mqtt: + _connected: bool + _is_server: bool + _mqtt_config: MqttConfig + + def __init__(self, + clean_session=True, + client_id='', + creds: Optional[MqttCreds] = None, + is_server=False): + if not client_id: + raise ValueError('client_id must not be empty') + + self._client = mqtt.Client(client_id=client_id, + protocol=mqtt.MQTTv311, + clean_session=clean_session) + self._client.on_connect = self.on_connect + self._client.on_disconnect = self.on_disconnect + self._client.on_message = self.on_message + self._client.on_log = self.on_log + self._client.on_publish = self.on_publish + self._loop_started = False + self._connected = False + self._is_server = is_server + self._mqtt_config = MqttConfig() + self._logger = logging.getLogger(self.__class__.__name__) + + if not creds: + creds = self._mqtt_config.creds() if not is_server else self._mqtt_config.server_creds() + + self._client.username_pw_set(creds.username, creds.password) + + def _configure_tls(self): + ca_certs = os.path.realpath(os.path.join( + os.path.dirname(os.path.realpath(__file__)), + '..', + '..', + '..', + 'assets', + 'mqtt_ca.crt' + )) + self._client.tls_set(ca_certs=ca_certs, + cert_reqs=ssl.CERT_REQUIRED, + tls_version=ssl.PROTOCOL_TLSv1_2) + + def connect_and_loop(self, loop_forever=True): + self._configure_tls() + addr = self._mqtt_config.local_addr() if self._is_server else self._mqtt_config.remote_addr() + self._client.connect(addr.host, addr.port, 60) + if loop_forever: + self._client.loop_forever() + else: + self._client.loop_start() + self._loop_started = True + + def disconnect(self): + self._client.disconnect() + self._client.loop_write() + self._client.loop_stop() + + def on_connect(self, client: mqtt.Client, userdata, flags, rc): + self._logger.info("Connected with result code " + str(rc)) + self._connected = True + + def on_disconnect(self, client: mqtt.Client, userdata, rc): + self._logger.info("Disconnected with result code " + str(rc)) + self._connected = False + + def on_log(self, client: mqtt.Client, userdata, level, buf): + level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO + self._logger.log(level, f'MQTT: {buf}') + + def on_message(self, client: mqtt.Client, userdata, msg): + self._logger.debug(msg.topic + ": " + str(msg.payload)) + + def on_publish(self, client: mqtt.Client, userdata, mid): + self._logger.debug(f'publish done, mid={mid}') diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py new file mode 100644 index 0000000..4e259a4 --- /dev/null +++ b/src/home/mqtt/_node.py @@ -0,0 +1,92 @@ +import logging +import importlib + +from typing import List, TYPE_CHECKING, Optional +from ._payload import MqttPayload +from ._module import MqttModule +if TYPE_CHECKING: + from ._wrapper import MqttWrapper +else: + MqttWrapper = None + + +class MqttNode: + _modules: List[MqttModule] + _module_subscriptions: dict[str, MqttModule] + _node_id: str + _node_secret: str + _payload_callbacks: list[callable] + _wrapper: Optional[MqttWrapper] + + def __init__(self, + node_id: str, + node_secret: Optional[str] = None): + self._modules = [] + self._module_subscriptions = {} + self._node_id = node_id + self._node_secret = node_secret + self._payload_callbacks = [] + self._logger = logging.getLogger(self.__class__.__name__) + self._wrapper = None + + def on_connect(self, wrapper: MqttWrapper): + self._wrapper = wrapper + for module in self._modules: + if not module.is_initialized(): + module.on_connect(self) + module.set_initialized() + + def on_disconnect(self): + self._wrapper = None + for module in self._modules: + module.unset_initialized() + + def on_message(self, topic, payload): + if topic in self._module_subscriptions: + payload = self._module_subscriptions[topic].handle_payload(self, topic, payload) + if isinstance(payload, MqttPayload): + for f in self._payload_callbacks: + f(self, payload) + + def load_module(self, module_name: str, *args, **kwargs) -> MqttModule: + module = importlib.import_module(f'..module.{module_name}', __name__) + if not hasattr(module, 'MODULE_NAME'): + raise RuntimeError(f'MODULE_NAME not found in module {module}') + cl = getattr(module, getattr(module, 'MODULE_NAME')) + instance = cl(*args, **kwargs) + self.add_module(instance) + return instance + + def add_module(self, module: MqttModule): + self._modules.append(module) + if self._wrapper and self._wrapper._connected: + module.on_connect(self) + module.set_initialized() + + def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1): + if not self._wrapper or not self._wrapper._connected: + raise RuntimeError('not connected') + + self._module_subscriptions[topic] = module + self._wrapper.subscribe(self.id, topic, qos) + + def publish(self, + topic: str, + payload: bytes, + qos: int = 1): + self._wrapper.publish(self.id, topic, payload, qos) + + def add_payload_callback(self, callback: callable): + self._payload_callbacks.append(callback) + + @property + def id(self) -> str: + return self._node_id + + @property + def secret(self) -> str: + return self._node_secret + + @secret.setter + def secret(self, secret: str) -> None: + self._node_secret = secret diff --git a/src/home/mqtt/_payload.py b/src/home/mqtt/_payload.py new file mode 100644 index 0000000..58eeae3 --- /dev/null +++ b/src/home/mqtt/_payload.py @@ -0,0 +1,145 @@ +import struct +import abc +import re + +from typing import Optional, Tuple + + +def pldstr(self) -> str: + attrs = [] + for field in self.__class__.__annotations__: + if hasattr(self, field): + attr = getattr(self, field) + attrs.append(f'{field}={attr}') + if attrs: + attrs_s = ' ' + attrs_s += ', '.join(attrs) + else: + attrs_s = '' + return f'<%s{attrs_s}>' % (self.__class__.__name__,) + + +class MqttPayload(abc.ABC): + FORMAT = '' + PACKER = {} + UNPACKER = {} + + def __init__(self, **kwargs): + for field in self.__class__.__annotations__: + setattr(self, field, kwargs[field]) + + def pack(self): + args = [] + bf_number = -1 + bf_arg = 0 + bf_progress = 0 + + for field, field_type in self.__class__.__annotations__.items(): + bfp = _bit_field_params(field_type) + if bfp: + n, s, b = bfp + if n != bf_number: + if bf_number != -1: + args.append(bf_arg) + bf_number = n + bf_progress = 0 + bf_arg = 0 + bf_arg |= (getattr(self, field) & (2 ** b - 1)) << bf_progress + bf_progress += b + + else: + if bf_number != -1: + args.append(bf_arg) + bf_number = -1 + bf_progress = 0 + bf_arg = 0 + + args.append(self._pack_field(field)) + + if bf_number != -1: + args.append(bf_arg) + + return struct.pack(self.FORMAT, *args) + + @classmethod + def unpack(cls, buf: bytes): + data = struct.unpack(cls.FORMAT, buf) + kwargs = {} + i = 0 + bf_number = -1 + bf_progress = 0 + + for field, field_type in cls.__annotations__.items(): + bfp = _bit_field_params(field_type) + if bfp: + n, s, b = bfp + if n != bf_number: + bf_number = n + bf_progress = 0 + kwargs[field] = (data[i] >> bf_progress) & (2 ** b - 1) + bf_progress += b + continue # don't increment i + + if bf_number != -1: + bf_number = -1 + i += 1 + + if issubclass(field_type, MqttPayloadCustomField): + kwargs[field] = field_type.unpack(data[i]) + else: + kwargs[field] = cls._unpack_field(field, data[i]) + i += 1 + + return cls(**kwargs) + + def _pack_field(self, name): + val = getattr(self, name) + if self.PACKER and name in self.PACKER: + return self.PACKER[name](val) + else: + return val + + @classmethod + def _unpack_field(cls, name, val): + if isinstance(val, MqttPayloadCustomField): + return + if cls.UNPACKER and name in cls.UNPACKER: + return cls.UNPACKER[name](val) + else: + return val + + def __str__(self): + return pldstr(self) + + +class MqttPayloadCustomField(abc.ABC): + def __init__(self, **kwargs): + for field in self.__class__.__annotations__: + setattr(self, field, kwargs[field]) + + @abc.abstractmethod + def __index__(self): + pass + + @classmethod + @abc.abstractmethod + def unpack(cls, *args, **kwargs): + pass + + def __str__(self): + return pldstr(self) + + +def bit_field(seq_no: int, total_bits: int, bits: int): + return type(f'MQTTPayloadBitField_{seq_no}_{total_bits}_{bits}', (object,), { + 'seq_no': seq_no, + 'total_bits': total_bits, + 'bits': bits + }) + + +def _bit_field_params(cl) -> Optional[Tuple[int, ...]]: + match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__) + if match is not None: + return tuple([int(match.group(i)) for i in range(1, 4)]) + return None \ No newline at end of file diff --git a/src/home/mqtt/_util.py b/src/home/mqtt/_util.py new file mode 100644 index 0000000..390d463 --- /dev/null +++ b/src/home/mqtt/_util.py @@ -0,0 +1,15 @@ +import os +import re + +from typing import List + + +def get_modules() -> List[str]: + modules = [] + modules_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module') + for name in os.listdir(modules_dir): + if os.path.isdir(os.path.join(modules_dir, name)): + continue + name = re.sub(r'\.py$', '', name) + modules.append(name) + return modules diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py new file mode 100644 index 0000000..f858f88 --- /dev/null +++ b/src/home/mqtt/_wrapper.py @@ -0,0 +1,59 @@ +import paho.mqtt.client as mqtt + +from ._mqtt import Mqtt +from ._node import MqttNode +from ..config import config +from ..util import strgen + + +class MqttWrapper(Mqtt): + _nodes: list[MqttNode] + + def __init__(self, + client_id: str, + topic_prefix='hk', + randomize_client_id=False, + clean_session=True): + if randomize_client_id: + client_id += '_'+strgen(6) + super().__init__(clean_session=clean_session, + client_id=client_id) + self._nodes = [] + self._topic_prefix = topic_prefix + + def on_connect(self, client: mqtt.Client, userdata, flags, rc): + super().on_connect(client, userdata, flags, rc) + for node in self._nodes: + node.on_connect(self) + + def on_disconnect(self, client: mqtt.Client, userdata, rc): + super().on_disconnect(client, userdata, rc) + for node in self._nodes: + node.on_disconnect() + + def on_message(self, client: mqtt.Client, userdata, msg): + try: + topic = msg.topic + for node in self._nodes: + node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) + except Exception as e: + self._logger.exception(str(e)) + + def add_node(self, node: MqttNode): + self._nodes.append(node) + if self._connected: + node.on_connect(self) + + def subscribe(self, + node_id: str, + topic: str, + qos: int): + self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos) + + def publish(self, + node_id: str, + topic: str, + payload: bytes, + qos: int): + self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos) + self._client.loop_write() diff --git a/src/home/mqtt/esp.py b/src/home/mqtt/esp.py deleted file mode 100644 index 56ced83..0000000 --- a/src/home/mqtt/esp.py +++ /dev/null @@ -1,106 +0,0 @@ -import re -import paho.mqtt.client as mqtt - -from .mqtt import MqttBase -from typing import Optional, Union -from .payload.esp import ( - OTAPayload, - OTAResultPayload, - DiagnosticsPayload, - InitialDiagnosticsPayload -) - - -class MqttEspDevice: - id: str - secret: Optional[str] - - def __init__(self, id: str, secret: Optional[str] = None): - self.id = id - self.secret = secret - - -class MqttEspBase(MqttBase): - _devices: list[MqttEspDevice] - _message_callback: Optional[callable] - _ota_publish_callback: Optional[callable] - - TOPIC_LEAF = 'esp' - - def __init__(self, - devices: Union[MqttEspDevice, list[MqttEspDevice]], - subscribe_to_updates=True): - super().__init__(clean_session=True) - if not isinstance(devices, list): - devices = [devices] - self._devices = devices - self._message_callback = None - self._ota_publish_callback = None - self._subscribe_to_updates = subscribe_to_updates - self._ota_mid = None - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - super().on_connect(client, userdata, flags, rc) - - if self._subscribe_to_updates: - for device in self._devices: - topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#' - self._logger.debug(f"subscribing to {topic}") - client.subscribe(topic, qos=1) - - def on_publish(self, client: mqtt.Client, userdata, mid): - if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback: - self._ota_publish_callback() - - def set_message_callback(self, callback: callable): - self._message_callback = callback - - def on_message(self, client: mqtt.Client, userdata, msg): - try: - match = re.match(self.get_mqtt_topics(), msg.topic) - self._logger.debug(f'topic: {msg.topic}') - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - # try: - next(d for d in self._devices if d.id == device_id) - # except StopIteration:h - # return - - message = None - if subtopic == 'stat': - message = DiagnosticsPayload.unpack(msg.payload) - elif subtopic == 'stat1': - message = InitialDiagnosticsPayload.unpack(msg.payload) - elif subtopic == 'otares': - message = OTAResultPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - return True - - except Exception as e: - self._logger.exception(str(e)) - - def push_ota(self, - device_id, - filename: str, - publish_callback: callable, - qos: int): - device = next(d for d in self._devices if d.id == device_id) - assert device.secret is not None, 'device secret not specified' - - self._ota_publish_callback = publish_callback - payload = OTAPayload(secret=device.secret, filename=filename) - publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota', - payload=payload.pack(), - qos=qos) - self._ota_mid = publish_result.mid - self._client.loop_write() - - @classmethod - def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None): - return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$' \ No newline at end of file diff --git a/src/home/mqtt/module/diagnostics.py b/src/home/mqtt/module/diagnostics.py new file mode 100644 index 0000000..5db5e99 --- /dev/null +++ b/src/home/mqtt/module/diagnostics.py @@ -0,0 +1,64 @@ +from .._payload import MqttPayload, MqttPayloadCustomField +from .._node import MqttNode, MqttModule +from typing import Optional + +MODULE_NAME = 'MqttDiagnosticsModule' + + +class DiagnosticsFlags(MqttPayloadCustomField): + state: bool + config_changed_value_present: bool + config_changed: bool + + @staticmethod + def unpack(flags: int): + # _logger.debug(f'StatFlags.unpack: flags={flags}') + state = flags & 0x1 + ccvp = (flags >> 1) & 0x1 + cc = (flags >> 2) & 0x1 + # _logger.debug(f'StatFlags.unpack: state={state}') + return DiagnosticsFlags(state=(state == 1), + config_changed_value_present=(ccvp == 1), + config_changed=(cc == 1)) + + def __index__(self): + bits = 0 + bits |= (int(self.state) & 0x1) + bits |= (int(self.config_changed_value_present) & 0x1) << 1 + bits |= (int(self.config_changed) & 0x1) << 2 + return bits + + +class InitialDiagnosticsPayload(MqttPayload): + FORMAT = '=IBbIB' + + ip: int + fw_version: int + rssi: int + free_heap: int + flags: DiagnosticsFlags + + +class DiagnosticsPayload(MqttPayload): + FORMAT = '=bIB' + + rssi: int + free_heap: int + flags: DiagnosticsFlags + + +class MqttDiagnosticsModule(MqttModule): + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + for topic in ('diag', 'd1ag', 'stat', 'stat1'): + mqtt.subscribe_module(topic, self) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + message = None + if topic in ('stat', 'diag'): + message = DiagnosticsPayload.unpack(payload) + elif topic in ('stat1', 'd1ag'): + message = InitialDiagnosticsPayload.unpack(payload) + if message: + self._logger.debug(message) + return message diff --git a/src/home/mqtt/module/inverter.py b/src/home/mqtt/module/inverter.py new file mode 100644 index 0000000..d927a06 --- /dev/null +++ b/src/home/mqtt/module/inverter.py @@ -0,0 +1,195 @@ +import time +import json +import datetime +try: + import inverterd +except: + pass + +from typing import Optional +from .._module import MqttModule +from .._node import MqttNode +from .._payload import MqttPayload, bit_field +try: + from home.database import InverterDatabase +except: + pass + +_mult_10 = lambda n: int(n*10) +_div_10 = lambda n: n/10 + + +MODULE_NAME = 'MqttInverterModule' + +STATUS_TOPIC = 'status' +GENERATION_TOPIC = 'generation' + + +class MqttInverterStatusPayload(MqttPayload): + # 46 bytes + FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH' + + PACKER = { + 'grid_voltage': _mult_10, + 'grid_freq': _mult_10, + 'ac_output_voltage': _mult_10, + 'ac_output_freq': _mult_10, + 'battery_voltage': _mult_10, + 'battery_voltage_scc': _mult_10, + 'battery_voltage_scc2': _mult_10, + 'pv1_input_voltage': _mult_10, + 'pv2_input_voltage': _mult_10 + } + UNPACKER = { + 'grid_voltage': _div_10, + 'grid_freq': _div_10, + 'ac_output_voltage': _div_10, + 'ac_output_freq': _div_10, + 'battery_voltage': _div_10, + 'battery_voltage_scc': _div_10, + 'battery_voltage_scc2': _div_10, + 'pv1_input_voltage': _div_10, + 'pv2_input_voltage': _div_10 + } + + time: int + grid_voltage: float + grid_freq: float + ac_output_voltage: float + ac_output_freq: float + ac_output_apparent_power: int + ac_output_active_power: int + output_load_percent: int + battery_voltage: float + battery_voltage_scc: float + battery_voltage_scc2: float + battery_discharge_current: int + battery_charge_current: int + battery_capacity: int + inverter_heat_sink_temp: int + mppt1_charger_temp: int + mppt2_charger_temp: int + pv1_input_power: int + pv2_input_power: int + pv1_input_voltage: float + pv2_input_voltage: float + + # H + mppt1_charger_status: bit_field(0, 16, 2) + mppt2_charger_status: bit_field(0, 16, 2) + battery_power_direction: bit_field(0, 16, 2) + dc_ac_power_direction: bit_field(0, 16, 2) + line_power_direction: bit_field(0, 16, 2) + load_connected: bit_field(0, 16, 1) + + +class MqttInverterGenerationPayload(MqttPayload): + # 8 bytes + FORMAT = 'II' + + time: int + wh: int + + +class MqttInverterModule(MqttModule): + _status_poll_freq: int + _generation_poll_freq: int + _inverter: Optional[inverterd.Client] + _database: Optional[InverterDatabase] + _gen_prev: float + + def __init__(self, status_poll_freq=0, generation_poll_freq=0): + super().__init__(tick_interval=status_poll_freq) + self._status_poll_freq = status_poll_freq + self._generation_poll_freq = generation_poll_freq + + # this defines whether this is a publisher or a subscriber + if status_poll_freq > 0: + self._inverter = inverterd.Client() + self._inverter.connect() + self._inverter.format(inverterd.Format.SIMPLE_JSON) + self._database = None + else: + self._inverter = None + self._database = InverterDatabase() + + self._gen_prev = 0 + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + if not self._inverter: + mqtt.subscribe_module(STATUS_TOPIC, self) + mqtt.subscribe_module(GENERATION_TOPIC, self) + + def tick(self): + if not self._inverter: + return + + # read status + now = time.time() + try: + raw = self._inverter.exec('get-status') + except inverterd.InverterError as e: + self._logger.error(f'inverter error: {str(e)}') + # TODO send to server + return + + data = json.loads(raw)['data'] + status = MqttInverterStatusPayload(time=round(now), **data) + self._mqtt_node_ref.publish(STATUS_TOPIC, status.pack()) + + # read today's generation stat + now = time.time() + if self._gen_prev == 0 or now - self._gen_prev >= self._generation_poll_freq: + self._gen_prev = now + today = datetime.date.today() + try: + raw = self._inverter.exec('get-day-generated', (today.year, today.month, today.day)) + except inverterd.InverterError as e: + self._logger.error(f'inverter error: {str(e)}') + # TODO send to server + return + + data = json.loads(raw)['data'] + gen = MqttInverterGenerationPayload(time=round(now), wh=data['wh']) + self._mqtt_node_ref.publish(GENERATION_TOPIC, gen.pack()) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + home_id = 1 # legacy compat + + if topic == STATUS_TOPIC: + s = MqttInverterStatusPayload.unpack(payload) + self._database.add_status(home_id=home_id, + client_time=s.time, + grid_voltage=int(s.grid_voltage*10), + grid_freq=int(s.grid_freq * 10), + ac_output_voltage=int(s.ac_output_voltage * 10), + ac_output_freq=int(s.ac_output_freq * 10), + ac_output_apparent_power=s.ac_output_apparent_power, + ac_output_active_power=s.ac_output_active_power, + output_load_percent=s.output_load_percent, + battery_voltage=int(s.battery_voltage * 10), + battery_voltage_scc=int(s.battery_voltage_scc * 10), + battery_voltage_scc2=int(s.battery_voltage_scc2 * 10), + battery_discharge_current=s.battery_discharge_current, + battery_charge_current=s.battery_charge_current, + battery_capacity=s.battery_capacity, + inverter_heat_sink_temp=s.inverter_heat_sink_temp, + mppt1_charger_temp=s.mppt1_charger_temp, + mppt2_charger_temp=s.mppt2_charger_temp, + pv1_input_power=s.pv1_input_power, + pv2_input_power=s.pv2_input_power, + pv1_input_voltage=int(s.pv1_input_voltage * 10), + pv2_input_voltage=int(s.pv2_input_voltage * 10), + mppt1_charger_status=s.mppt1_charger_status, + mppt2_charger_status=s.mppt2_charger_status, + battery_power_direction=s.battery_power_direction, + dc_ac_power_direction=s.dc_ac_power_direction, + line_power_direction=s.line_power_direction, + load_connected=s.load_connected) + return s + + elif topic == GENERATION_TOPIC: + gen = MqttInverterGenerationPayload.unpack(payload) + self._database.add_generation(home_id, gen.time, gen.wh) + return gen diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py new file mode 100644 index 0000000..cd34332 --- /dev/null +++ b/src/home/mqtt/module/ota.py @@ -0,0 +1,77 @@ +import hashlib + +from typing import Optional +from .._payload import MqttPayload +from .._node import MqttModule, MqttNode + +MODULE_NAME = 'MqttOtaModule' + + +class OtaResultPayload(MqttPayload): + FORMAT = '=BB' + result: int + error_code: int + + +class OtaPayload(MqttPayload): + secret: str + filename: str + + # structure of returned data: + # + # uint8_t[len(secret)] secret; + # uint8_t[16] md5; + # *uint8_t data + + def pack(self): + buf = bytearray(self.secret.encode()) + m = hashlib.md5() + with open(self.filename, 'rb') as fd: + content = fd.read() + m.update(content) + buf.extend(m.digest()) + buf.extend(content) + return buf + + def unpack(cls, buf: bytes): + raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') + # secret = buf[:12].decode() + # filename = buf[12:].decode() + # return OTAPayload(secret=secret, filename=filename) + + +class MqttOtaModule(MqttModule): + _ota_request: Optional[tuple[str, int]] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._ota_request = None + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + mqtt.subscribe_module("otares", self) + + if self._ota_request is not None: + filename, qos = self._ota_request + self._ota_request = None + self.do_push_ota(self._mqtt_node_ref.secret, filename, qos) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + if topic == 'otares': + message = OtaResultPayload.unpack(payload) + self._logger.debug(message) + return message + + def do_push_ota(self, secret: str, filename: str, qos: int): + payload = OtaPayload(secret=secret, filename=filename) + self._mqtt_node_ref.publish('ota', + payload=payload.pack(), + qos=qos) + + def push_ota(self, + filename: str, + qos: int): + if not self._initialized: + self._ota_request = (filename, qos) + else: + self.do_push_ota(filename, qos) diff --git a/src/home/mqtt/module/relay.py b/src/home/mqtt/module/relay.py new file mode 100644 index 0000000..e968031 --- /dev/null +++ b/src/home/mqtt/module/relay.py @@ -0,0 +1,92 @@ +import datetime + +from typing import Optional +from .. import MqttModule, MqttPayload, MqttNode + +MODULE_NAME = 'MqttRelayModule' + + +class MqttPowerSwitchPayload(MqttPayload): + FORMAT = '=12sB' + PACKER = { + 'state': lambda n: int(n), + 'secret': lambda s: s.encode('utf-8') + } + UNPACKER = { + 'state': lambda n: bool(n), + 'secret': lambda s: s.decode('utf-8') + } + + secret: str + state: bool + + +class MqttPowerStatusPayload(MqttPayload): + FORMAT = '=B' + PACKER = { + 'opened': lambda n: int(n), + } + UNPACKER = { + 'opened': lambda n: bool(n), + } + + opened: bool + + +class MqttRelayState: + enabled: bool + update_time: datetime.datetime + rssi: int + fw_version: int + ever_updated: bool + + def __init__(self): + self.ever_updated = False + self.enabled = False + self.rssi = 0 + + def update(self, + enabled: bool, + rssi: int, + fw_version=None): + self.ever_updated = True + self.enabled = enabled + self.rssi = rssi + self.update_time = datetime.datetime.now() + if fw_version: + self.fw_version = fw_version + + +class MqttRelayModule(MqttModule): + _legacy_topics: bool + + def __init__(self, legacy_topics=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self._legacy_topics = legacy_topics + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + mqtt.subscribe_module(self._get_switch_topic(), self) + mqtt.subscribe_module('relay/status', self) + + def switchpower(self, + enable: bool): + payload = MqttPowerSwitchPayload(secret=self._mqtt_node_ref.secret, + state=enable) + self._mqtt_node_ref.publish(self._get_switch_topic(), + payload=payload.pack()) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + message = None + + if topic == self._get_switch_topic(): + message = MqttPowerSwitchPayload.unpack(payload) + elif topic == 'relay/status': + message = MqttPowerStatusPayload.unpack(payload) + + if message is not None: + self._logger.debug(message) + return message + + def _get_switch_topic(self) -> str: + return 'relay/power' if self._legacy_topics else 'relay/switch' diff --git a/src/home/mqtt/module/temphum.py b/src/home/mqtt/module/temphum.py new file mode 100644 index 0000000..fd02cca --- /dev/null +++ b/src/home/mqtt/module/temphum.py @@ -0,0 +1,82 @@ +from .._node import MqttNode +from .._module import MqttModule +from .._payload import MqttPayload +from typing import Optional +from ...temphum import BaseSensor + +two_digits_precision = lambda x: round(x, 2) + +MODULE_NAME = 'MqttTempHumModule' +DATA_TOPIC = 'temphum/data' + + +class MqttTemphumDataPayload(MqttPayload): + FORMAT = '=ddb' + UNPACKER = { + 'temp': two_digits_precision, + 'rh': two_digits_precision + } + + temp: float + rh: float + error: int + + +# class MqttTempHumNodes(HashableEnum): +# KBN_SH_HALL = auto() +# KBN_SH_BATHROOM = auto() +# KBN_SH_LIVINGROOM = auto() +# KBN_SH_BEDROOM = auto() +# +# KBN_BH_2FL = auto() +# KBN_BH_2FL_STREET = auto() +# KBN_BH_1FL_LIVINGROOM = auto() +# KBN_BH_1FL_BEDROOM = auto() +# KBN_BH_1FL_BATHROOM = auto() +# +# KBN_NH_1FL_INV = auto() +# KBN_NH_1FL_CENTER = auto() +# KBN_NH_1LF_KT = auto() +# KBN_NH_1FL_DS = auto() +# KBN_NH_1FS_EZ = auto() +# +# SPB_FLAT120_CABINET = auto() + + +class MqttTempHumModule(MqttModule): + def __init__(self, + sensor: Optional[BaseSensor] = None, + write_to_database=False, + *args, **kwargs): + if sensor is not None: + kwargs['tick_interval'] = 10 + super().__init__(*args, **kwargs) + self._sensor = sensor + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + mqtt.subscribe_module(DATA_TOPIC, self) + + def tick(self): + if not self._sensor: + return + + error = 0 + temp = 0 + rh = 0 + try: + temp = self._sensor.temperature() + rh = self._sensor.humidity() + except: + error = 1 + pld = MqttTemphumDataPayload(temp=temp, rh=rh, error=error) + self._mqtt_node_ref.publish(DATA_TOPIC, pld.pack()) + + def handle_payload(self, + mqtt: MqttNode, + topic: str, + payload: bytes) -> Optional[MqttPayload]: + if topic == DATA_TOPIC: + message = MqttTemphumDataPayload.unpack(payload) + self._logger.debug(message) + return message diff --git a/src/home/mqtt/mqtt.py b/src/home/mqtt/mqtt.py deleted file mode 100644 index 4acd4f6..0000000 --- a/src/home/mqtt/mqtt.py +++ /dev/null @@ -1,76 +0,0 @@ -import os.path -import paho.mqtt.client as mqtt -import ssl -import logging - -from typing import Tuple -from ..config import config - - -def username_and_password() -> Tuple[str, str]: - username = config['mqtt']['username'] if 'username' in config['mqtt'] else None - password = config['mqtt']['password'] if 'password' in config['mqtt'] else None - return username, password - - -class MqttBase: - def __init__(self, clean_session=True): - self._client = mqtt.Client(client_id=config['mqtt']['client_id'], - protocol=mqtt.MQTTv311, - clean_session=clean_session) - self._client.on_connect = self.on_connect - self._client.on_disconnect = self.on_disconnect - self._client.on_message = self.on_message - self._client.on_log = self.on_log - self._client.on_publish = self.on_publish - self._loop_started = False - - self._logger = logging.getLogger(self.__class__.__name__) - - username, password = username_and_password() - if username and password: - self._logger.debug(f'username={username} password={password}') - self._client.username_pw_set(username, password) - - def configure_tls(self): - ca_certs = os.path.realpath(os.path.join( - os.path.dirname(os.path.realpath(__file__)), - '..', - '..', - '..', - 'assets', - 'mqtt_ca.crt' - )) - self._client.tls_set(ca_certs=ca_certs, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_TLSv1_2) - - def connect_and_loop(self, loop_forever=True): - host = config['mqtt']['host'] - port = config['mqtt']['port'] - - self._client.connect(host, port, 60) - if loop_forever: - self._client.loop_forever() - else: - self._client.loop_start() - self._loop_started = True - - def disconnect(self): - self._client.disconnect() - self._client.loop_write() - self._client.loop_stop() - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - self._logger.info("Connected with result code " + str(rc)) - - def on_disconnect(self, client: mqtt.Client, userdata, rc): - self._logger.info("Disconnected with result code " + str(rc)) - - def on_log(self, client: mqtt.Client, userdata, level, buf): - level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO - self._logger.log(level, f'MQTT: {buf}') - - def on_message(self, client: mqtt.Client, userdata, msg): - self._logger.debug(msg.topic + ": " + str(msg.payload)) - - def on_publish(self, client: mqtt.Client, userdata, mid): - self._logger.debug(f'publish done, mid={mid}') \ No newline at end of file diff --git a/src/home/mqtt/payload/__init__.py b/src/home/mqtt/payload/__init__.py deleted file mode 100644 index eee6709..0000000 --- a/src/home/mqtt/payload/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base_payload import MqttPayload \ No newline at end of file diff --git a/src/home/mqtt/payload/base_payload.py b/src/home/mqtt/payload/base_payload.py deleted file mode 100644 index 1abd898..0000000 --- a/src/home/mqtt/payload/base_payload.py +++ /dev/null @@ -1,145 +0,0 @@ -import abc -import struct -import re - -from typing import Optional, Tuple - - -def pldstr(self) -> str: - attrs = [] - for field in self.__class__.__annotations__: - if hasattr(self, field): - attr = getattr(self, field) - attrs.append(f'{field}={attr}') - if attrs: - attrs_s = ' ' - attrs_s += ', '.join(attrs) - else: - attrs_s = '' - return f'<%s{attrs_s}>' % (self.__class__.__name__,) - - -class MqttPayload(abc.ABC): - FORMAT = '' - PACKER = {} - UNPACKER = {} - - def __init__(self, **kwargs): - for field in self.__class__.__annotations__: - setattr(self, field, kwargs[field]) - - def pack(self): - args = [] - bf_number = -1 - bf_arg = 0 - bf_progress = 0 - - for field, field_type in self.__class__.__annotations__.items(): - bfp = _bit_field_params(field_type) - if bfp: - n, s, b = bfp - if n != bf_number: - if bf_number != -1: - args.append(bf_arg) - bf_number = n - bf_progress = 0 - bf_arg = 0 - bf_arg |= (getattr(self, field) & (2 ** b - 1)) << bf_progress - bf_progress += b - - else: - if bf_number != -1: - args.append(bf_arg) - bf_number = -1 - bf_progress = 0 - bf_arg = 0 - - args.append(self._pack_field(field)) - - if bf_number != -1: - args.append(bf_arg) - - return struct.pack(self.FORMAT, *args) - - @classmethod - def unpack(cls, buf: bytes): - data = struct.unpack(cls.FORMAT, buf) - kwargs = {} - i = 0 - bf_number = -1 - bf_progress = 0 - - for field, field_type in cls.__annotations__.items(): - bfp = _bit_field_params(field_type) - if bfp: - n, s, b = bfp - if n != bf_number: - bf_number = n - bf_progress = 0 - kwargs[field] = (data[i] >> bf_progress) & (2 ** b - 1) - bf_progress += b - continue # don't increment i - - if bf_number != -1: - bf_number = -1 - i += 1 - - if issubclass(field_type, MqttPayloadCustomField): - kwargs[field] = field_type.unpack(data[i]) - else: - kwargs[field] = cls._unpack_field(field, data[i]) - i += 1 - - return cls(**kwargs) - - def _pack_field(self, name): - val = getattr(self, name) - if self.PACKER and name in self.PACKER: - return self.PACKER[name](val) - else: - return val - - @classmethod - def _unpack_field(cls, name, val): - if isinstance(val, MqttPayloadCustomField): - return - if cls.UNPACKER and name in cls.UNPACKER: - return cls.UNPACKER[name](val) - else: - return val - - def __str__(self): - return pldstr(self) - - -class MqttPayloadCustomField(abc.ABC): - def __init__(self, **kwargs): - for field in self.__class__.__annotations__: - setattr(self, field, kwargs[field]) - - @abc.abstractmethod - def __index__(self): - pass - - @classmethod - @abc.abstractmethod - def unpack(cls, *args, **kwargs): - pass - - def __str__(self): - return pldstr(self) - - -def bit_field(seq_no: int, total_bits: int, bits: int): - return type(f'MQTTPayloadBitField_{seq_no}_{total_bits}_{bits}', (object,), { - 'seq_no': seq_no, - 'total_bits': total_bits, - 'bits': bits - }) - - -def _bit_field_params(cl) -> Optional[Tuple[int, ...]]: - match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__) - if match is not None: - return tuple([int(match.group(i)) for i in range(1, 4)]) - return None diff --git a/src/home/mqtt/payload/esp.py b/src/home/mqtt/payload/esp.py deleted file mode 100644 index 171cdb9..0000000 --- a/src/home/mqtt/payload/esp.py +++ /dev/null @@ -1,78 +0,0 @@ -import hashlib - -from .base_payload import MqttPayload, MqttPayloadCustomField - - -class OTAResultPayload(MqttPayload): - FORMAT = '=BB' - result: int - error_code: int - - -class OTAPayload(MqttPayload): - secret: str - filename: str - - # structure of returned data: - # - # uint8_t[len(secret)] secret; - # uint8_t[16] md5; - # *uint8_t data - - def pack(self): - buf = bytearray(self.secret.encode()) - m = hashlib.md5() - with open(self.filename, 'rb') as fd: - content = fd.read() - m.update(content) - buf.extend(m.digest()) - buf.extend(content) - return buf - - def unpack(cls, buf: bytes): - raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') - # secret = buf[:12].decode() - # filename = buf[12:].decode() - # return OTAPayload(secret=secret, filename=filename) - - -class DiagnosticsFlags(MqttPayloadCustomField): - state: bool - config_changed_value_present: bool - config_changed: bool - - @staticmethod - def unpack(flags: int): - # _logger.debug(f'StatFlags.unpack: flags={flags}') - state = flags & 0x1 - ccvp = (flags >> 1) & 0x1 - cc = (flags >> 2) & 0x1 - # _logger.debug(f'StatFlags.unpack: state={state}') - return DiagnosticsFlags(state=(state == 1), - config_changed_value_present=(ccvp == 1), - config_changed=(cc == 1)) - - def __index__(self): - bits = 0 - bits |= (int(self.state) & 0x1) - bits |= (int(self.config_changed_value_present) & 0x1) << 1 - bits |= (int(self.config_changed) & 0x1) << 2 - return bits - - -class InitialDiagnosticsPayload(MqttPayload): - FORMAT = '=IBbIB' - - ip: int - fw_version: int - rssi: int - free_heap: int - flags: DiagnosticsFlags - - -class DiagnosticsPayload(MqttPayload): - FORMAT = '=bIB' - - rssi: int - free_heap: int - flags: DiagnosticsFlags diff --git a/src/home/mqtt/payload/inverter.py b/src/home/mqtt/payload/inverter.py deleted file mode 100644 index 09388df..0000000 --- a/src/home/mqtt/payload/inverter.py +++ /dev/null @@ -1,73 +0,0 @@ -import struct - -from .base_payload import MqttPayload, bit_field -from typing import Tuple - -_mult_10 = lambda n: int(n*10) -_div_10 = lambda n: n/10 - - -class Status(MqttPayload): - # 46 bytes - FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH' - - PACKER = { - 'grid_voltage': _mult_10, - 'grid_freq': _mult_10, - 'ac_output_voltage': _mult_10, - 'ac_output_freq': _mult_10, - 'battery_voltage': _mult_10, - 'battery_voltage_scc': _mult_10, - 'battery_voltage_scc2': _mult_10, - 'pv1_input_voltage': _mult_10, - 'pv2_input_voltage': _mult_10 - } - UNPACKER = { - 'grid_voltage': _div_10, - 'grid_freq': _div_10, - 'ac_output_voltage': _div_10, - 'ac_output_freq': _div_10, - 'battery_voltage': _div_10, - 'battery_voltage_scc': _div_10, - 'battery_voltage_scc2': _div_10, - 'pv1_input_voltage': _div_10, - 'pv2_input_voltage': _div_10 - } - - time: int - grid_voltage: float - grid_freq: float - ac_output_voltage: float - ac_output_freq: float - ac_output_apparent_power: int - ac_output_active_power: int - output_load_percent: int - battery_voltage: float - battery_voltage_scc: float - battery_voltage_scc2: float - battery_discharge_current: int - battery_charge_current: int - battery_capacity: int - inverter_heat_sink_temp: int - mppt1_charger_temp: int - mppt2_charger_temp: int - pv1_input_power: int - pv2_input_power: int - pv1_input_voltage: float - pv2_input_voltage: float - - # H - mppt1_charger_status: bit_field(0, 16, 2) - mppt2_charger_status: bit_field(0, 16, 2) - battery_power_direction: bit_field(0, 16, 2) - dc_ac_power_direction: bit_field(0, 16, 2) - line_power_direction: bit_field(0, 16, 2) - load_connected: bit_field(0, 16, 1) - - -class Generation(MqttPayload): - # 8 bytes - FORMAT = 'II' - - time: int - wh: int diff --git a/src/home/mqtt/payload/relay.py b/src/home/mqtt/payload/relay.py deleted file mode 100644 index 4902991..0000000 --- a/src/home/mqtt/payload/relay.py +++ /dev/null @@ -1,22 +0,0 @@ -from .base_payload import MqttPayload -from .esp import ( - OTAResultPayload, - OTAPayload, - InitialDiagnosticsPayload, - DiagnosticsPayload -) - - -class PowerPayload(MqttPayload): - FORMAT = '=12sB' - PACKER = { - 'state': lambda n: int(n), - 'secret': lambda s: s.encode('utf-8') - } - UNPACKER = { - 'state': lambda n: bool(n), - 'secret': lambda s: s.decode('utf-8') - } - - secret: str - state: bool diff --git a/src/home/mqtt/payload/sensors.py b/src/home/mqtt/payload/sensors.py deleted file mode 100644 index f99b307..0000000 --- a/src/home/mqtt/payload/sensors.py +++ /dev/null @@ -1,20 +0,0 @@ -from .base_payload import MqttPayload - -_mult_100 = lambda n: int(n*100) -_div_100 = lambda n: n/100 - - -class Temperature(MqttPayload): - FORMAT = 'IhH' - PACKER = { - 'temp': _mult_100, - 'rh': _mult_100, - } - UNPACKER = { - 'temp': _div_100, - 'rh': _div_100, - } - - time: int - temp: float - rh: float diff --git a/src/home/mqtt/payload/temphum.py b/src/home/mqtt/payload/temphum.py deleted file mode 100644 index c0b744e..0000000 --- a/src/home/mqtt/payload/temphum.py +++ /dev/null @@ -1,15 +0,0 @@ -from .base_payload import MqttPayload - -two_digits_precision = lambda x: round(x, 2) - - -class TempHumDataPayload(MqttPayload): - FORMAT = '=ddb' - UNPACKER = { - 'temp': two_digits_precision, - 'rh': two_digits_precision - } - - temp: float - rh: float - error: int diff --git a/src/home/mqtt/relay.py b/src/home/mqtt/relay.py deleted file mode 100644 index a90f19c..0000000 --- a/src/home/mqtt/relay.py +++ /dev/null @@ -1,71 +0,0 @@ -import paho.mqtt.client as mqtt -import re -import datetime - -from .payload.relay import ( - PowerPayload, -) -from .esp import MqttEspBase - - -class MqttRelay(MqttEspBase): - TOPIC_LEAF = 'relay' - - def set_power(self, device_id, enable: bool, secret=None): - device = next(d for d in self._devices if d.id == device_id) - secret = secret if secret else device.secret - - assert secret is not None, 'device secret not specified' - - payload = PowerPayload(secret=secret, - state=enable) - self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/power', - payload=payload.pack(), - qos=1) - self._client.loop_write() - - def on_message(self, client: mqtt.Client, userdata, msg): - if super().on_message(client, userdata, msg): - return - - try: - match = re.match(self.get_mqtt_topics(['power']), msg.topic) - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - message = None - if subtopic == 'power': - message = PowerPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - - except Exception as e: - self._logger.exception(str(e)) - - -class MqttRelayState: - enabled: bool - update_time: datetime.datetime - rssi: int - fw_version: int - ever_updated: bool - - def __init__(self): - self.ever_updated = False - self.enabled = False - self.rssi = 0 - - def update(self, - enabled: bool, - rssi: int, - fw_version=None): - self.ever_updated = True - self.enabled = enabled - self.rssi = rssi - self.update_time = datetime.datetime.now() - if fw_version: - self.fw_version = fw_version diff --git a/src/home/mqtt/temphum.py b/src/home/mqtt/temphum.py deleted file mode 100644 index 44810ef..0000000 --- a/src/home/mqtt/temphum.py +++ /dev/null @@ -1,54 +0,0 @@ -import paho.mqtt.client as mqtt -import re - -from enum import auto -from .payload.temphum import TempHumDataPayload -from .esp import MqttEspBase -from ..util import HashableEnum - - -class MqttTempHumNodes(HashableEnum): - KBN_SH_HALL = auto() - KBN_SH_BATHROOM = auto() - KBN_SH_LIVINGROOM = auto() - KBN_SH_BEDROOM = auto() - - KBN_BH_2FL = auto() - KBN_BH_2FL_STREET = auto() - KBN_BH_1FL_LIVINGROOM = auto() - KBN_BH_1FL_BEDROOM = auto() - KBN_BH_1FL_BATHROOM = auto() - - KBN_NH_1FL_INV = auto() - KBN_NH_1FL_CENTER = auto() - KBN_NH_1LF_KT = auto() - KBN_NH_1FL_DS = auto() - KBN_NH_1FS_EZ = auto() - - SPB_FLAT120_CABINET = auto() - - -class MqttTempHum(MqttEspBase): - TOPIC_LEAF = 'temphum' - - def on_message(self, client: mqtt.Client, userdata, msg): - if super().on_message(client, userdata, msg): - return - - try: - match = re.match(self.get_mqtt_topics(['data']), msg.topic) - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - message = None - if subtopic == 'data': - message = TempHumDataPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - - except Exception as e: - self._logger.exception(str(e)) diff --git a/src/home/mqtt/util.py b/src/home/mqtt/util.py deleted file mode 100644 index f71ffd8..0000000 --- a/src/home/mqtt/util.py +++ /dev/null @@ -1,8 +0,0 @@ -import time - - -def poll_tick(freq): - t = time.time() - while True: - t += freq - yield max(t - time.time(), 0) -- cgit v1.2.3 From 327a5298359027099631c3c9967b7585928cd367 Mon Sep 17 00:00:00 2001 From: Evgeny Zinoviev Date: Sat, 10 Jun 2023 21:54:56 +0300 Subject: port relay_mqtt_http_proxy to new config scheme; config: support addr types & normalization --- src/home/mqtt/_config.py | 8 ++++---- src/home/mqtt/_wrapper.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) (limited to 'src/home/mqtt') diff --git a/src/home/mqtt/_config.py b/src/home/mqtt/_config.py index f9047b4..9ba9443 100644 --- a/src/home/mqtt/_config.py +++ b/src/home/mqtt/_config.py @@ -9,8 +9,8 @@ MqttCreds = namedtuple('MqttCreds', 'username, password') class MqttConfig(ConfigUnit): NAME = 'mqtt' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: addr_schema = { 'type': 'dict', 'required': True, @@ -64,8 +64,8 @@ class MqttConfig(ConfigUnit): class MqttNodesConfig(ConfigUnit): NAME = 'mqtt_nodes' - @staticmethod - def schema() -> Optional[dict]: + @classmethod + def schema(cls) -> Optional[dict]: return { 'common': { 'type': 'dict', diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py index f858f88..3c2774c 100644 --- a/src/home/mqtt/_wrapper.py +++ b/src/home/mqtt/_wrapper.py @@ -2,7 +2,6 @@ import paho.mqtt.client as mqtt from ._mqtt import Mqtt from ._node import MqttNode -from ..config import config from ..util import strgen @@ -34,8 +33,10 @@ class MqttWrapper(Mqtt): def on_message(self, client: mqtt.Client, userdata, msg): try: topic = msg.topic + topic_node = topic[len(self._topic_prefix)+1:topic.find('/', len(self._topic_prefix)+1)] for node in self._nodes: - node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) + if node.id in ('+', topic_node): + node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) except Exception as e: self._logger.exception(str(e)) -- cgit v1.2.3 From b0bf43e6a272d42a55158e657bd937cb82fc3d8d Mon Sep 17 00:00:00 2001 From: Evgeny Zinoviev Date: Sat, 10 Jun 2023 23:02:34 +0300 Subject: move files, rename home package to homekit --- src/home/mqtt/__init__.py | 7 -- src/home/mqtt/_config.py | 165 ------------------------------ src/home/mqtt/_module.py | 70 ------------- src/home/mqtt/_mqtt.py | 86 ---------------- src/home/mqtt/_node.py | 92 ----------------- src/home/mqtt/_payload.py | 145 --------------------------- src/home/mqtt/_util.py | 15 --- src/home/mqtt/_wrapper.py | 60 ----------- src/home/mqtt/module/diagnostics.py | 64 ------------ src/home/mqtt/module/inverter.py | 195 ------------------------------------ src/home/mqtt/module/ota.py | 77 -------------- src/home/mqtt/module/relay.py | 92 ----------------- src/home/mqtt/module/temphum.py | 82 --------------- 13 files changed, 1150 deletions(-) delete mode 100644 src/home/mqtt/__init__.py delete mode 100644 src/home/mqtt/_config.py delete mode 100644 src/home/mqtt/_module.py delete mode 100644 src/home/mqtt/_mqtt.py delete mode 100644 src/home/mqtt/_node.py delete mode 100644 src/home/mqtt/_payload.py delete mode 100644 src/home/mqtt/_util.py delete mode 100644 src/home/mqtt/_wrapper.py delete mode 100644 src/home/mqtt/module/diagnostics.py delete mode 100644 src/home/mqtt/module/inverter.py delete mode 100644 src/home/mqtt/module/ota.py delete mode 100644 src/home/mqtt/module/relay.py delete mode 100644 src/home/mqtt/module/temphum.py (limited to 'src/home/mqtt') diff --git a/src/home/mqtt/__init__.py b/src/home/mqtt/__init__.py deleted file mode 100644 index 707d59c..0000000 --- a/src/home/mqtt/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from ._mqtt import Mqtt -from ._node import MqttNode -from ._module import MqttModule -from ._wrapper import MqttWrapper -from ._config import MqttConfig, MqttCreds, MqttNodesConfig -from ._payload import MqttPayload, MqttPayloadCustomField -from ._util import get_modules as get_mqtt_modules \ No newline at end of file diff --git a/src/home/mqtt/_config.py b/src/home/mqtt/_config.py deleted file mode 100644 index 9ba9443..0000000 --- a/src/home/mqtt/_config.py +++ /dev/null @@ -1,165 +0,0 @@ -from ..config import ConfigUnit -from typing import Optional, Union -from ..util import Addr -from collections import namedtuple - -MqttCreds = namedtuple('MqttCreds', 'username, password') - - -class MqttConfig(ConfigUnit): - NAME = 'mqtt' - - @classmethod - def schema(cls) -> Optional[dict]: - addr_schema = { - 'type': 'dict', - 'required': True, - 'schema': { - 'host': {'type': 'string', 'required': True}, - 'port': {'type': 'integer', 'required': True} - } - } - - schema = {} - for key in ('local', 'remote'): - schema[f'{key}_addr'] = addr_schema - - schema['creds'] = { - 'type': 'dict', - 'required': True, - 'keysrules': {'type': 'string'}, - 'valuesrules': { - 'type': 'dict', - 'schema': { - 'username': {'type': 'string', 'required': True}, - 'password': {'type': 'string', 'required': True}, - } - } - } - - for key in ('client', 'server'): - schema[f'default_{key}_creds'] = {'type': 'string', 'required': True} - - return schema - - def remote_addr(self) -> Addr: - return Addr(host=self['remote_addr']['host'], - port=self['remote_addr']['port']) - - def local_addr(self) -> Addr: - return Addr(host=self['local_addr']['host'], - port=self['local_addr']['port']) - - def creds_by_name(self, name: str) -> MqttCreds: - return MqttCreds(username=self['creds'][name]['username'], - password=self['creds'][name]['password']) - - def creds(self) -> MqttCreds: - return self.creds_by_name(self['default_client_creds']) - - def server_creds(self) -> MqttCreds: - return self.creds_by_name(self['default_server_creds']) - - -class MqttNodesConfig(ConfigUnit): - NAME = 'mqtt_nodes' - - @classmethod - def schema(cls) -> Optional[dict]: - return { - 'common': { - 'type': 'dict', - 'schema': { - 'temphum': { - 'type': 'dict', - 'schema': { - 'interval': {'type': 'integer'} - } - }, - 'password': {'type': 'string'} - } - }, - 'nodes': { - 'type': 'dict', - 'required': True, - 'keysrules': {'type': 'string'}, - 'valuesrules': { - 'type': 'dict', - 'schema': { - 'type': {'type': 'string', 'required': True, 'allowed': ['esp8266', 'linux', 'none'],}, - 'board': {'type': 'string', 'allowed': ['nodemcu', 'd1_mini_lite', 'esp12e']}, - 'temphum': { - 'type': 'dict', - 'schema': { - 'module': {'type': 'string', 'required': True, 'allowed': ['si7021', 'dht12']}, - 'interval': {'type': 'integer'}, - 'i2c_bus': {'type': 'integer'}, - 'tcpserver': { - 'type': 'dict', - 'schema': { - 'port': {'type': 'integer', 'required': True} - } - } - } - }, - 'relay': { - 'type': 'dict', - 'schema': { - 'device_type': {'type': 'string', 'allowed': ['lamp', 'pump', 'solenoid'], 'required': True}, - 'legacy_topics': {'type': 'boolean'} - } - }, - 'password': {'type': 'string'} - } - } - } - } - - @staticmethod - def custom_validator(data): - for name, node in data['nodes'].items(): - if 'temphum' in node: - if node['type'] == 'linux': - if 'i2c_bus' not in node['temphum']: - raise KeyError(f'nodes.{name}.temphum: i2c_bus is missing but required for type=linux') - if node['type'] in ('esp8266',) and 'board' not in node: - raise KeyError(f'nodes.{name}: board is missing but required for type={node["type"]}') - - def get_node(self, name: str) -> dict: - node = self['nodes'][name] - if node['type'] == 'none': - return node - - try: - if 'password' not in node: - node['password'] = self['common']['password'] - except KeyError: - pass - - try: - if 'temphum' in node: - for ckey, cval in self['common']['temphum'].items(): - if ckey not in node['temphum']: - node['temphum'][ckey] = cval - except KeyError: - pass - - return node - - def get_nodes(self, - filters: Optional[Union[list[str], tuple[str]]] = None, - only_names=False) -> Union[dict, list[str]]: - if filters: - for f in filters: - if f not in ('temphum', 'relay'): - raise ValueError(f'{self.__class__.__name__}::get_node(): invalid filter {f}') - reslist = [] - resdict = {} - for name in self['nodes'].keys(): - node = self.get_node(name) - if (not filters) or ('temphum' in filters and 'temphum' in node) or ('relay' in filters and 'relay' in node): - if only_names: - reslist.append(name) - else: - resdict[name] = node - return reslist if only_names else resdict diff --git a/src/home/mqtt/_module.py b/src/home/mqtt/_module.py deleted file mode 100644 index 80f27bb..0000000 --- a/src/home/mqtt/_module.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -import abc -import logging -import threading - -from time import sleep -from ..util import next_tick_gen - -from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: - from ._node import MqttNode - from ._payload import MqttPayload - - -class MqttModule(abc.ABC): - _tick_interval: int - _initialized: bool - _connected: bool - _ticker: Optional[threading.Thread] - _mqtt_node_ref: Optional[MqttNode] - - def __init__(self, tick_interval=0): - self._tick_interval = tick_interval - self._initialized = False - self._ticker = None - self._logger = logging.getLogger(self.__class__.__name__) - self._connected = False - self._mqtt_node_ref = None - - def on_connect(self, mqtt: MqttNode): - self._connected = True - self._mqtt_node_ref = mqtt - if self._tick_interval: - self._start_ticker() - - def on_disconnect(self, mqtt: MqttNode): - self._connected = False - self._mqtt_node_ref = None - - def is_initialized(self): - return self._initialized - - def set_initialized(self): - self._initialized = True - - def unset_initialized(self): - self._initialized = False - - def tick(self): - pass - - def _tick(self): - g = next_tick_gen(self._tick_interval) - while self._connected: - sleep(next(g)) - if not self._connected: - break - self.tick() - - def _start_ticker(self): - if not self._ticker or not self._ticker.is_alive(): - name_part = f'{self._mqtt_node_ref.id}/' if self._mqtt_node_ref else '' - self._ticker = None - self._ticker = threading.Thread(target=self._tick, - name=f'mqtt:{self.__class__.__name__}/{name_part}ticker') - self._ticker.start() - - def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: - pass diff --git a/src/home/mqtt/_mqtt.py b/src/home/mqtt/_mqtt.py deleted file mode 100644 index 746ae2e..0000000 --- a/src/home/mqtt/_mqtt.py +++ /dev/null @@ -1,86 +0,0 @@ -import os.path -import paho.mqtt.client as mqtt -import ssl -import logging - -from ._config import MqttCreds, MqttConfig -from typing import Optional - - -class Mqtt: - _connected: bool - _is_server: bool - _mqtt_config: MqttConfig - - def __init__(self, - clean_session=True, - client_id='', - creds: Optional[MqttCreds] = None, - is_server=False): - if not client_id: - raise ValueError('client_id must not be empty') - - self._client = mqtt.Client(client_id=client_id, - protocol=mqtt.MQTTv311, - clean_session=clean_session) - self._client.on_connect = self.on_connect - self._client.on_disconnect = self.on_disconnect - self._client.on_message = self.on_message - self._client.on_log = self.on_log - self._client.on_publish = self.on_publish - self._loop_started = False - self._connected = False - self._is_server = is_server - self._mqtt_config = MqttConfig() - self._logger = logging.getLogger(self.__class__.__name__) - - if not creds: - creds = self._mqtt_config.creds() if not is_server else self._mqtt_config.server_creds() - - self._client.username_pw_set(creds.username, creds.password) - - def _configure_tls(self): - ca_certs = os.path.realpath(os.path.join( - os.path.dirname(os.path.realpath(__file__)), - '..', - '..', - '..', - 'assets', - 'mqtt_ca.crt' - )) - self._client.tls_set(ca_certs=ca_certs, - cert_reqs=ssl.CERT_REQUIRED, - tls_version=ssl.PROTOCOL_TLSv1_2) - - def connect_and_loop(self, loop_forever=True): - self._configure_tls() - addr = self._mqtt_config.local_addr() if self._is_server else self._mqtt_config.remote_addr() - self._client.connect(addr.host, addr.port, 60) - if loop_forever: - self._client.loop_forever() - else: - self._client.loop_start() - self._loop_started = True - - def disconnect(self): - self._client.disconnect() - self._client.loop_write() - self._client.loop_stop() - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - self._logger.info("Connected with result code " + str(rc)) - self._connected = True - - def on_disconnect(self, client: mqtt.Client, userdata, rc): - self._logger.info("Disconnected with result code " + str(rc)) - self._connected = False - - def on_log(self, client: mqtt.Client, userdata, level, buf): - level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO - self._logger.log(level, f'MQTT: {buf}') - - def on_message(self, client: mqtt.Client, userdata, msg): - self._logger.debug(msg.topic + ": " + str(msg.payload)) - - def on_publish(self, client: mqtt.Client, userdata, mid): - self._logger.debug(f'publish done, mid={mid}') diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py deleted file mode 100644 index 4e259a4..0000000 --- a/src/home/mqtt/_node.py +++ /dev/null @@ -1,92 +0,0 @@ -import logging -import importlib - -from typing import List, TYPE_CHECKING, Optional -from ._payload import MqttPayload -from ._module import MqttModule -if TYPE_CHECKING: - from ._wrapper import MqttWrapper -else: - MqttWrapper = None - - -class MqttNode: - _modules: List[MqttModule] - _module_subscriptions: dict[str, MqttModule] - _node_id: str - _node_secret: str - _payload_callbacks: list[callable] - _wrapper: Optional[MqttWrapper] - - def __init__(self, - node_id: str, - node_secret: Optional[str] = None): - self._modules = [] - self._module_subscriptions = {} - self._node_id = node_id - self._node_secret = node_secret - self._payload_callbacks = [] - self._logger = logging.getLogger(self.__class__.__name__) - self._wrapper = None - - def on_connect(self, wrapper: MqttWrapper): - self._wrapper = wrapper - for module in self._modules: - if not module.is_initialized(): - module.on_connect(self) - module.set_initialized() - - def on_disconnect(self): - self._wrapper = None - for module in self._modules: - module.unset_initialized() - - def on_message(self, topic, payload): - if topic in self._module_subscriptions: - payload = self._module_subscriptions[topic].handle_payload(self, topic, payload) - if isinstance(payload, MqttPayload): - for f in self._payload_callbacks: - f(self, payload) - - def load_module(self, module_name: str, *args, **kwargs) -> MqttModule: - module = importlib.import_module(f'..module.{module_name}', __name__) - if not hasattr(module, 'MODULE_NAME'): - raise RuntimeError(f'MODULE_NAME not found in module {module}') - cl = getattr(module, getattr(module, 'MODULE_NAME')) - instance = cl(*args, **kwargs) - self.add_module(instance) - return instance - - def add_module(self, module: MqttModule): - self._modules.append(module) - if self._wrapper and self._wrapper._connected: - module.on_connect(self) - module.set_initialized() - - def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1): - if not self._wrapper or not self._wrapper._connected: - raise RuntimeError('not connected') - - self._module_subscriptions[topic] = module - self._wrapper.subscribe(self.id, topic, qos) - - def publish(self, - topic: str, - payload: bytes, - qos: int = 1): - self._wrapper.publish(self.id, topic, payload, qos) - - def add_payload_callback(self, callback: callable): - self._payload_callbacks.append(callback) - - @property - def id(self) -> str: - return self._node_id - - @property - def secret(self) -> str: - return self._node_secret - - @secret.setter - def secret(self, secret: str) -> None: - self._node_secret = secret diff --git a/src/home/mqtt/_payload.py b/src/home/mqtt/_payload.py deleted file mode 100644 index 58eeae3..0000000 --- a/src/home/mqtt/_payload.py +++ /dev/null @@ -1,145 +0,0 @@ -import struct -import abc -import re - -from typing import Optional, Tuple - - -def pldstr(self) -> str: - attrs = [] - for field in self.__class__.__annotations__: - if hasattr(self, field): - attr = getattr(self, field) - attrs.append(f'{field}={attr}') - if attrs: - attrs_s = ' ' - attrs_s += ', '.join(attrs) - else: - attrs_s = '' - return f'<%s{attrs_s}>' % (self.__class__.__name__,) - - -class MqttPayload(abc.ABC): - FORMAT = '' - PACKER = {} - UNPACKER = {} - - def __init__(self, **kwargs): - for field in self.__class__.__annotations__: - setattr(self, field, kwargs[field]) - - def pack(self): - args = [] - bf_number = -1 - bf_arg = 0 - bf_progress = 0 - - for field, field_type in self.__class__.__annotations__.items(): - bfp = _bit_field_params(field_type) - if bfp: - n, s, b = bfp - if n != bf_number: - if bf_number != -1: - args.append(bf_arg) - bf_number = n - bf_progress = 0 - bf_arg = 0 - bf_arg |= (getattr(self, field) & (2 ** b - 1)) << bf_progress - bf_progress += b - - else: - if bf_number != -1: - args.append(bf_arg) - bf_number = -1 - bf_progress = 0 - bf_arg = 0 - - args.append(self._pack_field(field)) - - if bf_number != -1: - args.append(bf_arg) - - return struct.pack(self.FORMAT, *args) - - @classmethod - def unpack(cls, buf: bytes): - data = struct.unpack(cls.FORMAT, buf) - kwargs = {} - i = 0 - bf_number = -1 - bf_progress = 0 - - for field, field_type in cls.__annotations__.items(): - bfp = _bit_field_params(field_type) - if bfp: - n, s, b = bfp - if n != bf_number: - bf_number = n - bf_progress = 0 - kwargs[field] = (data[i] >> bf_progress) & (2 ** b - 1) - bf_progress += b - continue # don't increment i - - if bf_number != -1: - bf_number = -1 - i += 1 - - if issubclass(field_type, MqttPayloadCustomField): - kwargs[field] = field_type.unpack(data[i]) - else: - kwargs[field] = cls._unpack_field(field, data[i]) - i += 1 - - return cls(**kwargs) - - def _pack_field(self, name): - val = getattr(self, name) - if self.PACKER and name in self.PACKER: - return self.PACKER[name](val) - else: - return val - - @classmethod - def _unpack_field(cls, name, val): - if isinstance(val, MqttPayloadCustomField): - return - if cls.UNPACKER and name in cls.UNPACKER: - return cls.UNPACKER[name](val) - else: - return val - - def __str__(self): - return pldstr(self) - - -class MqttPayloadCustomField(abc.ABC): - def __init__(self, **kwargs): - for field in self.__class__.__annotations__: - setattr(self, field, kwargs[field]) - - @abc.abstractmethod - def __index__(self): - pass - - @classmethod - @abc.abstractmethod - def unpack(cls, *args, **kwargs): - pass - - def __str__(self): - return pldstr(self) - - -def bit_field(seq_no: int, total_bits: int, bits: int): - return type(f'MQTTPayloadBitField_{seq_no}_{total_bits}_{bits}', (object,), { - 'seq_no': seq_no, - 'total_bits': total_bits, - 'bits': bits - }) - - -def _bit_field_params(cl) -> Optional[Tuple[int, ...]]: - match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__) - if match is not None: - return tuple([int(match.group(i)) for i in range(1, 4)]) - return None \ No newline at end of file diff --git a/src/home/mqtt/_util.py b/src/home/mqtt/_util.py deleted file mode 100644 index 390d463..0000000 --- a/src/home/mqtt/_util.py +++ /dev/null @@ -1,15 +0,0 @@ -import os -import re - -from typing import List - - -def get_modules() -> List[str]: - modules = [] - modules_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module') - for name in os.listdir(modules_dir): - if os.path.isdir(os.path.join(modules_dir, name)): - continue - name = re.sub(r'\.py$', '', name) - modules.append(name) - return modules diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py deleted file mode 100644 index 3c2774c..0000000 --- a/src/home/mqtt/_wrapper.py +++ /dev/null @@ -1,60 +0,0 @@ -import paho.mqtt.client as mqtt - -from ._mqtt import Mqtt -from ._node import MqttNode -from ..util import strgen - - -class MqttWrapper(Mqtt): - _nodes: list[MqttNode] - - def __init__(self, - client_id: str, - topic_prefix='hk', - randomize_client_id=False, - clean_session=True): - if randomize_client_id: - client_id += '_'+strgen(6) - super().__init__(clean_session=clean_session, - client_id=client_id) - self._nodes = [] - self._topic_prefix = topic_prefix - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - super().on_connect(client, userdata, flags, rc) - for node in self._nodes: - node.on_connect(self) - - def on_disconnect(self, client: mqtt.Client, userdata, rc): - super().on_disconnect(client, userdata, rc) - for node in self._nodes: - node.on_disconnect() - - def on_message(self, client: mqtt.Client, userdata, msg): - try: - topic = msg.topic - topic_node = topic[len(self._topic_prefix)+1:topic.find('/', len(self._topic_prefix)+1)] - for node in self._nodes: - if node.id in ('+', topic_node): - node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) - except Exception as e: - self._logger.exception(str(e)) - - def add_node(self, node: MqttNode): - self._nodes.append(node) - if self._connected: - node.on_connect(self) - - def subscribe(self, - node_id: str, - topic: str, - qos: int): - self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos) - - def publish(self, - node_id: str, - topic: str, - payload: bytes, - qos: int): - self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos) - self._client.loop_write() diff --git a/src/home/mqtt/module/diagnostics.py b/src/home/mqtt/module/diagnostics.py deleted file mode 100644 index 5db5e99..0000000 --- a/src/home/mqtt/module/diagnostics.py +++ /dev/null @@ -1,64 +0,0 @@ -from .._payload import MqttPayload, MqttPayloadCustomField -from .._node import MqttNode, MqttModule -from typing import Optional - -MODULE_NAME = 'MqttDiagnosticsModule' - - -class DiagnosticsFlags(MqttPayloadCustomField): - state: bool - config_changed_value_present: bool - config_changed: bool - - @staticmethod - def unpack(flags: int): - # _logger.debug(f'StatFlags.unpack: flags={flags}') - state = flags & 0x1 - ccvp = (flags >> 1) & 0x1 - cc = (flags >> 2) & 0x1 - # _logger.debug(f'StatFlags.unpack: state={state}') - return DiagnosticsFlags(state=(state == 1), - config_changed_value_present=(ccvp == 1), - config_changed=(cc == 1)) - - def __index__(self): - bits = 0 - bits |= (int(self.state) & 0x1) - bits |= (int(self.config_changed_value_present) & 0x1) << 1 - bits |= (int(self.config_changed) & 0x1) << 2 - return bits - - -class InitialDiagnosticsPayload(MqttPayload): - FORMAT = '=IBbIB' - - ip: int - fw_version: int - rssi: int - free_heap: int - flags: DiagnosticsFlags - - -class DiagnosticsPayload(MqttPayload): - FORMAT = '=bIB' - - rssi: int - free_heap: int - flags: DiagnosticsFlags - - -class MqttDiagnosticsModule(MqttModule): - def on_connect(self, mqtt: MqttNode): - super().on_connect(mqtt) - for topic in ('diag', 'd1ag', 'stat', 'stat1'): - mqtt.subscribe_module(topic, self) - - def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: - message = None - if topic in ('stat', 'diag'): - message = DiagnosticsPayload.unpack(payload) - elif topic in ('stat1', 'd1ag'): - message = InitialDiagnosticsPayload.unpack(payload) - if message: - self._logger.debug(message) - return message diff --git a/src/home/mqtt/module/inverter.py b/src/home/mqtt/module/inverter.py deleted file mode 100644 index d927a06..0000000 --- a/src/home/mqtt/module/inverter.py +++ /dev/null @@ -1,195 +0,0 @@ -import time -import json -import datetime -try: - import inverterd -except: - pass - -from typing import Optional -from .._module import MqttModule -from .._node import MqttNode -from .._payload import MqttPayload, bit_field -try: - from home.database import InverterDatabase -except: - pass - -_mult_10 = lambda n: int(n*10) -_div_10 = lambda n: n/10 - - -MODULE_NAME = 'MqttInverterModule' - -STATUS_TOPIC = 'status' -GENERATION_TOPIC = 'generation' - - -class MqttInverterStatusPayload(MqttPayload): - # 46 bytes - FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH' - - PACKER = { - 'grid_voltage': _mult_10, - 'grid_freq': _mult_10, - 'ac_output_voltage': _mult_10, - 'ac_output_freq': _mult_10, - 'battery_voltage': _mult_10, - 'battery_voltage_scc': _mult_10, - 'battery_voltage_scc2': _mult_10, - 'pv1_input_voltage': _mult_10, - 'pv2_input_voltage': _mult_10 - } - UNPACKER = { - 'grid_voltage': _div_10, - 'grid_freq': _div_10, - 'ac_output_voltage': _div_10, - 'ac_output_freq': _div_10, - 'battery_voltage': _div_10, - 'battery_voltage_scc': _div_10, - 'battery_voltage_scc2': _div_10, - 'pv1_input_voltage': _div_10, - 'pv2_input_voltage': _div_10 - } - - time: int - grid_voltage: float - grid_freq: float - ac_output_voltage: float - ac_output_freq: float - ac_output_apparent_power: int - ac_output_active_power: int - output_load_percent: int - battery_voltage: float - battery_voltage_scc: float - battery_voltage_scc2: float - battery_discharge_current: int - battery_charge_current: int - battery_capacity: int - inverter_heat_sink_temp: int - mppt1_charger_temp: int - mppt2_charger_temp: int - pv1_input_power: int - pv2_input_power: int - pv1_input_voltage: float - pv2_input_voltage: float - - # H - mppt1_charger_status: bit_field(0, 16, 2) - mppt2_charger_status: bit_field(0, 16, 2) - battery_power_direction: bit_field(0, 16, 2) - dc_ac_power_direction: bit_field(0, 16, 2) - line_power_direction: bit_field(0, 16, 2) - load_connected: bit_field(0, 16, 1) - - -class MqttInverterGenerationPayload(MqttPayload): - # 8 bytes - FORMAT = 'II' - - time: int - wh: int - - -class MqttInverterModule(MqttModule): - _status_poll_freq: int - _generation_poll_freq: int - _inverter: Optional[inverterd.Client] - _database: Optional[InverterDatabase] - _gen_prev: float - - def __init__(self, status_poll_freq=0, generation_poll_freq=0): - super().__init__(tick_interval=status_poll_freq) - self._status_poll_freq = status_poll_freq - self._generation_poll_freq = generation_poll_freq - - # this defines whether this is a publisher or a subscriber - if status_poll_freq > 0: - self._inverter = inverterd.Client() - self._inverter.connect() - self._inverter.format(inverterd.Format.SIMPLE_JSON) - self._database = None - else: - self._inverter = None - self._database = InverterDatabase() - - self._gen_prev = 0 - - def on_connect(self, mqtt: MqttNode): - super().on_connect(mqtt) - if not self._inverter: - mqtt.subscribe_module(STATUS_TOPIC, self) - mqtt.subscribe_module(GENERATION_TOPIC, self) - - def tick(self): - if not self._inverter: - return - - # read status - now = time.time() - try: - raw = self._inverter.exec('get-status') - except inverterd.InverterError as e: - self._logger.error(f'inverter error: {str(e)}') - # TODO send to server - return - - data = json.loads(raw)['data'] - status = MqttInverterStatusPayload(time=round(now), **data) - self._mqtt_node_ref.publish(STATUS_TOPIC, status.pack()) - - # read today's generation stat - now = time.time() - if self._gen_prev == 0 or now - self._gen_prev >= self._generation_poll_freq: - self._gen_prev = now - today = datetime.date.today() - try: - raw = self._inverter.exec('get-day-generated', (today.year, today.month, today.day)) - except inverterd.InverterError as e: - self._logger.error(f'inverter error: {str(e)}') - # TODO send to server - return - - data = json.loads(raw)['data'] - gen = MqttInverterGenerationPayload(time=round(now), wh=data['wh']) - self._mqtt_node_ref.publish(GENERATION_TOPIC, gen.pack()) - - def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: - home_id = 1 # legacy compat - - if topic == STATUS_TOPIC: - s = MqttInverterStatusPayload.unpack(payload) - self._database.add_status(home_id=home_id, - client_time=s.time, - grid_voltage=int(s.grid_voltage*10), - grid_freq=int(s.grid_freq * 10), - ac_output_voltage=int(s.ac_output_voltage * 10), - ac_output_freq=int(s.ac_output_freq * 10), - ac_output_apparent_power=s.ac_output_apparent_power, - ac_output_active_power=s.ac_output_active_power, - output_load_percent=s.output_load_percent, - battery_voltage=int(s.battery_voltage * 10), - battery_voltage_scc=int(s.battery_voltage_scc * 10), - battery_voltage_scc2=int(s.battery_voltage_scc2 * 10), - battery_discharge_current=s.battery_discharge_current, - battery_charge_current=s.battery_charge_current, - battery_capacity=s.battery_capacity, - inverter_heat_sink_temp=s.inverter_heat_sink_temp, - mppt1_charger_temp=s.mppt1_charger_temp, - mppt2_charger_temp=s.mppt2_charger_temp, - pv1_input_power=s.pv1_input_power, - pv2_input_power=s.pv2_input_power, - pv1_input_voltage=int(s.pv1_input_voltage * 10), - pv2_input_voltage=int(s.pv2_input_voltage * 10), - mppt1_charger_status=s.mppt1_charger_status, - mppt2_charger_status=s.mppt2_charger_status, - battery_power_direction=s.battery_power_direction, - dc_ac_power_direction=s.dc_ac_power_direction, - line_power_direction=s.line_power_direction, - load_connected=s.load_connected) - return s - - elif topic == GENERATION_TOPIC: - gen = MqttInverterGenerationPayload.unpack(payload) - self._database.add_generation(home_id, gen.time, gen.wh) - return gen diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py deleted file mode 100644 index cd34332..0000000 --- a/src/home/mqtt/module/ota.py +++ /dev/null @@ -1,77 +0,0 @@ -import hashlib - -from typing import Optional -from .._payload import MqttPayload -from .._node import MqttModule, MqttNode - -MODULE_NAME = 'MqttOtaModule' - - -class OtaResultPayload(MqttPayload): - FORMAT = '=BB' - result: int - error_code: int - - -class OtaPayload(MqttPayload): - secret: str - filename: str - - # structure of returned data: - # - # uint8_t[len(secret)] secret; - # uint8_t[16] md5; - # *uint8_t data - - def pack(self): - buf = bytearray(self.secret.encode()) - m = hashlib.md5() - with open(self.filename, 'rb') as fd: - content = fd.read() - m.update(content) - buf.extend(m.digest()) - buf.extend(content) - return buf - - def unpack(cls, buf: bytes): - raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') - # secret = buf[:12].decode() - # filename = buf[12:].decode() - # return OTAPayload(secret=secret, filename=filename) - - -class MqttOtaModule(MqttModule): - _ota_request: Optional[tuple[str, int]] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._ota_request = None - - def on_connect(self, mqtt: MqttNode): - super().on_connect(mqtt) - mqtt.subscribe_module("otares", self) - - if self._ota_request is not None: - filename, qos = self._ota_request - self._ota_request = None - self.do_push_ota(self._mqtt_node_ref.secret, filename, qos) - - def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: - if topic == 'otares': - message = OtaResultPayload.unpack(payload) - self._logger.debug(message) - return message - - def do_push_ota(self, secret: str, filename: str, qos: int): - payload = OtaPayload(secret=secret, filename=filename) - self._mqtt_node_ref.publish('ota', - payload=payload.pack(), - qos=qos) - - def push_ota(self, - filename: str, - qos: int): - if not self._initialized: - self._ota_request = (filename, qos) - else: - self.do_push_ota(filename, qos) diff --git a/src/home/mqtt/module/relay.py b/src/home/mqtt/module/relay.py deleted file mode 100644 index e968031..0000000 --- a/src/home/mqtt/module/relay.py +++ /dev/null @@ -1,92 +0,0 @@ -import datetime - -from typing import Optional -from .. import MqttModule, MqttPayload, MqttNode - -MODULE_NAME = 'MqttRelayModule' - - -class MqttPowerSwitchPayload(MqttPayload): - FORMAT = '=12sB' - PACKER = { - 'state': lambda n: int(n), - 'secret': lambda s: s.encode('utf-8') - } - UNPACKER = { - 'state': lambda n: bool(n), - 'secret': lambda s: s.decode('utf-8') - } - - secret: str - state: bool - - -class MqttPowerStatusPayload(MqttPayload): - FORMAT = '=B' - PACKER = { - 'opened': lambda n: int(n), - } - UNPACKER = { - 'opened': lambda n: bool(n), - } - - opened: bool - - -class MqttRelayState: - enabled: bool - update_time: datetime.datetime - rssi: int - fw_version: int - ever_updated: bool - - def __init__(self): - self.ever_updated = False - self.enabled = False - self.rssi = 0 - - def update(self, - enabled: bool, - rssi: int, - fw_version=None): - self.ever_updated = True - self.enabled = enabled - self.rssi = rssi - self.update_time = datetime.datetime.now() - if fw_version: - self.fw_version = fw_version - - -class MqttRelayModule(MqttModule): - _legacy_topics: bool - - def __init__(self, legacy_topics=False, *args, **kwargs): - super().__init__(*args, **kwargs) - self._legacy_topics = legacy_topics - - def on_connect(self, mqtt: MqttNode): - super().on_connect(mqtt) - mqtt.subscribe_module(self._get_switch_topic(), self) - mqtt.subscribe_module('relay/status', self) - - def switchpower(self, - enable: bool): - payload = MqttPowerSwitchPayload(secret=self._mqtt_node_ref.secret, - state=enable) - self._mqtt_node_ref.publish(self._get_switch_topic(), - payload=payload.pack()) - - def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: - message = None - - if topic == self._get_switch_topic(): - message = MqttPowerSwitchPayload.unpack(payload) - elif topic == 'relay/status': - message = MqttPowerStatusPayload.unpack(payload) - - if message is not None: - self._logger.debug(message) - return message - - def _get_switch_topic(self) -> str: - return 'relay/power' if self._legacy_topics else 'relay/switch' diff --git a/src/home/mqtt/module/temphum.py b/src/home/mqtt/module/temphum.py deleted file mode 100644 index fd02cca..0000000 --- a/src/home/mqtt/module/temphum.py +++ /dev/null @@ -1,82 +0,0 @@ -from .._node import MqttNode -from .._module import MqttModule -from .._payload import MqttPayload -from typing import Optional -from ...temphum import BaseSensor - -two_digits_precision = lambda x: round(x, 2) - -MODULE_NAME = 'MqttTempHumModule' -DATA_TOPIC = 'temphum/data' - - -class MqttTemphumDataPayload(MqttPayload): - FORMAT = '=ddb' - UNPACKER = { - 'temp': two_digits_precision, - 'rh': two_digits_precision - } - - temp: float - rh: float - error: int - - -# class MqttTempHumNodes(HashableEnum): -# KBN_SH_HALL = auto() -# KBN_SH_BATHROOM = auto() -# KBN_SH_LIVINGROOM = auto() -# KBN_SH_BEDROOM = auto() -# -# KBN_BH_2FL = auto() -# KBN_BH_2FL_STREET = auto() -# KBN_BH_1FL_LIVINGROOM = auto() -# KBN_BH_1FL_BEDROOM = auto() -# KBN_BH_1FL_BATHROOM = auto() -# -# KBN_NH_1FL_INV = auto() -# KBN_NH_1FL_CENTER = auto() -# KBN_NH_1LF_KT = auto() -# KBN_NH_1FL_DS = auto() -# KBN_NH_1FS_EZ = auto() -# -# SPB_FLAT120_CABINET = auto() - - -class MqttTempHumModule(MqttModule): - def __init__(self, - sensor: Optional[BaseSensor] = None, - write_to_database=False, - *args, **kwargs): - if sensor is not None: - kwargs['tick_interval'] = 10 - super().__init__(*args, **kwargs) - self._sensor = sensor - - def on_connect(self, mqtt: MqttNode): - super().on_connect(mqtt) - mqtt.subscribe_module(DATA_TOPIC, self) - - def tick(self): - if not self._sensor: - return - - error = 0 - temp = 0 - rh = 0 - try: - temp = self._sensor.temperature() - rh = self._sensor.humidity() - except: - error = 1 - pld = MqttTemphumDataPayload(temp=temp, rh=rh, error=error) - self._mqtt_node_ref.publish(DATA_TOPIC, pld.pack()) - - def handle_payload(self, - mqtt: MqttNode, - topic: str, - payload: bytes) -> Optional[MqttPayload]: - if topic == DATA_TOPIC: - message = MqttTemphumDataPayload.unpack(payload) - self._logger.debug(message) - return message -- cgit v1.2.3