summaryrefslogtreecommitdiff
path: root/py_include/homekit/mqtt/_node.py
blob: 4e259a434d213bebbc940b9adc09548283e1cc92 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import logging
import importlib

from typing import List, TYPE_CHECKING, Optional
from ._payload import MqttPayload
from ._module import MqttModule
if TYPE_CHECKING:
    from ._wrapper import MqttWrapper
else:
    MqttWrapper = None


class MqttNode:
    _modules: List[MqttModule]
    _module_subscriptions: dict[str, MqttModule]
    _node_id: str
    _node_secret: str
    _payload_callbacks: list[callable]
    _wrapper: Optional[MqttWrapper]

    def __init__(self,
                 node_id: str,
                 node_secret: Optional[str] = None):
        self._modules = []
        self._module_subscriptions = {}
        self._node_id = node_id
        self._node_secret = node_secret
        self._payload_callbacks = []
        self._logger = logging.getLogger(self.__class__.__name__)
        self._wrapper = None

    def on_connect(self, wrapper: MqttWrapper):
        self._wrapper = wrapper
        for module in self._modules:
            if not module.is_initialized():
                module.on_connect(self)
                module.set_initialized()

    def on_disconnect(self):
        self._wrapper = None
        for module in self._modules:
            module.unset_initialized()

    def on_message(self, topic, payload):
        if topic in self._module_subscriptions:
            payload = self._module_subscriptions[topic].handle_payload(self, topic, payload)
            if isinstance(payload, MqttPayload):
                for f in self._payload_callbacks:
                    f(self, payload)

    def load_module(self, module_name: str, *args, **kwargs) -> MqttModule:
        module = importlib.import_module(f'..module.{module_name}', __name__)
        if not hasattr(module, 'MODULE_NAME'):
            raise RuntimeError(f'MODULE_NAME not found in module {module}')
        cl = getattr(module, getattr(module, 'MODULE_NAME'))
        instance = cl(*args, **kwargs)
        self.add_module(instance)
        return instance

    def add_module(self, module: MqttModule):
        self._modules.append(module)
        if self._wrapper and self._wrapper._connected:
            module.on_connect(self)
            module.set_initialized()

    def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1):
        if not self._wrapper or not self._wrapper._connected:
            raise RuntimeError('not connected')

        self._module_subscriptions[topic] = module
        self._wrapper.subscribe(self.id, topic, qos)

    def publish(self,
                topic: str,
                payload: bytes,
                qos: int = 1):
        self._wrapper.publish(self.id, topic, payload, qos)

    def add_payload_callback(self, callback: callable):
        self._payload_callbacks.append(callback)

    @property
    def id(self) -> str:
        return self._node_id

    @property
    def secret(self) -> str:
        return self._node_secret

    @secret.setter
    def secret(self, secret: str) -> None:
        self._node_secret = secret