diff options
Diffstat (limited to 'src/home/mqtt/_node.py')
-rw-r--r-- | src/home/mqtt/_node.py | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py new file mode 100644 index 0000000..688b30b --- /dev/null +++ b/src/home/mqtt/_node.py @@ -0,0 +1,95 @@ +import paho.mqtt.client as mqtt + +from .mqtt import MqttBase +from typing import List +from ._module import MqttModule +from ._payload import MqttPayload + + +class MqttNode(MqttBase): + _modules: List[MqttModule] + _module_subscriptions: dict[str, MqttModule] + _node_id: str + _payload_callbacks: list[callable] + # _devices: list[MqttEspDevice] + # _message_callback: Optional[callable] + # _ota_publish_callback: Optional[callable] + + def __init__(self, + node_id: str, + # devices: Union[MqttEspDevice, list[MqttEspDevice]] + ): + super().__init__(clean_session=True) + self._modules = [] + self._module_subscriptions = {} + self._node_id = node_id + self._payload_callbacks = [] + # if not isinstance(devices, list): + # devices = [devices] + # self._devices = devices + # self._message_callback = None + # self._ota_publish_callback = None + # self._ota_mid = None + + def on_connect(self, client: mqtt.Client, userdata, flags, rc): + super().on_connect(client, userdata, flags, rc) + for module in self._modules: + if not module.is_initialized(): + module.init(self) + module.set_initialized() + + def on_publish(self, client: mqtt.Client, userdata, mid): + pass # FIXME + # if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback: + # self._ota_publish_callback() + + def on_message(self, client: mqtt.Client, userdata, msg): + try: + topic = msg.topic + actual_topic = topic[len(f'hk/{self._node_id}/'):] + + if actual_topic in self._module_subscriptions: + payload = self._module_subscriptions[actual_topic].handle_payload(self, actual_topic, msg.payload) + if isinstance(payload, MqttPayload): + for f in self._payload_callbacks: + f(payload) + + except Exception as e: + self._logger.exception(str(e)) + + # def push_ota(self, + # device_id, + # filename: str, + # publish_callback: callable, + # qos: int): + # device = next(d for d in self._devices if d.id == device_id) + # assert device.secret is not None, 'device secret not specified' + # + # self._ota_publish_callback = publish_callback + # payload = OtaPayload(secret=device.secret, filename=filename) + # publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota', + # payload=payload.pack(), + # qos=qos) + # self._ota_mid = publish_result.mid + # self._client.loop_write() + # + # @classmethod + # def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None): + # return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$' + + def add_module(self, module: MqttModule): + self._modules.append(module) + if self._connected: + module.init(self) + module.set_initialized() + + def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1): + self._module_subscriptions[topic] = module + self._client.subscribe(f'hk/{self._node_id}/{topic}', qos) + + def publish(self, topic: str, payload: bytes, qos: int = 1): + self._client.publish(f'hk/{self._node_id}/{topic}', payload, qos) + self._client.loop_write() + + def add_payload_callback(self, callback: callable): + self._payload_callbacks.append(callback)
\ No newline at end of file |