summaryrefslogtreecommitdiff
path: root/src/home/mqtt/esp.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/home/mqtt/esp.py')
-rw-r--r--src/home/mqtt/esp.py106
1 files changed, 106 insertions, 0 deletions
diff --git a/src/home/mqtt/esp.py b/src/home/mqtt/esp.py
new file mode 100644
index 0000000..56ced83
--- /dev/null
+++ b/src/home/mqtt/esp.py
@@ -0,0 +1,106 @@
+import re
+import paho.mqtt.client as mqtt
+
+from .mqtt import MqttBase
+from typing import Optional, Union
+from .payload.esp import (
+ OTAPayload,
+ OTAResultPayload,
+ DiagnosticsPayload,
+ InitialDiagnosticsPayload
+)
+
+
+class MqttEspDevice:
+ id: str
+ secret: Optional[str]
+
+ def __init__(self, id: str, secret: Optional[str] = None):
+ self.id = id
+ self.secret = secret
+
+
+class MqttEspBase(MqttBase):
+ _devices: list[MqttEspDevice]
+ _message_callback: Optional[callable]
+ _ota_publish_callback: Optional[callable]
+
+ TOPIC_LEAF = 'esp'
+
+ def __init__(self,
+ devices: Union[MqttEspDevice, list[MqttEspDevice]],
+ subscribe_to_updates=True):
+ super().__init__(clean_session=True)
+ if not isinstance(devices, list):
+ devices = [devices]
+ self._devices = devices
+ self._message_callback = None
+ self._ota_publish_callback = None
+ self._subscribe_to_updates = subscribe_to_updates
+ self._ota_mid = None
+
+ def on_connect(self, client: mqtt.Client, userdata, flags, rc):
+ super().on_connect(client, userdata, flags, rc)
+
+ if self._subscribe_to_updates:
+ for device in self._devices:
+ topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#'
+ self._logger.debug(f"subscribing to {topic}")
+ client.subscribe(topic, qos=1)
+
+ def on_publish(self, client: mqtt.Client, userdata, mid):
+ if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback:
+ self._ota_publish_callback()
+
+ def set_message_callback(self, callback: callable):
+ self._message_callback = callback
+
+ def on_message(self, client: mqtt.Client, userdata, msg):
+ try:
+ match = re.match(self.get_mqtt_topics(), msg.topic)
+ self._logger.debug(f'topic: {msg.topic}')
+ if not match:
+ return
+
+ device_id = match.group(1)
+ subtopic = match.group(2)
+
+ # try:
+ next(d for d in self._devices if d.id == device_id)
+ # except StopIteration:h
+ # return
+
+ message = None
+ if subtopic == 'stat':
+ message = DiagnosticsPayload.unpack(msg.payload)
+ elif subtopic == 'stat1':
+ message = InitialDiagnosticsPayload.unpack(msg.payload)
+ elif subtopic == 'otares':
+ message = OTAResultPayload.unpack(msg.payload)
+
+ if message and self._message_callback:
+ self._message_callback(device_id, message)
+ return True
+
+ except Exception as e:
+ self._logger.exception(str(e))
+
+ def push_ota(self,
+ device_id,
+ filename: str,
+ publish_callback: callable,
+ qos: int):
+ device = next(d for d in self._devices if d.id == device_id)
+ assert device.secret is not None, 'device secret not specified'
+
+ self._ota_publish_callback = publish_callback
+ payload = OTAPayload(secret=device.secret, filename=filename)
+ publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
+ payload=payload.pack(),
+ qos=qos)
+ self._ota_mid = publish_result.mid
+ self._client.loop_write()
+
+ @classmethod
+ def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None):
+ return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$' \ No newline at end of file