import re import paho.mqtt.client as mqtt from .mqtt import MqttBase from typing import Optional, Union from .payload.esp import ( OTAPayload, OTAResultPayload, DiagnosticsPayload, InitialDiagnosticsPayload ) class MqttEspDevice: id: str secret: Optional[str] def __init__(self, id: str, secret: Optional[str] = None): self.id = id self.secret = secret class MqttEspBase(MqttBase): _devices: list[MqttEspDevice] _message_callback: Optional[callable] _ota_publish_callback: Optional[callable] TOPIC_LEAF = 'esp' def __init__(self, devices: Union[MqttEspDevice, list[MqttEspDevice]], subscribe_to_updates=True): super().__init__(clean_session=True) if not isinstance(devices, list): devices = [devices] self._devices = devices self._message_callback = None self._ota_publish_callback = None self._subscribe_to_updates = subscribe_to_updates self._ota_mid = None def on_connect(self, client: mqtt.Client, userdata, flags, rc): super().on_connect(client, userdata, flags, rc) if self._subscribe_to_updates: for device in self._devices: topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#' self._logger.debug(f"subscribing to {topic}") client.subscribe(topic, qos=1) def on_publish(self, client: mqtt.Client, userdata, mid): if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback: self._ota_publish_callback() def set_message_callback(self, callback: callable): self._message_callback = callback def on_message(self, client: mqtt.Client, userdata, msg): try: match = re.match(self.get_mqtt_topics(), msg.topic) self._logger.debug(f'topic: {msg.topic}') if not match: return device_id = match.group(1) subtopic = match.group(2) # try: next(d for d in self._devices if d.id == device_id) # except StopIteration:h # return message = None if subtopic == 'stat': message = DiagnosticsPayload.unpack(msg.payload) elif subtopic == 'stat1': message = InitialDiagnosticsPayload.unpack(msg.payload) elif subtopic == 'otares': message = OTAResultPayload.unpack(msg.payload) if message and self._message_callback: self._message_callback(device_id, message) return True except Exception as e: self._logger.exception(str(e)) def push_ota(self, device_id, filename: str, publish_callback: callable, qos: int): device = next(d for d in self._devices if d.id == device_id) assert device.secret is not None, 'device secret not specified' self._ota_publish_callback = publish_callback payload = OTAPayload(secret=device.secret, filename=filename) publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota', payload=payload.pack(), qos=qos) self._ota_mid = publish_result.mid self._client.loop_write() @classmethod def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None): return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'