diff options
Diffstat (limited to 'src/home/mqtt')
-rw-r--r-- | src/home/mqtt/__init__.py | 10 | ||||
-rw-r--r-- | src/home/mqtt/_module.py | 43 | ||||
-rw-r--r-- | src/home/mqtt/_node.py | 111 | ||||
-rw-r--r-- | src/home/mqtt/_wrapper.py | 55 | ||||
-rw-r--r-- | src/home/mqtt/module/diagnostics.py | 7 | ||||
-rw-r--r-- | src/home/mqtt/module/ota.py | 15 | ||||
-rw-r--r-- | src/home/mqtt/module/relay.py | 8 | ||||
-rw-r--r-- | src/home/mqtt/module/temphum.py | 13 | ||||
-rw-r--r-- | src/home/mqtt/mqtt.py | 20 | ||||
-rw-r--r-- | src/home/mqtt/relay.py | 59 | ||||
-rw-r--r-- | src/home/mqtt/util.py | 31 |
11 files changed, 176 insertions, 196 deletions
diff --git a/src/home/mqtt/__init__.py b/src/home/mqtt/__init__.py index c95061f..8633437 100644 --- a/src/home/mqtt/__init__.py +++ b/src/home/mqtt/__init__.py @@ -1,9 +1,5 @@ -from .mqtt import MqttBase, MqttPayload, MqttPayloadCustomField +from .mqtt import Mqtt, MqttPayload, MqttPayloadCustomField from ._node import MqttNode from ._module import MqttModule -from .util import ( - poll_tick, - get_modules as get_mqtt_modules, - import_module as import_mqtt_module, - add_module as add_mqtt_module -)
\ No newline at end of file +from ._wrapper import MqttWrapper +from .util import get_modules as get_mqtt_modules
\ No newline at end of file diff --git a/src/home/mqtt/_module.py b/src/home/mqtt/_module.py index 840534e..80f27bb 100644 --- a/src/home/mqtt/_module.py +++ b/src/home/mqtt/_module.py @@ -2,6 +2,10 @@ from __future__ import annotations import abc import logging +import threading + +from time import sleep +from ..util import next_tick_gen from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: @@ -10,16 +14,29 @@ if TYPE_CHECKING: class MqttModule(abc.ABC): - tick_interval: int + _tick_interval: int _initialized: bool + _connected: bool + _ticker: Optional[threading.Thread] + _mqtt_node_ref: Optional[MqttNode] def __init__(self, tick_interval=0): - self.tick_interval = tick_interval + self._tick_interval = tick_interval self._initialized = False + self._ticker = None self._logger = logging.getLogger(self.__class__.__name__) + self._connected = False + self._mqtt_node_ref = None - def init(self, mqtt: MqttNode): - pass + def on_connect(self, mqtt: MqttNode): + self._connected = True + self._mqtt_node_ref = mqtt + if self._tick_interval: + self._start_ticker() + + def on_disconnect(self, mqtt: MqttNode): + self._connected = False + self._mqtt_node_ref = None def is_initialized(self): return self._initialized @@ -30,8 +47,24 @@ class MqttModule(abc.ABC): def unset_initialized(self): self._initialized = False - def tick(self, mqtt: MqttNode): + def tick(self): pass + def _tick(self): + g = next_tick_gen(self._tick_interval) + while self._connected: + sleep(next(g)) + if not self._connected: + break + self.tick() + + def _start_ticker(self): + if not self._ticker or not self._ticker.is_alive(): + name_part = f'{self._mqtt_node_ref.id}/' if self._mqtt_node_ref else '' + self._ticker = None + self._ticker = threading.Thread(target=self._tick, + name=f'mqtt:{self.__class__.__name__}/{name_part}ticker') + self._ticker.start() + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: pass diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py index f34da0c..ddf5ba2 100644 --- a/src/home/mqtt/_node.py +++ b/src/home/mqtt/_node.py @@ -1,103 +1,80 @@ -import paho.mqtt.client as mqtt +import logging +import importlib -from .mqtt import MqttBase -from typing import List, Optional -from ._module import MqttModule +from typing import List, TYPE_CHECKING, Optional from ._payload import MqttPayload +from ._module import MqttModule +if TYPE_CHECKING: + from ._wrapper import MqttWrapper +else: + MqttWrapper = None -class MqttNode(MqttBase): +class MqttNode: _modules: List[MqttModule] _module_subscriptions: dict[str, MqttModule] _node_id: str _payload_callbacks: list[callable] - # _devices: list[MqttEspDevice] - # _message_callback: Optional[callable] - # _ota_publish_callback: Optional[callable] + _wrapper: Optional[MqttWrapper] - def __init__(self, - node_id: str, - # devices: Union[MqttEspDevice, list[MqttEspDevice]] - ): - super().__init__(clean_session=True) + def __init__(self, node_id: str): self._modules = [] self._module_subscriptions = {} self._node_id = node_id self._payload_callbacks = [] - # if not isinstance(devices, list): - # devices = [devices] - # self._devices = devices - # self._message_callback = None - # self._ota_publish_callback = None - # self._ota_mid = None + self._logger = logging.getLogger(self.__class__.__name__) + self._wrapper = None - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - super().on_connect(client, userdata, flags, rc) + def on_connect(self, wrapper: MqttWrapper): + self._wrapper = wrapper for module in self._modules: if not module.is_initialized(): - module.init(self) + module.on_connect(self) module.set_initialized() - def on_disconnect(self, client: mqtt.Client, userdata, rc): - super().on_disconnect(client, userdata, rc) + def on_disconnect(self): + self._wrapper = None for module in self._modules: module.unset_initialized() - def on_publish(self, client: mqtt.Client, userdata, mid): - pass # FIXME - # if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback: - # self._ota_publish_callback() - - def on_message(self, client: mqtt.Client, userdata, msg): - try: - topic = msg.topic - actual_topic = topic[len(f'hk/{self._node_id}/'):] - - if actual_topic in self._module_subscriptions: - payload = self._module_subscriptions[actual_topic].handle_payload(self, actual_topic, msg.payload) - if isinstance(payload, MqttPayload): - for f in self._payload_callbacks: - f(payload) + def on_message(self, topic, payload): + if topic in self._module_subscriptions: + payload = self._module_subscriptions[topic].handle_payload(self, topic, payload) + if isinstance(payload, MqttPayload): + for f in self._payload_callbacks: + f(payload) - except Exception as e: - self._logger.exception(str(e)) - - # def push_ota(self, - # device_id, - # filename: str, - # publish_callback: callable, - # qos: int): - # device = next(d for d in self._devices if d.id == device_id) - # assert device.secret is not None, 'device secret not specified' - # - # self._ota_publish_callback = publish_callback - # payload = OtaPayload(secret=device.secret, filename=filename) - # publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota', - # payload=payload.pack(), - # qos=qos) - # self._ota_mid = publish_result.mid - # self._client.loop_write() - # - # @classmethod - # def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None): - # return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$' + def load_module(self, module_name: str, *args, **kwargs) -> MqttModule: + module = importlib.import_module(f'..module.{module_name}', __name__) + if not hasattr(module, 'MODULE_NAME'): + raise RuntimeError(f'MODULE_NAME not found in module {module}') + cl = getattr(module, getattr(module, 'MODULE_NAME')) + instance = cl(*args, **kwargs) + self.add_module(instance) + return instance def add_module(self, module: MqttModule): self._modules.append(module) - if self._connected: - module.init(self) + if self._wrapper and self._wrapper._connected: + module.on_connect(self) module.set_initialized() def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1): + if not self._wrapper or not self._wrapper._connected: + raise RuntimeError('not connected') + self._module_subscriptions[topic] = module - self._client.subscribe(f'hk/{self._node_id}/{topic}', qos) + self._wrapper.subscribe(self.id, topic, qos) def publish(self, topic: str, payload: bytes, qos: int = 1): - self._client.publish(f'hk/{self._node_id}/{topic}', payload, qos) - self._client.loop_write() + self._wrapper.publish(self.id, topic, payload, qos) def add_payload_callback(self, callback: callable): - self._payload_callbacks.append(callback)
\ No newline at end of file + self._payload_callbacks.append(callback) + + @property + def id(self) -> str: + return self._node_id diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py new file mode 100644 index 0000000..41f9d89 --- /dev/null +++ b/src/home/mqtt/_wrapper.py @@ -0,0 +1,55 @@ +import paho.mqtt.client as mqtt + +from .mqtt import Mqtt +from ._node import MqttNode +from ..config import config +from ..util import strgen + + +class MqttWrapper(Mqtt): + _nodes: list[MqttNode] + + def __init__(self, topic_prefix='hk', randomize_client_id=False): + client_id = config['mqtt']['client_id'] + if randomize_client_id: + client_id += '_'+strgen(6) + super().__init__(clean_session=True, client_id=client_id) + self._nodes = [] + self._topic_prefix = topic_prefix + + def on_connect(self, client: mqtt.Client, userdata, flags, rc): + super().on_connect(client, userdata, flags, rc) + for node in self._nodes: + node.on_connect(self) + + def on_disconnect(self, client: mqtt.Client, userdata, rc): + super().on_disconnect(client, userdata, rc) + for node in self._nodes: + node.on_disconnect() + + def on_message(self, client: mqtt.Client, userdata, msg): + try: + topic = msg.topic + for node in self._nodes: + node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) + except Exception as e: + self._logger.exception(str(e)) + + def add_node(self, node: MqttNode): + self._nodes.append(node) + if self._connected: + node.on_connect(self) + + def subscribe(self, + node_id: str, + topic: str, + qos: int): + self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos) + + def publish(self, + node_id: str, + topic: str, + payload: bytes, + qos: int): + self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos) + self._client.loop_write() diff --git a/src/home/mqtt/module/diagnostics.py b/src/home/mqtt/module/diagnostics.py index c31cce2..fa6cc8e 100644 --- a/src/home/mqtt/module/diagnostics.py +++ b/src/home/mqtt/module/diagnostics.py @@ -48,14 +48,17 @@ class DiagnosticsPayload(MqttPayload): class MqttDiagnosticsModule(MqttModule): - def init(self, mqtt: MqttNode): + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) for topic in ('diag', 'd1ag', 'stat', 'stat1'): mqtt.subscribe_module(topic, self) def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + message = None if topic in ('stat', 'diag'): message = DiagnosticsPayload.unpack(payload) elif topic in ('stat1', 'd1ag'): message = InitialDiagnosticsPayload.unpack(payload) - self._logger.debug(message) + if message: + self._logger.debug(message) return message diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py index 5a1a309..e71cccc 100644 --- a/src/home/mqtt/module/ota.py +++ b/src/home/mqtt/module/ota.py @@ -42,18 +42,15 @@ class OtaPayload(MqttPayload): class MqttOtaModule(MqttModule): _ota_request: Optional[tuple[str, str, int]] - _mqtt_ref: Optional[MqttNode] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._ota_request = None - self._mqtt_ref = None - def init(self, mqtt: MqttNode): + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) mqtt.subscribe_module("otares", self) - self._mqtt_ref = mqtt - if self._ota_request is not None: secret, filename, qos = self._ota_request self._ota_request = None @@ -67,9 +64,9 @@ class MqttOtaModule(MqttModule): def do_push_ota(self, secret: str, filename: str, qos: int): payload = OtaPayload(secret=secret, filename=filename) - self._mqtt_ref.publish('ota', - payload=payload.pack(), - qos=qos) + self._mqtt_node_ref.publish('ota', + payload=payload.pack(), + qos=qos) def push_ota(self, secret: str, @@ -78,4 +75,4 @@ class MqttOtaModule(MqttModule): if not self._initialized: self._ota_request = (secret, filename, qos) else: - self.do_push_ota(secret, filename, qos)
\ No newline at end of file + self.do_push_ota(secret, filename, qos) diff --git a/src/home/mqtt/module/relay.py b/src/home/mqtt/module/relay.py index bf22bfe..ae88ddb 100644 --- a/src/home/mqtt/module/relay.py +++ b/src/home/mqtt/module/relay.py @@ -58,16 +58,16 @@ class MqttRelayState: class MqttRelayModule(MqttModule): - def init(self, mqtt: MqttNode): + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) mqtt.subscribe_module('relay/switch', self) mqtt.subscribe_module('relay/status', self) - @staticmethod - def switchpower(mqtt: MqttNode, + def switchpower(self, enable: bool, secret: str): payload = MqttPowerSwitchPayload(secret=secret, state=enable) - mqtt.publish('relay/switch', payload=payload.pack()) + self._mqtt_node_ref.publish('relay/switch', payload=payload.pack()) def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: message = None diff --git a/src/home/mqtt/module/temphum.py b/src/home/mqtt/module/temphum.py index 0e43f1b..9cdfedb 100644 --- a/src/home/mqtt/module/temphum.py +++ b/src/home/mqtt/module/temphum.py @@ -4,6 +4,7 @@ from .._module import MqttModule from .._payload import MqttPayload from ...util import HashableEnum from typing import Optional +from ...temphum import BaseSensor two_digits_precision = lambda x: round(x, 2) @@ -44,9 +45,17 @@ class MqttTempHumNodes(HashableEnum): class MqttTempHumModule(MqttModule): - def init(self, mqtt: MqttNode): + def __init__(self, sensor: Optional[BaseSensor] = None, *args, **kwargs): + super().__init__(*args, **kwargs) + self._sensor = sensor + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) mqtt.subscribe_module('temphum/data', self) + def tick(self): + pass + def handle_payload(self, mqtt: MqttNode, topic: str, @@ -54,4 +63,4 @@ class MqttTempHumModule(MqttModule): if topic == 'temphum/data': message = MqttTemphumDataPayload.unpack(payload) self._logger.debug(message) - return message
\ No newline at end of file + return message diff --git a/src/home/mqtt/mqtt.py b/src/home/mqtt/mqtt.py index fad5d26..ba32889 100644 --- a/src/home/mqtt/mqtt.py +++ b/src/home/mqtt/mqtt.py @@ -5,6 +5,7 @@ import logging from ..config import config from ._payload import * +from typing import Optional def username_and_password() -> Tuple[str, str]: @@ -13,11 +14,13 @@ def username_and_password() -> Tuple[str, str]: return username, password -class MqttBase: +class Mqtt: _connected: bool - def __init__(self, clean_session=True): - self._client = mqtt.Client(client_id=config['mqtt']['client_id'], + def __init__(self, + clean_session=True, + client_id: Optional[str] = None): + self._client = mqtt.Client(client_id=config['mqtt']['client_id'] if not client_id else client_id, protocol=mqtt.MQTTv311, clean_session=clean_session) self._client.on_connect = self.on_connect @@ -81,14 +84,3 @@ class MqttBase: def on_publish(self, client: mqtt.Client, userdata, mid): self._logger.debug(f'publish done, mid={mid}') - - -class MqttEspDevice: - id: str - secret: Optional[str] - - def __init__(self, - node_id: str, - secret: Optional[str] = None): - self.id = node_id - self.secret = secret diff --git a/src/home/mqtt/relay.py b/src/home/mqtt/relay.py deleted file mode 100644 index cf657f7..0000000 --- a/src/home/mqtt/relay.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 -import paho.mqtt.client as mqtt -import re -import logging - -from .mqtt import MQTTBase - - -class MQTTRelayClient(MQTTBase): - _home_id: str - - def __init__(self, home_id: str): - super().__init__(clean_session=True) - self._home_id = home_id - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - super().on_connect(client, userdata, flags, rc) - - topic = f'home/{self._home_id}/#' - self._logger.info(f"subscribing to {topic}") - - client.subscribe(topic, qos=1) - - def on_message(self, client: mqtt.Client, userdata, msg): - try: - match = re.match(r'^home/(.*?)/relay/(stat|power)(?:/(.+))?$', msg.topic) - self._logger.info(f'topic: {msg.topic}') - if not match: - return - - name = match.group(1) - subtopic = match.group(2) - - if name != self._home_id: - return - - if subtopic == 'stat': - stat_name, stat_value = match.group(3).split('/') - self._logger.info(f'stat: {stat_name} = {stat_value}') - - except Exception as e: - self._logger.exception(str(e)) - - -class MQTTRelayController(MQTTBase): - _home_id: str - - def __init__(self, home_id: str): - super().__init__(clean_session=True) - self._home_id = home_id - - def set_power(self, enable: bool): - self._client.publish(f'home/{self._home_id}/relay/power', - payload=int(enable), - qos=1) - self._client.loop_write() - - def send_stat(self, stat: dict): - pass diff --git a/src/home/mqtt/util.py b/src/home/mqtt/util.py index 91b6baf..390d463 100644 --- a/src/home/mqtt/util.py +++ b/src/home/mqtt/util.py @@ -1,38 +1,15 @@ -import time import os import re -import importlib -from ._node import MqttNode -from . import MqttModule from typing import List -def poll_tick(freq): - t = time.time() - while True: - t += freq - yield max(t - time.time(), 0) - - def get_modules() -> List[str]: modules = [] - for name in os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module')): + modules_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module') + for name in os.listdir(modules_dir): + if os.path.isdir(os.path.join(modules_dir, name)): + continue name = re.sub(r'\.py$', '', name) modules.append(name) return modules - - -def import_module(module: str): - return importlib.import_module( - f'..module.{module}', __name__) - - -def add_module(mqtt_node: MqttNode, module: str) -> MqttModule: - module = import_module(module) - if not hasattr(module, 'MODULE_NAME'): - raise RuntimeError(f'MODULE_NAME not found in module {module}') - cl = getattr(module, getattr(module, 'MODULE_NAME')) - instance = cl() - mqtt_node.add_module(instance) - return instance
\ No newline at end of file |