summaryrefslogtreecommitdiff
path: root/include/py/homekit/mqtt
diff options
context:
space:
mode:
authorEvgeny Zinoviev <me@ch1p.io>2024-02-17 03:08:25 +0300
committerEvgeny Zinoviev <me@ch1p.io>2024-02-17 03:08:25 +0300
commit0ce2e41a2bad790c5232fafb4b6ed631ca8cd957 (patch)
treefd401495b87cae8c95a4c4edf2c851c8177b6069 /include/py/homekit/mqtt
parente9fc2c1835f7ac8e072919df81a6661c6308dea9 (diff)
parentb7f1d55c9b4de4d21b11e5615a5dc8be0d4e883c (diff)
merge with master
Diffstat (limited to 'include/py/homekit/mqtt')
-rw-r--r--include/py/homekit/mqtt/__init__.py7
-rw-r--r--include/py/homekit/mqtt/_config.py183
-rw-r--r--include/py/homekit/mqtt/_module.py70
-rw-r--r--include/py/homekit/mqtt/_mqtt.py87
-rw-r--r--include/py/homekit/mqtt/_node.py92
-rw-r--r--include/py/homekit/mqtt/_payload.py145
-rw-r--r--include/py/homekit/mqtt/_util.py15
-rw-r--r--include/py/homekit/mqtt/_wrapper.py81
-rw-r--r--include/py/homekit/mqtt/module/diagnostics.py64
-rw-r--r--include/py/homekit/mqtt/module/inverter.py195
-rw-r--r--include/py/homekit/mqtt/module/ota.py77
-rw-r--r--include/py/homekit/mqtt/module/relay.py91
-rw-r--r--include/py/homekit/mqtt/module/temphum.py73
13 files changed, 1180 insertions, 0 deletions
diff --git a/include/py/homekit/mqtt/__init__.py b/include/py/homekit/mqtt/__init__.py
new file mode 100644
index 0000000..707d59c
--- /dev/null
+++ b/include/py/homekit/mqtt/__init__.py
@@ -0,0 +1,7 @@
+from ._mqtt import Mqtt
+from ._node import MqttNode
+from ._module import MqttModule
+from ._wrapper import MqttWrapper
+from ._config import MqttConfig, MqttCreds, MqttNodesConfig
+from ._payload import MqttPayload, MqttPayloadCustomField
+from ._util import get_modules as get_mqtt_modules \ No newline at end of file
diff --git a/include/py/homekit/mqtt/_config.py b/include/py/homekit/mqtt/_config.py
new file mode 100644
index 0000000..8aa3bfe
--- /dev/null
+++ b/include/py/homekit/mqtt/_config.py
@@ -0,0 +1,183 @@
+from ..config import ConfigUnit
+from typing import Optional, Union
+from ..util import Addr
+from collections import namedtuple
+
+MqttCreds = namedtuple('MqttCreds', 'username, password')
+
+
+class MqttConfig(ConfigUnit):
+ NAME = 'mqtt'
+
+ @classmethod
+ def schema(cls) -> Optional[dict]:
+ addr_schema = {
+ 'type': 'dict',
+ 'required': True,
+ 'schema': {
+ 'host': {'type': 'string', 'required': True},
+ 'port': {'type': 'integer', 'required': True}
+ }
+ }
+
+ schema = {}
+ for key in ('local', 'remote'):
+ schema[f'{key}_addr'] = addr_schema
+
+ schema['creds'] = {
+ 'type': 'dict',
+ 'required': True,
+ 'keysrules': {'type': 'string'},
+ 'valuesrules': {
+ 'type': 'dict',
+ 'schema': {
+ 'username': {'type': 'string', 'required': True},
+ 'password': {'type': 'string', 'required': True},
+ }
+ }
+ }
+
+ for key in ('client', 'server'):
+ schema[f'default_{key}_creds'] = {'type': 'string', 'required': True}
+
+ return schema
+
+ def remote_addr(self) -> Addr:
+ return Addr(host=self['remote_addr']['host'],
+ port=self['remote_addr']['port'])
+
+ def local_addr(self) -> Addr:
+ return Addr(host=self['local_addr']['host'],
+ port=self['local_addr']['port'])
+
+ def creds_by_name(self, name: str) -> MqttCreds:
+ return MqttCreds(username=self['creds'][name]['username'],
+ password=self['creds'][name]['password'])
+
+ def creds(self) -> MqttCreds:
+ return self.creds_by_name(self['default_client_creds'])
+
+ def server_creds(self) -> MqttCreds:
+ return self.creds_by_name(self['default_server_creds'])
+
+
+class MqttNodesConfig(ConfigUnit):
+ NAME = 'mqtt_nodes'
+
+ @classmethod
+ def schema(cls) -> Optional[dict]:
+ return {
+ 'common': {
+ 'type': 'dict',
+ 'schema': {
+ 'temphum': {
+ 'type': 'dict',
+ 'schema': {
+ 'interval': {'type': 'integer'}
+ }
+ },
+ 'password': {'type': 'string'}
+ }
+ },
+ 'nodes': {
+ 'type': 'dict',
+ 'required': True,
+ 'keysrules': {'type': 'string'},
+ 'valuesrules': {
+ 'type': 'dict',
+ 'schema': {
+ 'type': {'type': 'string', 'required': True, 'allowed': ['esp8266', 'linux', 'none'],},
+ 'board': {'type': 'string', 'allowed': ['nodemcu', 'd1_mini_lite', 'esp12e']},
+ 'temphum': {
+ 'type': 'dict',
+ 'schema': {
+ 'module': {'type': 'string', 'required': True, 'allowed': ['si7021', 'dht12']},
+ 'legacy_payload': {'type': 'boolean', 'required': False, 'default': False},
+ 'interval': {'type': 'integer'},
+ 'i2c_bus': {'type': 'integer'},
+ 'tcpserver': {
+ 'type': 'dict',
+ 'schema': {
+ 'port': {'type': 'integer', 'required': True}
+ }
+ }
+ }
+ },
+ 'relay': {
+ 'type': 'dict',
+ 'schema': {
+ 'device_type': {'type': 'string', 'allowed': ['lamp', 'pump', 'solenoid', 'cooler'], 'required': True},
+ 'legacy_topics': {'type': 'boolean'}
+ }
+ },
+ 'password': {'type': 'string'},
+ 'defines': {
+ 'type': 'dict',
+ 'keysrules': {'type': 'string'},
+ 'valuesrules': {'type': ['string', 'integer']}
+ }
+ }
+ }
+ }
+ }
+
+ @staticmethod
+ def custom_validator(data):
+ for name, node in data['nodes'].items():
+ if 'temphum' in node:
+ if node['type'] == 'linux':
+ if 'i2c_bus' not in node['temphum']:
+ raise KeyError(f'nodes.{name}.temphum: i2c_bus is missing but required for type=linux')
+ if node['type'] in ('esp8266',) and 'board' not in node:
+ raise KeyError(f'nodes.{name}: board is missing but required for type={node["type"]}')
+
+ def get_node(self, name: str) -> dict:
+ node = self['nodes'][name]
+ if node['type'] == 'none':
+ return node
+
+ try:
+ if 'password' not in node:
+ node['password'] = self['common']['password']
+ except KeyError:
+ pass
+
+ try:
+ if 'temphum' in node:
+ for ckey, cval in self['common']['temphum'].items():
+ if ckey not in node['temphum']:
+ node['temphum'][ckey] = cval
+ except KeyError:
+ pass
+
+ return node
+
+ def get_nodes(self,
+ filters: Optional[Union[list[str], tuple[str]]] = None,
+ only_names=False) -> Union[dict, list[str]]:
+ if filters:
+ for f in filters:
+ if f not in ('temphum', 'relay'):
+ raise ValueError(f'{self.__class__.__name__}::get_node(): invalid filter {f}')
+ reslist = []
+ resdict = {}
+ for name in self['nodes'].keys():
+ node = self.get_node(name)
+ if (not filters) or ('temphum' in filters and 'temphum' in node) or ('relay' in filters and 'relay' in node):
+ if only_names:
+ reslist.append(name)
+ else:
+ resdict[name] = node
+ return reslist if only_names else resdict
+
+ def node_uses_legacy_temphum_data_payload(self, node_id: str) -> bool:
+ try:
+ return self.get_node(node_id)['temphum']['legacy_payload']
+ except KeyError:
+ return False
+
+ def node_uses_legacy_relay_power_payload(self, node_id: str) -> bool:
+ try:
+ return self.get_node(node_id)['relay']['legacy_topics']
+ except KeyError:
+ return False
diff --git a/include/py/homekit/mqtt/_module.py b/include/py/homekit/mqtt/_module.py
new file mode 100644
index 0000000..80f27bb
--- /dev/null
+++ b/include/py/homekit/mqtt/_module.py
@@ -0,0 +1,70 @@
+from __future__ import annotations
+
+import abc
+import logging
+import threading
+
+from time import sleep
+from ..util import next_tick_gen
+
+from typing import TYPE_CHECKING, Optional
+if TYPE_CHECKING:
+ from ._node import MqttNode
+ from ._payload import MqttPayload
+
+
+class MqttModule(abc.ABC):
+ _tick_interval: int
+ _initialized: bool
+ _connected: bool
+ _ticker: Optional[threading.Thread]
+ _mqtt_node_ref: Optional[MqttNode]
+
+ def __init__(self, tick_interval=0):
+ self._tick_interval = tick_interval
+ self._initialized = False
+ self._ticker = None
+ self._logger = logging.getLogger(self.__class__.__name__)
+ self._connected = False
+ self._mqtt_node_ref = None
+
+ def on_connect(self, mqtt: MqttNode):
+ self._connected = True
+ self._mqtt_node_ref = mqtt
+ if self._tick_interval:
+ self._start_ticker()
+
+ def on_disconnect(self, mqtt: MqttNode):
+ self._connected = False
+ self._mqtt_node_ref = None
+
+ def is_initialized(self):
+ return self._initialized
+
+ def set_initialized(self):
+ self._initialized = True
+
+ def unset_initialized(self):
+ self._initialized = False
+
+ def tick(self):
+ pass
+
+ def _tick(self):
+ g = next_tick_gen(self._tick_interval)
+ while self._connected:
+ sleep(next(g))
+ if not self._connected:
+ break
+ self.tick()
+
+ def _start_ticker(self):
+ if not self._ticker or not self._ticker.is_alive():
+ name_part = f'{self._mqtt_node_ref.id}/' if self._mqtt_node_ref else ''
+ self._ticker = None
+ self._ticker = threading.Thread(target=self._tick,
+ name=f'mqtt:{self.__class__.__name__}/{name_part}ticker')
+ self._ticker.start()
+
+ def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
+ pass
diff --git a/include/py/homekit/mqtt/_mqtt.py b/include/py/homekit/mqtt/_mqtt.py
new file mode 100644
index 0000000..47ee9ae
--- /dev/null
+++ b/include/py/homekit/mqtt/_mqtt.py
@@ -0,0 +1,87 @@
+import os.path
+import paho.mqtt.client as mqtt
+import ssl
+import logging
+
+from ._config import MqttCreds, MqttConfig
+from typing import Optional
+
+
+class Mqtt:
+ _connected: bool
+ _is_server: bool
+ _mqtt_config: MqttConfig
+
+ def __init__(self,
+ clean_session=True,
+ client_id='',
+ creds: Optional[MqttCreds] = None,
+ is_server=False):
+ if not client_id:
+ raise ValueError('client_id must not be empty')
+
+ self._client = mqtt.Client(client_id=client_id,
+ protocol=mqtt.MQTTv311,
+ clean_session=clean_session)
+ self._client.on_connect = self.on_connect
+ self._client.on_disconnect = self.on_disconnect
+ self._client.on_message = self.on_message
+ self._client.on_log = self.on_log
+ self._client.on_publish = self.on_publish
+ self._loop_started = False
+ self._connected = False
+ self._is_server = is_server
+ self._mqtt_config = MqttConfig()
+ self._logger = logging.getLogger(self.__class__.__name__)
+
+ if not creds:
+ creds = self._mqtt_config.creds() if not is_server else self._mqtt_config.server_creds()
+
+ self._client.username_pw_set(creds.username, creds.password)
+
+ def _configure_tls(self):
+ ca_certs = os.path.realpath(os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ '..',
+ '..',
+ '..',
+ '..',
+ 'misc',
+ 'mqtt_ca.crt'
+ ))
+ self._client.tls_set(ca_certs=ca_certs,
+ cert_reqs=ssl.CERT_REQUIRED,
+ tls_version=ssl.PROTOCOL_TLSv1_2)
+
+ def connect_and_loop(self, loop_forever=True):
+ self._configure_tls()
+ addr = self._mqtt_config.local_addr() if self._is_server else self._mqtt_config.remote_addr()
+ self._client.connect(addr.host, addr.port, 60)
+ if loop_forever:
+ self._client.loop_forever()
+ else:
+ self._client.loop_start()
+ self._loop_started = True
+
+ def disconnect(self):
+ self._client.disconnect()
+ self._client.loop_write()
+ self._client.loop_stop()
+
+ def on_connect(self, client: mqtt.Client, userdata, flags, rc):
+ self._logger.info("Connected with result code " + str(rc))
+ self._connected = True
+
+ def on_disconnect(self, client: mqtt.Client, userdata, rc):
+ self._logger.info("Disconnected with result code " + str(rc))
+ self._connected = False
+
+ def on_log(self, client: mqtt.Client, userdata, level, buf):
+ level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO
+ self._logger.log(level, f'MQTT: {buf}')
+
+ def on_message(self, client: mqtt.Client, userdata, msg):
+ self._logger.debug(msg.topic + ": " + str(msg.payload))
+
+ def on_publish(self, client: mqtt.Client, userdata, mid):
+ self._logger.debug(f'publish done, mid={mid}')
diff --git a/include/py/homekit/mqtt/_node.py b/include/py/homekit/mqtt/_node.py
new file mode 100644
index 0000000..4e259a4
--- /dev/null
+++ b/include/py/homekit/mqtt/_node.py
@@ -0,0 +1,92 @@
+import logging
+import importlib
+
+from typing import List, TYPE_CHECKING, Optional
+from ._payload import MqttPayload
+from ._module import MqttModule
+if TYPE_CHECKING:
+ from ._wrapper import MqttWrapper
+else:
+ MqttWrapper = None
+
+
+class MqttNode:
+ _modules: List[MqttModule]
+ _module_subscriptions: dict[str, MqttModule]
+ _node_id: str
+ _node_secret: str
+ _payload_callbacks: list[callable]
+ _wrapper: Optional[MqttWrapper]
+
+ def __init__(self,
+ node_id: str,
+ node_secret: Optional[str] = None):
+ self._modules = []
+ self._module_subscriptions = {}
+ self._node_id = node_id
+ self._node_secret = node_secret
+ self._payload_callbacks = []
+ self._logger = logging.getLogger(self.__class__.__name__)
+ self._wrapper = None
+
+ def on_connect(self, wrapper: MqttWrapper):
+ self._wrapper = wrapper
+ for module in self._modules:
+ if not module.is_initialized():
+ module.on_connect(self)
+ module.set_initialized()
+
+ def on_disconnect(self):
+ self._wrapper = None
+ for module in self._modules:
+ module.unset_initialized()
+
+ def on_message(self, topic, payload):
+ if topic in self._module_subscriptions:
+ payload = self._module_subscriptions[topic].handle_payload(self, topic, payload)
+ if isinstance(payload, MqttPayload):
+ for f in self._payload_callbacks:
+ f(self, payload)
+
+ def load_module(self, module_name: str, *args, **kwargs) -> MqttModule:
+ module = importlib.import_module(f'..module.{module_name}', __name__)
+ if not hasattr(module, 'MODULE_NAME'):
+ raise RuntimeError(f'MODULE_NAME not found in module {module}')
+ cl = getattr(module, getattr(module, 'MODULE_NAME'))
+ instance = cl(*args, **kwargs)
+ self.add_module(instance)
+ return instance
+
+ def add_module(self, module: MqttModule):
+ self._modules.append(module)
+ if self._wrapper and self._wrapper._connected:
+ module.on_connect(self)
+ module.set_initialized()
+
+ def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1):
+ if not self._wrapper or not self._wrapper._connected:
+ raise RuntimeError('not connected')
+
+ self._module_subscriptions[topic] = module
+ self._wrapper.subscribe(self.id, topic, qos)
+
+ def publish(self,
+ topic: str,
+ payload: bytes,
+ qos: int = 1):
+ self._wrapper.publish(self.id, topic, payload, qos)
+
+ def add_payload_callback(self, callback: callable):
+ self._payload_callbacks.append(callback)
+
+ @property
+ def id(self) -> str:
+ return self._node_id
+
+ @property
+ def secret(self) -> str:
+ return self._node_secret
+
+ @secret.setter
+ def secret(self, secret: str) -> None:
+ self._node_secret = secret
diff --git a/include/py/homekit/mqtt/_payload.py b/include/py/homekit/mqtt/_payload.py
new file mode 100644
index 0000000..58eeae3
--- /dev/null
+++ b/include/py/homekit/mqtt/_payload.py
@@ -0,0 +1,145 @@
+import struct
+import abc
+import re
+
+from typing import Optional, Tuple
+
+
+def pldstr(self) -> str:
+ attrs = []
+ for field in self.__class__.__annotations__:
+ if hasattr(self, field):
+ attr = getattr(self, field)
+ attrs.append(f'{field}={attr}')
+ if attrs:
+ attrs_s = ' '
+ attrs_s += ', '.join(attrs)
+ else:
+ attrs_s = ''
+ return f'<%s{attrs_s}>' % (self.__class__.__name__,)
+
+
+class MqttPayload(abc.ABC):
+ FORMAT = ''
+ PACKER = {}
+ UNPACKER = {}
+
+ def __init__(self, **kwargs):
+ for field in self.__class__.__annotations__:
+ setattr(self, field, kwargs[field])
+
+ def pack(self):
+ args = []
+ bf_number = -1
+ bf_arg = 0
+ bf_progress = 0
+
+ for field, field_type in self.__class__.__annotations__.items():
+ bfp = _bit_field_params(field_type)
+ if bfp:
+ n, s, b = bfp
+ if n != bf_number:
+ if bf_number != -1:
+ args.append(bf_arg)
+ bf_number = n
+ bf_progress = 0
+ bf_arg = 0
+ bf_arg |= (getattr(self, field) & (2 ** b - 1)) << bf_progress
+ bf_progress += b
+
+ else:
+ if bf_number != -1:
+ args.append(bf_arg)
+ bf_number = -1
+ bf_progress = 0
+ bf_arg = 0
+
+ args.append(self._pack_field(field))
+
+ if bf_number != -1:
+ args.append(bf_arg)
+
+ return struct.pack(self.FORMAT, *args)
+
+ @classmethod
+ def unpack(cls, buf: bytes):
+ data = struct.unpack(cls.FORMAT, buf)
+ kwargs = {}
+ i = 0
+ bf_number = -1
+ bf_progress = 0
+
+ for field, field_type in cls.__annotations__.items():
+ bfp = _bit_field_params(field_type)
+ if bfp:
+ n, s, b = bfp
+ if n != bf_number:
+ bf_number = n
+ bf_progress = 0
+ kwargs[field] = (data[i] >> bf_progress) & (2 ** b - 1)
+ bf_progress += b
+ continue # don't increment i
+
+ if bf_number != -1:
+ bf_number = -1
+ i += 1
+
+ if issubclass(field_type, MqttPayloadCustomField):
+ kwargs[field] = field_type.unpack(data[i])
+ else:
+ kwargs[field] = cls._unpack_field(field, data[i])
+ i += 1
+
+ return cls(**kwargs)
+
+ def _pack_field(self, name):
+ val = getattr(self, name)
+ if self.PACKER and name in self.PACKER:
+ return self.PACKER[name](val)
+ else:
+ return val
+
+ @classmethod
+ def _unpack_field(cls, name, val):
+ if isinstance(val, MqttPayloadCustomField):
+ return
+ if cls.UNPACKER and name in cls.UNPACKER:
+ return cls.UNPACKER[name](val)
+ else:
+ return val
+
+ def __str__(self):
+ return pldstr(self)
+
+
+class MqttPayloadCustomField(abc.ABC):
+ def __init__(self, **kwargs):
+ for field in self.__class__.__annotations__:
+ setattr(self, field, kwargs[field])
+
+ @abc.abstractmethod
+ def __index__(self):
+ pass
+
+ @classmethod
+ @abc.abstractmethod
+ def unpack(cls, *args, **kwargs):
+ pass
+
+ def __str__(self):
+ return pldstr(self)
+
+
+def bit_field(seq_no: int, total_bits: int, bits: int):
+ return type(f'MQTTPayloadBitField_{seq_no}_{total_bits}_{bits}', (object,), {
+ 'seq_no': seq_no,
+ 'total_bits': total_bits,
+ 'bits': bits
+ })
+
+
+def _bit_field_params(cl) -> Optional[Tuple[int, ...]]:
+ match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__)
+ if match is not None:
+ return tuple([int(match.group(i)) for i in range(1, 4)])
+ return None \ No newline at end of file
diff --git a/include/py/homekit/mqtt/_util.py b/include/py/homekit/mqtt/_util.py
new file mode 100644
index 0000000..390d463
--- /dev/null
+++ b/include/py/homekit/mqtt/_util.py
@@ -0,0 +1,15 @@
+import os
+import re
+
+from typing import List
+
+
+def get_modules() -> List[str]:
+ modules = []
+ modules_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module')
+ for name in os.listdir(modules_dir):
+ if os.path.isdir(os.path.join(modules_dir, name)):
+ continue
+ name = re.sub(r'\.py$', '', name)
+ modules.append(name)
+ return modules
diff --git a/include/py/homekit/mqtt/_wrapper.py b/include/py/homekit/mqtt/_wrapper.py
new file mode 100644
index 0000000..5fc33fe
--- /dev/null
+++ b/include/py/homekit/mqtt/_wrapper.py
@@ -0,0 +1,81 @@
+import paho.mqtt.client as mqtt
+
+from ._mqtt import Mqtt
+from ._node import MqttNode
+from ..util import strgen
+
+
+class MqttWrapper(Mqtt):
+ _nodes: list[MqttNode]
+ _connect_callbacks: list[callable]
+ _disconnect_callbacks: list[callable]
+
+ def __init__(self,
+ client_id: str,
+ topic_prefix='hk',
+ randomize_client_id=False,
+ clean_session=True):
+ if randomize_client_id:
+ client_id += '_'+strgen(6)
+ super().__init__(clean_session=clean_session,
+ client_id=client_id)
+ self._nodes = []
+ self._connect_callbacks = []
+ self._disconnect_callbacks = []
+ self._topic_prefix = topic_prefix
+
+ def on_connect(self, client: mqtt.Client, userdata, flags, rc):
+ super().on_connect(client, userdata, flags, rc)
+ for node in self._nodes:
+ node.on_connect(self)
+ for f in self._connect_callbacks:
+ try:
+ f()
+ except Exception as e:
+ self._logger.exception(e)
+
+ def on_disconnect(self, client: mqtt.Client, userdata, rc):
+ super().on_disconnect(client, userdata, rc)
+ for node in self._nodes:
+ node.on_disconnect()
+ for f in self._disconnect_callbacks:
+ try:
+ f()
+ except Exception as e:
+ self._logger.exception(e)
+
+
+ def on_message(self, client: mqtt.Client, userdata, msg):
+ try:
+ topic = msg.topic
+ topic_node = topic[len(self._topic_prefix)+1:topic.find('/', len(self._topic_prefix)+1)]
+ for node in self._nodes:
+ if node.id in ('+', topic_node):
+ node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload)
+ except Exception as e:
+ self._logger.exception(str(e))
+
+ def add_connect_callback(self, f: callable):
+ self._connect_callbacks.append(f)
+
+ def add_disconnect_callback(self, f: callable):
+ self._disconnect_callbacks.append(f)
+
+ def add_node(self, node: MqttNode):
+ self._nodes.append(node)
+ if self._connected:
+ node.on_connect(self)
+
+ def subscribe(self,
+ node_id: str,
+ topic: str,
+ qos: int):
+ self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos)
+
+ def publish(self,
+ node_id: str,
+ topic: str,
+ payload: bytes,
+ qos: int):
+ self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos)
+ self._client.loop_write()
diff --git a/include/py/homekit/mqtt/module/diagnostics.py b/include/py/homekit/mqtt/module/diagnostics.py
new file mode 100644
index 0000000..5db5e99
--- /dev/null
+++ b/include/py/homekit/mqtt/module/diagnostics.py
@@ -0,0 +1,64 @@
+from .._payload import MqttPayload, MqttPayloadCustomField
+from .._node import MqttNode, MqttModule
+from typing import Optional
+
+MODULE_NAME = 'MqttDiagnosticsModule'
+
+
+class DiagnosticsFlags(MqttPayloadCustomField):
+ state: bool
+ config_changed_value_present: bool
+ config_changed: bool
+
+ @staticmethod
+ def unpack(flags: int):
+ # _logger.debug(f'StatFlags.unpack: flags={flags}')
+ state = flags & 0x1
+ ccvp = (flags >> 1) & 0x1
+ cc = (flags >> 2) & 0x1
+ # _logger.debug(f'StatFlags.unpack: state={state}')
+ return DiagnosticsFlags(state=(state == 1),
+ config_changed_value_present=(ccvp == 1),
+ config_changed=(cc == 1))
+
+ def __index__(self):
+ bits = 0
+ bits |= (int(self.state) & 0x1)
+ bits |= (int(self.config_changed_value_present) & 0x1) << 1
+ bits |= (int(self.config_changed) & 0x1) << 2
+ return bits
+
+
+class InitialDiagnosticsPayload(MqttPayload):
+ FORMAT = '=IBbIB'
+
+ ip: int
+ fw_version: int
+ rssi: int
+ free_heap: int
+ flags: DiagnosticsFlags
+
+
+class DiagnosticsPayload(MqttPayload):
+ FORMAT = '=bIB'
+
+ rssi: int
+ free_heap: int
+ flags: DiagnosticsFlags
+
+
+class MqttDiagnosticsModule(MqttModule):
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
+ for topic in ('diag', 'd1ag', 'stat', 'stat1'):
+ mqtt.subscribe_module(topic, self)
+
+ def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
+ message = None
+ if topic in ('stat', 'diag'):
+ message = DiagnosticsPayload.unpack(payload)
+ elif topic in ('stat1', 'd1ag'):
+ message = InitialDiagnosticsPayload.unpack(payload)
+ if message:
+ self._logger.debug(message)
+ return message
diff --git a/include/py/homekit/mqtt/module/inverter.py b/include/py/homekit/mqtt/module/inverter.py
new file mode 100644
index 0000000..29bde0a
--- /dev/null
+++ b/include/py/homekit/mqtt/module/inverter.py
@@ -0,0 +1,195 @@
+import time
+import json
+import datetime
+try:
+ import inverterd
+except:
+ pass
+
+from typing import Optional
+from .._module import MqttModule
+from .._node import MqttNode
+from .._payload import MqttPayload, bit_field
+try:
+ from homekit.database import InverterDatabase
+except:
+ pass
+
+_mult_10 = lambda n: int(n*10)
+_div_10 = lambda n: n/10
+
+
+MODULE_NAME = 'MqttInverterModule'
+
+STATUS_TOPIC = 'status'
+GENERATION_TOPIC = 'generation'
+
+
+class MqttInverterStatusPayload(MqttPayload):
+ # 46 bytes
+ FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH'
+
+ PACKER = {
+ 'grid_voltage': _mult_10,
+ 'grid_freq': _mult_10,
+ 'ac_output_voltage': _mult_10,
+ 'ac_output_freq': _mult_10,
+ 'battery_voltage': _mult_10,
+ 'battery_voltage_scc': _mult_10,
+ 'battery_voltage_scc2': _mult_10,
+ 'pv1_input_voltage': _mult_10,
+ 'pv2_input_voltage': _mult_10
+ }
+ UNPACKER = {
+ 'grid_voltage': _div_10,
+ 'grid_freq': _div_10,
+ 'ac_output_voltage': _div_10,
+ 'ac_output_freq': _div_10,
+ 'battery_voltage': _div_10,
+ 'battery_voltage_scc': _div_10,
+ 'battery_voltage_scc2': _div_10,
+ 'pv1_input_voltage': _div_10,
+ 'pv2_input_voltage': _div_10
+ }
+
+ time: int
+ grid_voltage: float
+ grid_freq: float
+ ac_output_voltage: float
+ ac_output_freq: float
+ ac_output_apparent_power: int
+ ac_output_active_power: int
+ output_load_percent: int
+ battery_voltage: float
+ battery_voltage_scc: float
+ battery_voltage_scc2: float
+ battery_discharge_current: int
+ battery_charge_current: int
+ battery_capacity: int
+ inverter_heat_sink_temp: int
+ mppt1_charger_temp: int
+ mppt2_charger_temp: int
+ pv1_input_power: int
+ pv2_input_power: int
+ pv1_input_voltage: float
+ pv2_input_voltage: float
+
+ # H
+ mppt1_charger_status: bit_field(0, 16, 2)
+ mppt2_charger_status: bit_field(0, 16, 2)
+ battery_power_direction: bit_field(0, 16, 2)
+ dc_ac_power_direction: bit_field(0, 16, 2)
+ line_power_direction: bit_field(0, 16, 2)
+ load_connected: bit_field(0, 16, 1)
+
+
+class MqttInverterGenerationPayload(MqttPayload):
+ # 8 bytes
+ FORMAT = 'II'
+
+ time: int
+ wh: int
+
+
+class MqttInverterModule(MqttModule):
+ _status_poll_freq: int
+ _generation_poll_freq: int
+ _inverter: Optional[inverterd.Client]
+ _database: Optional[InverterDatabase]
+ _gen_prev: float
+
+ def __init__(self, status_poll_freq=0, generation_poll_freq=0):
+ super().__init__(tick_interval=status_poll_freq)
+ self._status_poll_freq = status_poll_freq
+ self._generation_poll_freq = generation_poll_freq
+
+ # this defines whether this is a publisher or a subscriber
+ if status_poll_freq > 0:
+ self._inverter = inverterd.Client()
+ self._inverter.connect()
+ self._inverter.format(inverterd.Format.SIMPLE_JSON)
+ self._database = None
+ else:
+ self._inverter = None
+ self._database = InverterDatabase()
+
+ self._gen_prev = 0
+
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
+ if not self._inverter:
+ mqtt.subscribe_module(STATUS_TOPIC, self)
+ mqtt.subscribe_module(GENERATION_TOPIC, self)
+
+ def tick(self):
+ if not self._inverter:
+ return
+
+ # read status
+ now = time.time()
+ try:
+ raw = self._inverter.exec('get-status')
+ except inverterd.InverterError as e:
+ self._logger.error(f'inverter error: {str(e)}')
+ # TODO send to server
+ return
+
+ data = json.loads(raw)['data']
+ status = MqttInverterStatusPayload(time=round(now), **data)
+ self._mqtt_node_ref.publish(STATUS_TOPIC, status.pack())
+
+ # read today's generation stat
+ now = time.time()
+ if self._gen_prev == 0 or now - self._gen_prev >= self._generation_poll_freq:
+ self._gen_prev = now
+ today = datetime.date.today()
+ try:
+ raw = self._inverter.exec('get-day-generated', (today.year, today.month, today.day))
+ except inverterd.InverterError as e:
+ self._logger.error(f'inverter error: {str(e)}')
+ # TODO send to server
+ return
+
+ data = json.loads(raw)['data']
+ gen = MqttInverterGenerationPayload(time=round(now), wh=data['wh'])
+ self._mqtt_node_ref.publish(GENERATION_TOPIC, gen.pack())
+
+ def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
+ home_id = 1 # legacy compat
+
+ if topic == STATUS_TOPIC:
+ s = MqttInverterStatusPayload.unpack(payload)
+ self._database.add_status(home_id=home_id,
+ client_time=s.time,
+ grid_voltage=int(s.grid_voltage*10),
+ grid_freq=int(s.grid_freq * 10),
+ ac_output_voltage=int(s.ac_output_voltage * 10),
+ ac_output_freq=int(s.ac_output_freq * 10),
+ ac_output_apparent_power=s.ac_output_apparent_power,
+ ac_output_active_power=s.ac_output_active_power,
+ output_load_percent=s.output_load_percent,
+ battery_voltage=int(s.battery_voltage * 10),
+ battery_voltage_scc=int(s.battery_voltage_scc * 10),
+ battery_voltage_scc2=int(s.battery_voltage_scc2 * 10),
+ battery_discharge_current=s.battery_discharge_current,
+ battery_charge_current=s.battery_charge_current,
+ battery_capacity=s.battery_capacity,
+ inverter_heat_sink_temp=s.inverter_heat_sink_temp,
+ mppt1_charger_temp=s.mppt1_charger_temp,
+ mppt2_charger_temp=s.mppt2_charger_temp,
+ pv1_input_power=s.pv1_input_power,
+ pv2_input_power=s.pv2_input_power,
+ pv1_input_voltage=int(s.pv1_input_voltage * 10),
+ pv2_input_voltage=int(s.pv2_input_voltage * 10),
+ mppt1_charger_status=s.mppt1_charger_status,
+ mppt2_charger_status=s.mppt2_charger_status,
+ battery_power_direction=s.battery_power_direction,
+ dc_ac_power_direction=s.dc_ac_power_direction,
+ line_power_direction=s.line_power_direction,
+ load_connected=s.load_connected)
+ return s
+
+ elif topic == GENERATION_TOPIC:
+ gen = MqttInverterGenerationPayload.unpack(payload)
+ self._database.add_generation(home_id, gen.time, gen.wh)
+ return gen
diff --git a/include/py/homekit/mqtt/module/ota.py b/include/py/homekit/mqtt/module/ota.py
new file mode 100644
index 0000000..2f9b216
--- /dev/null
+++ b/include/py/homekit/mqtt/module/ota.py
@@ -0,0 +1,77 @@
+import hashlib
+
+from typing import Optional
+from .._payload import MqttPayload
+from .._node import MqttModule, MqttNode
+
+MODULE_NAME = 'MqttOtaModule'
+
+
+class OtaResultPayload(MqttPayload):
+ FORMAT = '=BB'
+ result: int
+ error_code: int
+
+
+class OtaPayload(MqttPayload):
+ secret: str
+ filename: str
+
+ # structure of returned data:
+ #
+ # uint8_t[len(secret)] secret;
+ # uint8_t[16] md5;
+ # *uint8_t data
+
+ def pack(self):
+ buf = bytearray(self.secret.encode())
+ m = hashlib.md5()
+ with open(self.filename, 'rb') as fd:
+ content = fd.read()
+ m.update(content)
+ buf.extend(m.digest())
+ buf.extend(content)
+ return buf
+
+ def unpack(cls, buf: bytes):
+ raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented')
+ # secret = buf[:12].decode()
+ # filename = buf[12:].decode()
+ # return OTAPayload(secret=secret, filename=filename)
+
+
+class MqttOtaModule(MqttModule):
+ _ota_request: Optional[tuple[str, int]]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._ota_request = None
+
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
+ mqtt.subscribe_module("otares", self)
+
+ if self._ota_request is not None:
+ filename, qos = self._ota_request
+ self._ota_request = None
+ self.do_push_ota(self._mqtt_node_ref.secret, filename, qos)
+
+ def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
+ if topic == 'otares':
+ message = OtaResultPayload.unpack(payload)
+ self._logger.debug(message)
+ return message
+
+ def do_push_ota(self, secret: str, filename: str, qos: int):
+ payload = OtaPayload(secret=secret, filename=filename)
+ self._mqtt_node_ref.publish('ota',
+ payload=payload.pack(),
+ qos=qos)
+
+ def push_ota(self,
+ filename: str,
+ qos: int):
+ if not self._initialized:
+ self._ota_request = (filename, qos)
+ else:
+ self.do_push_ota(self._mqtt_node_ref.secret, filename, qos)
diff --git a/include/py/homekit/mqtt/module/relay.py b/include/py/homekit/mqtt/module/relay.py
new file mode 100644
index 0000000..5cbe09b
--- /dev/null
+++ b/include/py/homekit/mqtt/module/relay.py
@@ -0,0 +1,91 @@
+import datetime
+
+from typing import Optional
+from .. import MqttModule, MqttPayload, MqttNode
+
+MODULE_NAME = 'MqttRelayModule'
+
+
+class MqttPowerSwitchPayload(MqttPayload):
+ FORMAT = '=12sB'
+ PACKER = {
+ 'state': lambda n: int(n),
+ 'secret': lambda s: s.encode('utf-8')
+ }
+ UNPACKER = {
+ 'state': lambda n: bool(n),
+ 'secret': lambda s: s.decode('utf-8')
+ }
+
+ secret: str
+ state: bool
+
+
+class MqttPowerStatusPayload(MqttPayload):
+ FORMAT = '=B'
+ PACKER = {
+ 'opened': lambda n: int(n),
+ }
+ UNPACKER = {
+ 'opened': lambda n: bool(n),
+ }
+
+ opened: bool
+
+
+class MqttRelayState:
+ enabled: bool
+ update_time: datetime.datetime
+ rssi: int
+ fw_version: int
+ ever_updated: bool
+
+ def __init__(self):
+ self.ever_updated = False
+ self.enabled = False
+ self.rssi = 0
+
+ def update(self,
+ enabled: bool,
+ rssi: int,
+ fw_version=None):
+ self.ever_updated = True
+ self.enabled = enabled
+ self.rssi = rssi
+ self.update_time = datetime.datetime.now()
+ if fw_version:
+ self.fw_version = fw_version
+
+
+class MqttRelayModule(MqttModule):
+ _legacy_topics: bool
+
+ def __init__(self, legacy_topics=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._legacy_topics = legacy_topics
+
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
+ mqtt.subscribe_module(self._get_switch_topic(), self)
+ mqtt.subscribe_module('relay/status', self)
+
+ def switchpower(self, enable: bool):
+ payload = MqttPowerSwitchPayload(secret=self._mqtt_node_ref.secret,
+ state=enable)
+ self._mqtt_node_ref.publish(self._get_switch_topic(),
+ payload=payload.pack())
+
+ def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
+ message = None
+
+ if topic == self._get_switch_topic():
+ message = MqttPowerSwitchPayload.unpack(payload)
+ elif topic == 'relay/status':
+ message = MqttPowerStatusPayload.unpack(payload)
+
+ if message is not None:
+ self._logger.debug(message)
+ return message
+
+ def _get_switch_topic(self) -> str:
+ return 'relay/power' if self._legacy_topics else 'relay/switch'
diff --git a/include/py/homekit/mqtt/module/temphum.py b/include/py/homekit/mqtt/module/temphum.py
new file mode 100644
index 0000000..6deccfe
--- /dev/null
+++ b/include/py/homekit/mqtt/module/temphum.py
@@ -0,0 +1,73 @@
+from .._node import MqttNode
+from .._module import MqttModule
+from .._payload import MqttPayload
+from typing import Optional
+from ...temphum import BaseSensor
+
+two_digits_precision = lambda x: round(x, 2)
+
+MODULE_NAME = 'MqttTempHumModule'
+DATA_TOPIC = 'temphum/data'
+
+
+class MqttTemphumLegacyDataPayload(MqttPayload):
+ FORMAT = '=dd'
+ UNPACKER = {
+ 'temp': two_digits_precision,
+ 'rh': two_digits_precision
+ }
+
+ temp: float
+ rh: float
+
+
+class MqttTemphumDataPayload(MqttTemphumLegacyDataPayload):
+ FORMAT = '=ddb'
+ error: int
+
+
+class MqttTempHumModule(MqttModule):
+ _legacy_payload: bool
+
+ def __init__(self,
+ sensor: Optional[BaseSensor] = None,
+ legacy_payload=False,
+ write_to_database=False,
+ *args, **kwargs):
+ if sensor is not None:
+ kwargs['tick_interval'] = 10
+ super().__init__(*args, **kwargs)
+ self._sensor = sensor
+ self._legacy_payload = legacy_payload
+
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
+ mqtt.subscribe_module(DATA_TOPIC, self)
+
+ def tick(self):
+ if not self._sensor:
+ return
+
+ error = 0
+ temp = 0
+ rh = 0
+ try:
+ temp = self._sensor.temperature()
+ rh = self._sensor.humidity()
+ except:
+ error = 1
+ pld = self._get_data_payload_cls()(temp=temp, rh=rh, error=error)
+ self._mqtt_node_ref.publish(DATA_TOPIC, pld.pack())
+
+ def handle_payload(self,
+ mqtt: MqttNode,
+ topic: str,
+ payload: bytes) -> Optional[MqttPayload]:
+ if topic == DATA_TOPIC:
+ message = self._get_data_payload_cls().unpack(payload)
+ self._logger.debug(message)
+ return message
+
+ def _get_data_payload_cls(self):
+ return MqttTemphumLegacyDataPayload if self._legacy_payload else MqttTemphumDataPayload
+