summaryrefslogtreecommitdiff
path: root/src/home/mqtt/esp.py
blob: 56ced83c942e3ffae93a366ad580f39ced51c1d1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import re
import paho.mqtt.client as mqtt

from .mqtt import MqttBase
from typing import Optional, Union
from .payload.esp import (
    OTAPayload,
    OTAResultPayload,
    DiagnosticsPayload,
    InitialDiagnosticsPayload
)


class MqttEspDevice:
    id: str
    secret: Optional[str]

    def __init__(self, id: str, secret: Optional[str] = None):
        self.id = id
        self.secret = secret


class MqttEspBase(MqttBase):
    _devices: list[MqttEspDevice]
    _message_callback: Optional[callable]
    _ota_publish_callback: Optional[callable]

    TOPIC_LEAF = 'esp'

    def __init__(self,
                 devices: Union[MqttEspDevice, list[MqttEspDevice]],
                 subscribe_to_updates=True):
        super().__init__(clean_session=True)
        if not isinstance(devices, list):
            devices = [devices]
        self._devices = devices
        self._message_callback = None
        self._ota_publish_callback = None
        self._subscribe_to_updates = subscribe_to_updates
        self._ota_mid = None

    def on_connect(self, client: mqtt.Client, userdata, flags, rc):
        super().on_connect(client, userdata, flags, rc)

        if self._subscribe_to_updates:
            for device in self._devices:
                topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#'
                self._logger.debug(f"subscribing to {topic}")
                client.subscribe(topic, qos=1)

    def on_publish(self, client: mqtt.Client, userdata, mid):
        if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback:
            self._ota_publish_callback()

    def set_message_callback(self, callback: callable):
        self._message_callback = callback

    def on_message(self, client: mqtt.Client, userdata, msg):
        try:
            match = re.match(self.get_mqtt_topics(), msg.topic)
            self._logger.debug(f'topic: {msg.topic}')
            if not match:
                return

            device_id = match.group(1)
            subtopic = match.group(2)

            # try:
            next(d for d in self._devices if d.id == device_id)
            # except StopIteration:h
            #     return

            message = None
            if subtopic == 'stat':
                message = DiagnosticsPayload.unpack(msg.payload)
            elif subtopic == 'stat1':
                message = InitialDiagnosticsPayload.unpack(msg.payload)
            elif subtopic == 'otares':
                message = OTAResultPayload.unpack(msg.payload)

            if message and self._message_callback:
                self._message_callback(device_id, message)
                return True

        except Exception as e:
            self._logger.exception(str(e))

    def push_ota(self,
                 device_id,
                 filename: str,
                 publish_callback: callable,
                 qos: int):
        device = next(d for d in self._devices if d.id == device_id)
        assert device.secret is not None, 'device secret not specified'

        self._ota_publish_callback = publish_callback
        payload = OTAPayload(secret=device.secret, filename=filename)
        publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
                                              payload=payload.pack(),
                                              qos=qos)
        self._ota_mid = publish_result.mid
        self._client.loop_write()

    @classmethod
    def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None):
        return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'