summaryrefslogtreecommitdiff
path: root/src/home/mqtt/_wrapper.py
blob: 3c2774c96a3c352cf8613477d794be18bf9b6ae3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import paho.mqtt.client as mqtt

from ._mqtt import Mqtt
from ._node import MqttNode
from ..util import strgen


class MqttWrapper(Mqtt):
    _nodes: list[MqttNode]

    def __init__(self,
                 client_id: str,
                 topic_prefix='hk',
                 randomize_client_id=False,
                 clean_session=True):
        if randomize_client_id:
            client_id += '_'+strgen(6)
        super().__init__(clean_session=clean_session,
                         client_id=client_id)
        self._nodes = []
        self._topic_prefix = topic_prefix

    def on_connect(self, client: mqtt.Client, userdata, flags, rc):
        super().on_connect(client, userdata, flags, rc)
        for node in self._nodes:
            node.on_connect(self)

    def on_disconnect(self, client: mqtt.Client, userdata, rc):
        super().on_disconnect(client, userdata, rc)
        for node in self._nodes:
            node.on_disconnect()

    def on_message(self, client: mqtt.Client, userdata, msg):
        try:
            topic = msg.topic
            topic_node = topic[len(self._topic_prefix)+1:topic.find('/', len(self._topic_prefix)+1)]
            for node in self._nodes:
                if node.id in ('+', topic_node):
                    node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload)
        except Exception as e:
            self._logger.exception(str(e))

    def add_node(self, node: MqttNode):
        self._nodes.append(node)
        if self._connected:
            node.on_connect(self)

    def subscribe(self,
                  node_id: str,
                  topic: str,
                  qos: int):
        self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos)

    def publish(self,
                node_id: str,
                topic: str,
                payload: bytes,
                qos: int):
        self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos)
        self._client.loop_write()