diff options
Diffstat (limited to 'src/home/mqtt/esp.py')
-rw-r--r-- | src/home/mqtt/esp.py | 106 |
1 files changed, 106 insertions, 0 deletions
diff --git a/src/home/mqtt/esp.py b/src/home/mqtt/esp.py new file mode 100644 index 0000000..56ced83 --- /dev/null +++ b/src/home/mqtt/esp.py @@ -0,0 +1,106 @@ +import re +import paho.mqtt.client as mqtt + +from .mqtt import MqttBase +from typing import Optional, Union +from .payload.esp import ( + OTAPayload, + OTAResultPayload, + DiagnosticsPayload, + InitialDiagnosticsPayload +) + + +class MqttEspDevice: + id: str + secret: Optional[str] + + def __init__(self, id: str, secret: Optional[str] = None): + self.id = id + self.secret = secret + + +class MqttEspBase(MqttBase): + _devices: list[MqttEspDevice] + _message_callback: Optional[callable] + _ota_publish_callback: Optional[callable] + + TOPIC_LEAF = 'esp' + + def __init__(self, + devices: Union[MqttEspDevice, list[MqttEspDevice]], + subscribe_to_updates=True): + super().__init__(clean_session=True) + if not isinstance(devices, list): + devices = [devices] + self._devices = devices + self._message_callback = None + self._ota_publish_callback = None + self._subscribe_to_updates = subscribe_to_updates + self._ota_mid = None + + def on_connect(self, client: mqtt.Client, userdata, flags, rc): + super().on_connect(client, userdata, flags, rc) + + if self._subscribe_to_updates: + for device in self._devices: + topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#' + self._logger.debug(f"subscribing to {topic}") + client.subscribe(topic, qos=1) + + def on_publish(self, client: mqtt.Client, userdata, mid): + if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback: + self._ota_publish_callback() + + def set_message_callback(self, callback: callable): + self._message_callback = callback + + def on_message(self, client: mqtt.Client, userdata, msg): + try: + match = re.match(self.get_mqtt_topics(), msg.topic) + self._logger.debug(f'topic: {msg.topic}') + if not match: + return + + device_id = match.group(1) + subtopic = match.group(2) + + # try: + next(d for d in self._devices if d.id == device_id) + # except StopIteration:h + # return + + message = None + if subtopic == 'stat': + message = DiagnosticsPayload.unpack(msg.payload) + elif subtopic == 'stat1': + message = InitialDiagnosticsPayload.unpack(msg.payload) + elif subtopic == 'otares': + message = OTAResultPayload.unpack(msg.payload) + + if message and self._message_callback: + self._message_callback(device_id, message) + return True + + except Exception as e: + self._logger.exception(str(e)) + + def push_ota(self, + device_id, + filename: str, + publish_callback: callable, + qos: int): + device = next(d for d in self._devices if d.id == device_id) + assert device.secret is not None, 'device secret not specified' + + self._ota_publish_callback = publish_callback + payload = OTAPayload(secret=device.secret, filename=filename) + publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota', + payload=payload.pack(), + qos=qos) + self._ota_mid = publish_result.mid + self._client.loop_write() + + @classmethod + def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None): + return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'
\ No newline at end of file |