import paho.mqtt.client as mqtt from ._mqtt import Mqtt from ._node import MqttNode from ..util import strgen class MqttWrapper(Mqtt): _nodes: list[MqttNode] _connect_callbacks: list[callable] _disconnect_callbacks: list[callable] def __init__(self, client_id: str, topic_prefix='hk', randomize_client_id=False, clean_session=True): if randomize_client_id: client_id += '_'+strgen(6) super().__init__(clean_session=clean_session, client_id=client_id) self._nodes = [] self._connect_callbacks = [] self._disconnect_callbacks = [] self._topic_prefix = topic_prefix def on_connect(self, client: mqtt.Client, userdata, flags, rc): super().on_connect(client, userdata, flags, rc) for node in self._nodes: node.on_connect(self) for f in self._connect_callbacks: try: f() except Exception as e: self._logger.exception(e) def on_disconnect(self, client: mqtt.Client, userdata, rc): super().on_disconnect(client, userdata, rc) for node in self._nodes: node.on_disconnect() for f in self._disconnect_callbacks: try: f() except Exception as e: self._logger.exception(e) def on_message(self, client: mqtt.Client, userdata, msg): try: topic = msg.topic topic_node = topic[len(self._topic_prefix)+1:topic.find('/', len(self._topic_prefix)+1)] for node in self._nodes: if node.id in ('+', topic_node): node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) except Exception as e: self._logger.exception(str(e)) def add_connect_callback(self, f: callable): self._connect_callbacks.append(f) def add_disconnect_callback(self, f: callable): self._disconnect_callbacks.append(f) def add_node(self, node: MqttNode): self._nodes.append(node) if self._connected: node.on_connect(self) def subscribe(self, node_id: str, topic: str, qos: int): self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos) def publish(self, node_id: str, topic: str, payload: bytes, qos: int): self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos) self._client.loop_write()