summaryrefslogtreecommitdiff
path: root/include/py/homekit/mqtt/_wrapper.py
blob: 68af09335ec1402dd13ace329481db700c5710dc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import paho.mqtt.client as mqtt

from ._mqtt import Mqtt
from ._node import MqttNode
from ..util import strgen


class MqttWrapper(Mqtt):
    _nodes: list[MqttNode]
    _connect_callbacks: list[callable]
    _disconnect_callbacks: list[callable]

    def __init__(self,
                 client_id: str,
                 topic_prefix='hk',
                 randomize_client_id=False,
                 clean_session=True):
        if randomize_client_id:
            client_id += '_'+strgen(6)
        super().__init__(clean_session=clean_session,
                         client_id=client_id)
        self._nodes = []
        self._connect_callbacks = []
        self._disconnect_callbacks = []
        self._topic_prefix = topic_prefix

    def on_connect(self, client: mqtt.Client, userdata, flags, rc):
        super().on_connect(client, userdata, flags, rc)
        for node in self._nodes:
            node.on_connect(self)
        for f in self._connect_callbacks:
            try:
                f()
            except Exception as e:
                self._logger.exception(e)

    def on_disconnect(self, client: mqtt.Client, userdata, rc):
        super().on_disconnect(client, userdata, rc)
        for node in self._nodes:
            node.on_disconnect()
        for f in self._disconnect_callbacks:
            try:
                f()
            except Exception as e:
                self._logger.exception(e)

    def on_message(self, client: mqtt.Client, userdata, msg):
        try:
            topic = msg.topic
            topic_node = topic[len(self._topic_prefix)+1:topic.find('/', len(self._topic_prefix)+1)]
            for node in self._nodes:
                if node.id in ('+', topic_node):
                    node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload)
        except Exception as e:
            self._logger.exception(str(e))

    def add_connect_callback(self, f: callable):
        self._connect_callbacks.append(f)

    def add_disconnect_callback(self, f: callable):
       self._disconnect_callbacks.append(f)

    def add_node(self, node: MqttNode):
        self._nodes.append(node)
        if self._connected:
            node.on_connect(self)

    def subscribe(self,
                  node_id: str,
                  topic: str,
                  qos: int):
        self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos)

    def publish(self,
                node_id: str,
                topic: str,
                payload: bytes,
                qos: int):
        self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos)
        self._client.loop_write()