summaryrefslogtreecommitdiff
path: root/src/home/mqtt/_node.py
blob: c76610faef4d14a056723205027ad266747e5bd3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import paho.mqtt.client as mqtt

from .mqtt import MqttBase
from typing import List
from ._module import MqttModule


class MqttNode(MqttBase):
    _modules: List[MqttModule]
    _module_subscriptions: dict[str, MqttModule]
    _node_id: str
    # _devices: list[MqttEspDevice]
    # _message_callback: Optional[callable]
    # _ota_publish_callback: Optional[callable]

    def __init__(self,
                 node_id: str,
                 # devices: Union[MqttEspDevice, list[MqttEspDevice]],
                 subscribe_to_updates=True):
        super().__init__(clean_session=True)
        self._modules = []
        self._module_subscriptions = {}
        self._node_id = node_id
        # if not isinstance(devices, list):
        #     devices = [devices]
        # self._devices = devices
        # self._message_callback = None
        # self._ota_publish_callback = None
        # self._subscribe_to_updates = subscribe_to_updates
        # self._ota_mid = None

    def on_connect(self, client: mqtt.Client, userdata, flags, rc):
        super().on_connect(client, userdata, flags, rc)
        for module in self._modules:
            if not module.is_initialized():
                module.init(self)
                module.set_initialized()

    def on_publish(self, client: mqtt.Client, userdata, mid):
        pass  # FIXME
        # if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback:
        #     self._ota_publish_callback()

    def on_message(self, client: mqtt.Client, userdata, msg):
        try:
            topic = msg.topic
            actual_topic = topic[len(f'hk/{self._node_id}/'):]

            if actual_topic in self._module_subscriptions:
                self._module_subscriptions[actual_topic].handle_payload(self, actual_topic, msg.payload)

        except Exception as e:
            self._logger.exception(str(e))

    # def push_ota(self,
    #              device_id,
    #              filename: str,
    #              publish_callback: callable,
    #              qos: int):
    #     device = next(d for d in self._devices if d.id == device_id)
    #     assert device.secret is not None, 'device secret not specified'
    #
    #     self._ota_publish_callback = publish_callback
    #     payload = OtaPayload(secret=device.secret, filename=filename)
    #     publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
    #                                           payload=payload.pack(),
    #                                           qos=qos)
    #     self._ota_mid = publish_result.mid
    #     self._client.loop_write()
    #
    # @classmethod
    # def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None):
    #     return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'

    def add_module(self, module: MqttModule):
        self._modules.append(module)
        if self._connected:
            module.init(self)
            module.set_initialized()

    def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1):
        self._module_subscriptions[topic] = module
        self._client.subscribe(f'hk/{self._node_id}/{topic}', qos)

    def publish(self, topic: str, payload: bytes, qos: int = 1):
        self._client.publish(f'hk/{self._node_id}/{topic}', payload, qos)
        self._client.loop_write()