1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
|
import re
import paho.mqtt.client as mqtt
from .mqtt import MqttBase
from typing import Optional, Union
from .payload.esp import (
OTAPayload,
OTAResultPayload,
DiagnosticsPayload,
InitialDiagnosticsPayload
)
class MqttEspDevice:
id: str
secret: Optional[str]
def __init__(self, id: str, secret: Optional[str] = None):
self.id = id
self.secret = secret
class MqttEspBase(MqttBase):
_devices: list[MqttEspDevice]
_message_callback: Optional[callable]
_ota_publish_callback: Optional[callable]
TOPIC_LEAF = 'esp'
def __init__(self,
devices: Union[MqttEspDevice, list[MqttEspDevice]],
subscribe_to_updates=True):
super().__init__(clean_session=True)
if not isinstance(devices, list):
devices = [devices]
self._devices = devices
self._message_callback = None
self._ota_publish_callback = None
self._subscribe_to_updates = subscribe_to_updates
self._ota_mid = None
def on_connect(self, client: mqtt.Client, userdata, flags, rc):
super().on_connect(client, userdata, flags, rc)
if self._subscribe_to_updates:
for device in self._devices:
topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#'
self._logger.debug(f"subscribing to {topic}")
client.subscribe(topic, qos=1)
def on_publish(self, client: mqtt.Client, userdata, mid):
if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback:
self._ota_publish_callback()
def set_message_callback(self, callback: callable):
self._message_callback = callback
def on_message(self, client: mqtt.Client, userdata, msg):
try:
match = re.match(self.get_mqtt_topics(), msg.topic)
self._logger.debug(f'topic: {msg.topic}')
if not match:
return
device_id = match.group(1)
subtopic = match.group(2)
# try:
next(d for d in self._devices if d.id == device_id)
# except StopIteration:h
# return
message = None
if subtopic == 'stat':
message = DiagnosticsPayload.unpack(msg.payload)
elif subtopic == 'stat1':
message = InitialDiagnosticsPayload.unpack(msg.payload)
elif subtopic == 'otares':
message = OTAResultPayload.unpack(msg.payload)
if message and self._message_callback:
self._message_callback(device_id, message)
return True
except Exception as e:
self._logger.exception(str(e))
def push_ota(self,
device_id,
filename: str,
publish_callback: callable,
qos: int):
device = next(d for d in self._devices if d.id == device_id)
assert device.secret is not None, 'device secret not specified'
self._ota_publish_callback = publish_callback
payload = OTAPayload(secret=device.secret, filename=filename)
publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
payload=payload.pack(),
qos=qos)
self._ota_mid = publish_result.mid
self._client.loop_write()
@classmethod
def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None):
return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'
|