summaryrefslogtreecommitdiff
path: root/include/py/homekit/mqtt/_mqtt.py
blob: 47ee9ae36c89dd0034c75e4b5bd8957bf5ed44c7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os.path
import paho.mqtt.client as mqtt
import ssl
import logging

from ._config import MqttCreds, MqttConfig
from typing import Optional


class Mqtt:
    _connected: bool
    _is_server: bool
    _mqtt_config: MqttConfig

    def __init__(self,
                 clean_session=True,
                 client_id='',
                 creds: Optional[MqttCreds] = None,
                 is_server=False):
        if not client_id:
            raise ValueError('client_id must not be empty')

        self._client = mqtt.Client(client_id=client_id,
                                   protocol=mqtt.MQTTv311,
                                   clean_session=clean_session)
        self._client.on_connect = self.on_connect
        self._client.on_disconnect = self.on_disconnect
        self._client.on_message = self.on_message
        self._client.on_log = self.on_log
        self._client.on_publish = self.on_publish
        self._loop_started = False
        self._connected = False
        self._is_server = is_server
        self._mqtt_config = MqttConfig()
        self._logger = logging.getLogger(self.__class__.__name__)

        if not creds:
            creds = self._mqtt_config.creds() if not is_server else self._mqtt_config.server_creds()

        self._client.username_pw_set(creds.username, creds.password)

    def _configure_tls(self):
        ca_certs = os.path.realpath(os.path.join(
            os.path.dirname(os.path.realpath(__file__)),
            '..',
            '..',
            '..',
            '..',
            'misc',
            'mqtt_ca.crt'
        ))
        self._client.tls_set(ca_certs=ca_certs,
                             cert_reqs=ssl.CERT_REQUIRED,
                             tls_version=ssl.PROTOCOL_TLSv1_2)

    def connect_and_loop(self, loop_forever=True):
        self._configure_tls()
        addr = self._mqtt_config.local_addr() if self._is_server else self._mqtt_config.remote_addr()
        self._client.connect(addr.host, addr.port, 60)
        if loop_forever:
            self._client.loop_forever()
        else:
            self._client.loop_start()
            self._loop_started = True

    def disconnect(self):
        self._client.disconnect()
        self._client.loop_write()
        self._client.loop_stop()

    def on_connect(self, client: mqtt.Client, userdata, flags, rc):
        self._logger.info("Connected with result code " + str(rc))
        self._connected = True

    def on_disconnect(self, client: mqtt.Client, userdata, rc):
        self._logger.info("Disconnected with result code " + str(rc))
        self._connected = False

    def on_log(self, client: mqtt.Client, userdata, level, buf):
        level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO
        self._logger.log(level, f'MQTT: {buf}')

    def on_message(self, client: mqtt.Client, userdata, msg):
        self._logger.debug(msg.topic + ": " + str(msg.payload))

    def on_publish(self, client: mqtt.Client, userdata, mid):
        self._logger.debug(f'publish done, mid={mid}')