diff options
author | Evgeny Zinoviev <me@ch1p.io> | 2023-05-31 09:22:00 +0300 |
---|---|---|
committer | Evgeny Zinoviev <me@ch1p.io> | 2023-05-31 09:22:00 +0300 |
commit | c976495222858c4921454c9294ff73794ae56277 (patch) | |
tree | ea605604f2e8fb2108e01074a1bfb9de93f1e93f | |
parent | b02a9c5473267da88a9182a5b06753f62b689042 (diff) |
wip
24 files changed, 443 insertions, 394 deletions
diff --git a/platformio/common/libs/mqtt/homekit/mqtt/mqtt.cpp b/platformio/common/libs/mqtt/homekit/mqtt/mqtt.cpp index cb2cea7..16f4675 100644 --- a/platformio/common/libs/mqtt/homekit/mqtt/mqtt.cpp +++ b/platformio/common/libs/mqtt/homekit/mqtt/mqtt.cpp @@ -55,13 +55,9 @@ Mqtt::Mqtt() { } } -// if (ota.readyToRestart) { -// restartTimer.once(1, restart); -// } else { - reconnectTimer.once(2, [&]() { - reconnect(); - }); -// } + reconnectTimer.once(2, [&]() { + reconnect(); + }); }); client.onSubscribe([&](uint16_t packetId, const SubscribeReturncode* returncodes, size_t len) { @@ -79,7 +75,7 @@ Mqtt::Mqtt() { PRINTF("mqtt: message received, topic=%s, qos=%d, dup=%d, retain=%d, len=%ul, index=%ul, total=%ul\n", topic, properties.qos, (int)properties.dup, (int)properties.retain, len, index, total); - const char *ptr = topic + nodeId.length() + 10; + const char *ptr = topic + nodeId.length() + 4; String relevantTopic(ptr); auto it = moduleSubscriptions.find(relevantTopic); @@ -87,7 +83,7 @@ Mqtt::Mqtt() { auto module = it->second; module->handlePayload(*this, relevantTopic, properties.packetId, payload, len, index, total); } else { - PRINTF("error: module subscription for topic %s not found\n", topic); + PRINTF("error: module subscription for topic %s not found\n", relevantTopic.c_str()); } }); diff --git a/src/esp_mqtt_util.py b/src/esp_mqtt_util.py deleted file mode 100755 index 263128c..0000000 --- a/src/esp_mqtt_util.py +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env python3 -from typing import Optional -from argparse import ArgumentParser -from enum import Enum - -from home.config import config -from home.mqtt import MqttRelay -from home.mqtt.esp import MqttEspBase -from home.mqtt.temphum import MqttTempHum -from home.mqtt.esp import MqttEspDevice - -mqtt_client: Optional[MqttEspBase] = None - - -class NodeType(Enum): - RELAY = 'relay' - TEMPHUM = 'temphum' - - -if __name__ == '__main__': - parser = ArgumentParser() - parser.add_argument('--device-id', type=str, required=True) - parser.add_argument('--type', type=str, required=True, - choices=[i.name.lower() for i in NodeType]) - - config.load('mqtt_util', parser=parser) - arg = parser.parse_args() - - mqtt_node_type = NodeType(arg.type) - devices = MqttEspDevice(id=arg.device_id) - - if mqtt_node_type == NodeType.RELAY: - mqtt_client = MqttRelay(devices=devices) - elif mqtt_node_type == NodeType.TEMPHUM: - mqtt_client = MqttTempHum(devices=devices) - - mqtt_client.set_message_callback(lambda device_id, payload: print(payload)) - mqtt_client.configure_tls() - try: - mqtt_client.connect_and_loop() - except KeyboardInterrupt: - mqtt_client.disconnect() diff --git a/src/home/media/__init__.py b/src/home/media/__init__.py index 976c990..6923105 100644 --- a/src/home/media/__init__.py +++ b/src/home/media/__init__.py @@ -12,6 +12,7 @@ __map__ = { __all__ = list(itertools.chain(*__map__.values())) + def __getattr__(name): if name in __all__: for file, names in __map__.items(): diff --git a/src/home/mqtt/__init__.py b/src/home/mqtt/__init__.py index 982e2b6..3fbd744 100644 --- a/src/home/mqtt/__init__.py +++ b/src/home/mqtt/__init__.py @@ -1,4 +1,8 @@ -from .mqtt import MqttBase -from .util import poll_tick -from .relay import MqttRelay, MqttRelayState -from .temphum import MqttTempHum
\ No newline at end of file +from .mqtt import MqttBase, MqttPayload, MqttPayloadCustomField +from ._node import MqttNode +from ._module import MqttModule +from .util import ( + poll_tick, + get_modules as get_mqtt_modules, + import_module as import_mqtt_module +)
\ No newline at end of file diff --git a/src/home/mqtt/_module.py b/src/home/mqtt/_module.py new file mode 100644 index 0000000..949c344 --- /dev/null +++ b/src/home/mqtt/_module.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import abc +import logging + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ._node import MqttNode + + +class MqttModule(abc.ABC): + tick_interval: int + _initialized: bool + + def __init__(self, tick_interval=0): + self.tick_interval = tick_interval + self._initialized = False + self._logger = logging.getLogger(self.__class__.__name__) + + def init(self, mqtt: MqttNode): + pass + + def is_initialized(self): + return self._initialized + + def set_initialized(self): + self._initialized = True + + def tick(self, mqtt: MqttNode): + pass + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes): + pass diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py new file mode 100644 index 0000000..c76610f --- /dev/null +++ b/src/home/mqtt/_node.py @@ -0,0 +1,87 @@ +import paho.mqtt.client as mqtt + +from .mqtt import MqttBase +from typing import List +from ._module import MqttModule + + +class MqttNode(MqttBase): + _modules: List[MqttModule] + _module_subscriptions: dict[str, MqttModule] + _node_id: str + # _devices: list[MqttEspDevice] + # _message_callback: Optional[callable] + # _ota_publish_callback: Optional[callable] + + def __init__(self, + node_id: str, + # devices: Union[MqttEspDevice, list[MqttEspDevice]], + subscribe_to_updates=True): + super().__init__(clean_session=True) + self._modules = [] + self._module_subscriptions = {} + self._node_id = node_id + # if not isinstance(devices, list): + # devices = [devices] + # self._devices = devices + # self._message_callback = None + # self._ota_publish_callback = None + # self._subscribe_to_updates = subscribe_to_updates + # self._ota_mid = None + + def on_connect(self, client: mqtt.Client, userdata, flags, rc): + super().on_connect(client, userdata, flags, rc) + for module in self._modules: + if not module.is_initialized(): + module.init(self) + module.set_initialized() + + def on_publish(self, client: mqtt.Client, userdata, mid): + pass # FIXME + # if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback: + # self._ota_publish_callback() + + def on_message(self, client: mqtt.Client, userdata, msg): + try: + topic = msg.topic + actual_topic = topic[len(f'hk/{self._node_id}/'):] + + if actual_topic in self._module_subscriptions: + self._module_subscriptions[actual_topic].handle_payload(self, actual_topic, msg.payload) + + except Exception as e: + self._logger.exception(str(e)) + + # def push_ota(self, + # device_id, + # filename: str, + # publish_callback: callable, + # qos: int): + # device = next(d for d in self._devices if d.id == device_id) + # assert device.secret is not None, 'device secret not specified' + # + # self._ota_publish_callback = publish_callback + # payload = OtaPayload(secret=device.secret, filename=filename) + # publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota', + # payload=payload.pack(), + # qos=qos) + # self._ota_mid = publish_result.mid + # self._client.loop_write() + # + # @classmethod + # def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None): + # return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$' + + def add_module(self, module: MqttModule): + self._modules.append(module) + if self._connected: + module.init(self) + module.set_initialized() + + def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1): + self._module_subscriptions[topic] = module + self._client.subscribe(f'hk/{self._node_id}/{topic}', qos) + + def publish(self, topic: str, payload: bytes, qos: int = 1): + self._client.publish(f'hk/{self._node_id}/{topic}', payload, qos) + self._client.loop_write() diff --git a/src/home/mqtt/payload/base_payload.py b/src/home/mqtt/_payload.py index 1abd898..58eeae3 100644 --- a/src/home/mqtt/payload/base_payload.py +++ b/src/home/mqtt/_payload.py @@ -1,5 +1,5 @@ -import abc import struct +import abc import re from typing import Optional, Tuple @@ -142,4 +142,4 @@ def _bit_field_params(cl) -> Optional[Tuple[int, ...]]: match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__) if match is not None: return tuple([int(match.group(i)) for i in range(1, 4)]) - return None + return None
\ No newline at end of file diff --git a/src/home/mqtt/esp.py b/src/home/mqtt/esp.py deleted file mode 100644 index 56ced83..0000000 --- a/src/home/mqtt/esp.py +++ /dev/null @@ -1,106 +0,0 @@ -import re -import paho.mqtt.client as mqtt - -from .mqtt import MqttBase -from typing import Optional, Union -from .payload.esp import ( - OTAPayload, - OTAResultPayload, - DiagnosticsPayload, - InitialDiagnosticsPayload -) - - -class MqttEspDevice: - id: str - secret: Optional[str] - - def __init__(self, id: str, secret: Optional[str] = None): - self.id = id - self.secret = secret - - -class MqttEspBase(MqttBase): - _devices: list[MqttEspDevice] - _message_callback: Optional[callable] - _ota_publish_callback: Optional[callable] - - TOPIC_LEAF = 'esp' - - def __init__(self, - devices: Union[MqttEspDevice, list[MqttEspDevice]], - subscribe_to_updates=True): - super().__init__(clean_session=True) - if not isinstance(devices, list): - devices = [devices] - self._devices = devices - self._message_callback = None - self._ota_publish_callback = None - self._subscribe_to_updates = subscribe_to_updates - self._ota_mid = None - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - super().on_connect(client, userdata, flags, rc) - - if self._subscribe_to_updates: - for device in self._devices: - topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#' - self._logger.debug(f"subscribing to {topic}") - client.subscribe(topic, qos=1) - - def on_publish(self, client: mqtt.Client, userdata, mid): - if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback: - self._ota_publish_callback() - - def set_message_callback(self, callback: callable): - self._message_callback = callback - - def on_message(self, client: mqtt.Client, userdata, msg): - try: - match = re.match(self.get_mqtt_topics(), msg.topic) - self._logger.debug(f'topic: {msg.topic}') - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - # try: - next(d for d in self._devices if d.id == device_id) - # except StopIteration:h - # return - - message = None - if subtopic == 'stat': - message = DiagnosticsPayload.unpack(msg.payload) - elif subtopic == 'stat1': - message = InitialDiagnosticsPayload.unpack(msg.payload) - elif subtopic == 'otares': - message = OTAResultPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - return True - - except Exception as e: - self._logger.exception(str(e)) - - def push_ota(self, - device_id, - filename: str, - publish_callback: callable, - qos: int): - device = next(d for d in self._devices if d.id == device_id) - assert device.secret is not None, 'device secret not specified' - - self._ota_publish_callback = publish_callback - payload = OTAPayload(secret=device.secret, filename=filename) - publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota', - payload=payload.pack(), - qos=qos) - self._ota_mid = publish_result.mid - self._client.loop_write() - - @classmethod - def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None): - return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'
\ No newline at end of file diff --git a/src/home/mqtt/payload/esp.py b/src/home/mqtt/module/diagnostics.py index 171cdb9..8b5ea16 100644 --- a/src/home/mqtt/payload/esp.py +++ b/src/home/mqtt/module/diagnostics.py @@ -1,39 +1,7 @@ -import hashlib +from ..mqtt import MqttPayload, MqttPayloadCustomField +from .._node import MqttNode, MqttModule -from .base_payload import MqttPayload, MqttPayloadCustomField - - -class OTAResultPayload(MqttPayload): - FORMAT = '=BB' - result: int - error_code: int - - -class OTAPayload(MqttPayload): - secret: str - filename: str - - # structure of returned data: - # - # uint8_t[len(secret)] secret; - # uint8_t[16] md5; - # *uint8_t data - - def pack(self): - buf = bytearray(self.secret.encode()) - m = hashlib.md5() - with open(self.filename, 'rb') as fd: - content = fd.read() - m.update(content) - buf.extend(m.digest()) - buf.extend(content) - return buf - - def unpack(cls, buf: bytes): - raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') - # secret = buf[:12].decode() - # filename = buf[12:].decode() - # return OTAPayload(secret=secret, filename=filename) +MODULE_NAME = 'MqttDiagnosticsModule' class DiagnosticsFlags(MqttPayloadCustomField): @@ -76,3 +44,16 @@ class DiagnosticsPayload(MqttPayload): rssi: int free_heap: int flags: DiagnosticsFlags + + +class MqttDiagnosticsModule(MqttModule): + def init(self, mqtt: MqttNode): + for topic in ('diag', 'd1ag', 'stat', 'stat1'): + mqtt.subscribe_module(topic, self) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes): + if topic in ('stat', 'diag'): + message = DiagnosticsPayload.unpack(payload) + elif topic in ('stat1', 'd1ag'): + message = InitialDiagnosticsPayload.unpack(payload) + self._logger.debug(message)
\ No newline at end of file diff --git a/src/home/mqtt/payload/inverter.py b/src/home/mqtt/module/inverter.py index 09388df..9cf2978 100644 --- a/src/home/mqtt/payload/inverter.py +++ b/src/home/mqtt/module/inverter.py @@ -1,7 +1,7 @@ import struct -from .base_payload import MqttPayload, bit_field -from typing import Tuple +from .._node import MqttNode +from .._payload import MqttPayload, bit_field _mult_10 = lambda n: int(n*10) _div_10 = lambda n: n/10 @@ -71,3 +71,7 @@ class Generation(MqttPayload): time: int wh: int + + +class MqttInverterModule(MqttNode): + pass diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py new file mode 100644 index 0000000..1d472d1 --- /dev/null +++ b/src/home/mqtt/module/ota.py @@ -0,0 +1,65 @@ +import hashlib + +from ..mqtt import MqttPayload +from .._node import MqttModule, MqttNode + +MODULE_NAME = 'MqttOtaModule' + + +class OtaResultPayload(MqttPayload): + FORMAT = '=BB' + result: int + error_code: int + + +class OtaPayload(MqttPayload): + secret: str + filename: str + + # structure of returned data: + # + # uint8_t[len(secret)] secret; + # uint8_t[16] md5; + # *uint8_t data + + def pack(self): + buf = bytearray(self.secret.encode()) + m = hashlib.md5() + with open(self.filename, 'rb') as fd: + content = fd.read() + m.update(content) + buf.extend(m.digest()) + buf.extend(content) + return buf + + def unpack(cls, buf: bytes): + raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') + # secret = buf[:12].decode() + # filename = buf[12:].decode() + # return OTAPayload(secret=secret, filename=filename) + + +class MqttOtaModule(MqttModule): + def init(self, mqtt: MqttNode): + mqtt.subscribe_module("otares", self) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes): + if topic == 'otares': + message = OtaResultPayload.unpack(payload) + self._logger.debug(message) + + # def push_ota(self, + # node_id, + # filename: str, + # publish_callback: callable, + # qos: int): + # device = next(d for d in self._devices if d.id == device_id) + # assert device.secret is not None, 'device secret not specified' + # + # self._ota_publish_callback = publish_callback + # payload = OtaPayload(secret=device.secret, filename=filename) + # publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota', + # payload=payload.pack(), + # qos=qos) + # self._ota_mid = publish_result.mid + # self._client.loop_write()
\ No newline at end of file diff --git a/src/home/mqtt/module/relay.py b/src/home/mqtt/module/relay.py new file mode 100644 index 0000000..16877f6 --- /dev/null +++ b/src/home/mqtt/module/relay.py @@ -0,0 +1,65 @@ +import paho.mqtt.client as mqtt +import re +import datetime + +from .. import MqttModule, MqttPayload, MqttNode + +MODULE_NAME = 'MqttRelayModule' + + +class MqttPowerSwitchPayload(MqttPayload): + FORMAT = '=12sB' + PACKER = { + 'state': lambda n: int(n), + 'secret': lambda s: s.encode('utf-8') + } + UNPACKER = { + 'state': lambda n: bool(n), + 'secret': lambda s: s.decode('utf-8') + } + + secret: str + state: bool + + +class MqttRelayState: + enabled: bool + update_time: datetime.datetime + rssi: int + fw_version: int + ever_updated: bool + + def __init__(self): + self.ever_updated = False + self.enabled = False + self.rssi = 0 + + def update(self, + enabled: bool, + rssi: int, + fw_version=None): + self.ever_updated = True + self.enabled = enabled + self.rssi = rssi + self.update_time = datetime.datetime.now() + if fw_version: + self.fw_version = fw_version + + +class MqttRelayModule(MqttModule): + def init(self, mqtt: MqttNode): + mqtt.subscribe_module('relay/switch', self) + + @staticmethod + def switchpower(mqtt: MqttNode, + enable: bool, + secret: str): + payload = MqttPowerSwitchPayload(secret=secret, state=enable) + mqtt.publish('relay/switch', payload=payload.pack()) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes): + if topic != 'relay/switch': + return + + message = MqttPowerSwitchPayload.unpack(payload) + self._logger.debug(message)
\ No newline at end of file diff --git a/src/home/mqtt/module/temphum.py b/src/home/mqtt/module/temphum.py new file mode 100644 index 0000000..e1c4567 --- /dev/null +++ b/src/home/mqtt/module/temphum.py @@ -0,0 +1,55 @@ +from enum import auto +from .._node import MqttNode +from .._module import MqttModule +from .._payload import MqttPayload +from ...util import HashableEnum + +two_digits_precision = lambda x: round(x, 2) + +MODULE_NAME = 'MqttTempHumModule' + + +class TempHumDataPayload(MqttPayload): + FORMAT = '=ddb' + UNPACKER = { + 'temp': two_digits_precision, + 'rh': two_digits_precision + } + + temp: float + rh: float + error: int + + +class MqttTempHumNodes(HashableEnum): + KBN_SH_HALL = auto() + KBN_SH_BATHROOM = auto() + KBN_SH_LIVINGROOM = auto() + KBN_SH_BEDROOM = auto() + + KBN_BH_2FL = auto() + KBN_BH_2FL_STREET = auto() + KBN_BH_1FL_LIVINGROOM = auto() + KBN_BH_1FL_BEDROOM = auto() + KBN_BH_1FL_BATHROOM = auto() + + KBN_NH_1FL_INV = auto() + KBN_NH_1FL_CENTER = auto() + KBN_NH_1LF_KT = auto() + KBN_NH_1FL_DS = auto() + KBN_NH_1FS_EZ = auto() + + SPB_FLAT120_CABINET = auto() + + +class MqttTempHumModule(MqttModule): + def init(self, mqtt: MqttNode): + mqtt.subscribe_module('temphum/data', self) + + def handle_payload(self, + mqtt: MqttNode, + topic: str, + payload: bytes): + if topic == 'temphum/data': + message = TempHumDataPayload.unpack(payload) + self._logger.debug(message)
\ No newline at end of file diff --git a/src/home/mqtt/mqtt.py b/src/home/mqtt/mqtt.py index 4acd4f6..fad5d26 100644 --- a/src/home/mqtt/mqtt.py +++ b/src/home/mqtt/mqtt.py @@ -3,8 +3,8 @@ import paho.mqtt.client as mqtt import ssl import logging -from typing import Tuple from ..config import config +from ._payload import * def username_and_password() -> Tuple[str, str]: @@ -14,6 +14,8 @@ def username_and_password() -> Tuple[str, str]: class MqttBase: + _connected: bool + def __init__(self, clean_session=True): self._client = mqtt.Client(client_id=config['mqtt']['client_id'], protocol=mqtt.MQTTv311, @@ -24,6 +26,7 @@ class MqttBase: self._client.on_log = self.on_log self._client.on_publish = self.on_publish self._loop_started = False + self._connected = False self._logger = logging.getLogger(self.__class__.__name__) @@ -41,7 +44,9 @@ class MqttBase: 'assets', 'mqtt_ca.crt' )) - self._client.tls_set(ca_certs=ca_certs, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_TLSv1_2) + self._client.tls_set(ca_certs=ca_certs, + cert_reqs=ssl.CERT_REQUIRED, + tls_version=ssl.PROTOCOL_TLSv1_2) def connect_and_loop(self, loop_forever=True): host = config['mqtt']['host'] @@ -61,9 +66,11 @@ class MqttBase: def on_connect(self, client: mqtt.Client, userdata, flags, rc): self._logger.info("Connected with result code " + str(rc)) + self._connected = True def on_disconnect(self, client: mqtt.Client, userdata, rc): self._logger.info("Disconnected with result code " + str(rc)) + self._connected = False def on_log(self, client: mqtt.Client, userdata, level, buf): level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO @@ -73,4 +80,15 @@ class MqttBase: self._logger.debug(msg.topic + ": " + str(msg.payload)) def on_publish(self, client: mqtt.Client, userdata, mid): - self._logger.debug(f'publish done, mid={mid}')
\ No newline at end of file + self._logger.debug(f'publish done, mid={mid}') + + +class MqttEspDevice: + id: str + secret: Optional[str] + + def __init__(self, + node_id: str, + secret: Optional[str] = None): + self.id = node_id + self.secret = secret diff --git a/src/home/mqtt/payload/__init__.py b/src/home/mqtt/payload/__init__.py deleted file mode 100644 index eee6709..0000000 --- a/src/home/mqtt/payload/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base_payload import MqttPayload
\ No newline at end of file diff --git a/src/home/mqtt/payload/relay.py b/src/home/mqtt/payload/relay.py deleted file mode 100644 index 4902991..0000000 --- a/src/home/mqtt/payload/relay.py +++ /dev/null @@ -1,22 +0,0 @@ -from .base_payload import MqttPayload -from .esp import ( - OTAResultPayload, - OTAPayload, - InitialDiagnosticsPayload, - DiagnosticsPayload -) - - -class PowerPayload(MqttPayload): - FORMAT = '=12sB' - PACKER = { - 'state': lambda n: int(n), - 'secret': lambda s: s.encode('utf-8') - } - UNPACKER = { - 'state': lambda n: bool(n), - 'secret': lambda s: s.decode('utf-8') - } - - secret: str - state: bool diff --git a/src/home/mqtt/payload/sensors.py b/src/home/mqtt/payload/sensors.py deleted file mode 100644 index f99b307..0000000 --- a/src/home/mqtt/payload/sensors.py +++ /dev/null @@ -1,20 +0,0 @@ -from .base_payload import MqttPayload - -_mult_100 = lambda n: int(n*100) -_div_100 = lambda n: n/100 - - -class Temperature(MqttPayload): - FORMAT = 'IhH' - PACKER = { - 'temp': _mult_100, - 'rh': _mult_100, - } - UNPACKER = { - 'temp': _div_100, - 'rh': _div_100, - } - - time: int - temp: float - rh: float diff --git a/src/home/mqtt/payload/temphum.py b/src/home/mqtt/payload/temphum.py deleted file mode 100644 index c0b744e..0000000 --- a/src/home/mqtt/payload/temphum.py +++ /dev/null @@ -1,15 +0,0 @@ -from .base_payload import MqttPayload - -two_digits_precision = lambda x: round(x, 2) - - -class TempHumDataPayload(MqttPayload): - FORMAT = '=ddb' - UNPACKER = { - 'temp': two_digits_precision, - 'rh': two_digits_precision - } - - temp: float - rh: float - error: int diff --git a/src/home/mqtt/relay.py b/src/home/mqtt/relay.py deleted file mode 100644 index a90f19c..0000000 --- a/src/home/mqtt/relay.py +++ /dev/null @@ -1,71 +0,0 @@ -import paho.mqtt.client as mqtt -import re -import datetime - -from .payload.relay import ( - PowerPayload, -) -from .esp import MqttEspBase - - -class MqttRelay(MqttEspBase): - TOPIC_LEAF = 'relay' - - def set_power(self, device_id, enable: bool, secret=None): - device = next(d for d in self._devices if d.id == device_id) - secret = secret if secret else device.secret - - assert secret is not None, 'device secret not specified' - - payload = PowerPayload(secret=secret, - state=enable) - self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/power', - payload=payload.pack(), - qos=1) - self._client.loop_write() - - def on_message(self, client: mqtt.Client, userdata, msg): - if super().on_message(client, userdata, msg): - return - - try: - match = re.match(self.get_mqtt_topics(['power']), msg.topic) - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - message = None - if subtopic == 'power': - message = PowerPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - - except Exception as e: - self._logger.exception(str(e)) - - -class MqttRelayState: - enabled: bool - update_time: datetime.datetime - rssi: int - fw_version: int - ever_updated: bool - - def __init__(self): - self.ever_updated = False - self.enabled = False - self.rssi = 0 - - def update(self, - enabled: bool, - rssi: int, - fw_version=None): - self.ever_updated = True - self.enabled = enabled - self.rssi = rssi - self.update_time = datetime.datetime.now() - if fw_version: - self.fw_version = fw_version diff --git a/src/home/mqtt/temphum.py b/src/home/mqtt/temphum.py deleted file mode 100644 index 44810ef..0000000 --- a/src/home/mqtt/temphum.py +++ /dev/null @@ -1,54 +0,0 @@ -import paho.mqtt.client as mqtt -import re - -from enum import auto -from .payload.temphum import TempHumDataPayload -from .esp import MqttEspBase -from ..util import HashableEnum - - -class MqttTempHumNodes(HashableEnum): - KBN_SH_HALL = auto() - KBN_SH_BATHROOM = auto() - KBN_SH_LIVINGROOM = auto() - KBN_SH_BEDROOM = auto() - - KBN_BH_2FL = auto() - KBN_BH_2FL_STREET = auto() - KBN_BH_1FL_LIVINGROOM = auto() - KBN_BH_1FL_BEDROOM = auto() - KBN_BH_1FL_BATHROOM = auto() - - KBN_NH_1FL_INV = auto() - KBN_NH_1FL_CENTER = auto() - KBN_NH_1LF_KT = auto() - KBN_NH_1FL_DS = auto() - KBN_NH_1FS_EZ = auto() - - SPB_FLAT120_CABINET = auto() - - -class MqttTempHum(MqttEspBase): - TOPIC_LEAF = 'temphum' - - def on_message(self, client: mqtt.Client, userdata, msg): - if super().on_message(client, userdata, msg): - return - - try: - match = re.match(self.get_mqtt_topics(['data']), msg.topic) - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - message = None - if subtopic == 'data': - message = TempHumDataPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - - except Exception as e: - self._logger.exception(str(e)) diff --git a/src/home/mqtt/util.py b/src/home/mqtt/util.py index f71ffd8..78cbcaa 100644 --- a/src/home/mqtt/util.py +++ b/src/home/mqtt/util.py @@ -1,4 +1,9 @@ import time +import os +import re +import importlib + +from typing import List def poll_tick(freq): @@ -6,3 +11,16 @@ def poll_tick(freq): while True: t += freq yield max(t - time.time(), 0) + + +def get_modules() -> List[str]: + modules = [] + for name in os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module')): + name = re.sub(r'\.py$', '', name) + modules.append(name) + return modules + + +def import_module(module: str): + return importlib.import_module( + f'..module.{module}', __name__)
\ No newline at end of file diff --git a/src/home/pio/products.py b/src/home/pio/products.py index 7649078..388da03 100644 --- a/src/home/pio/products.py +++ b/src/home/pio/products.py @@ -16,10 +16,6 @@ _products_dir = os.path.join( def get_products(): products = [] for f in os.listdir(_products_dir): - # temp hack - if f.endswith('-esp01'): - continue - # skip the common dir if f in ('common',): continue diff --git a/src/mqtt_node_util.py b/src/mqtt_node_util.py new file mode 100755 index 0000000..674b60c --- /dev/null +++ b/src/mqtt_node_util.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +from typing import Optional +from argparse import ArgumentParser, ArgumentError + +from home.config import config +from home.mqtt import MqttNode, get_mqtt_modules, import_mqtt_module, MqttModule + +mqtt: Optional[MqttNode] = None + + +def add_module(module: str) -> MqttModule: + module = import_mqtt_module(module) + if not hasattr(module, 'MODULE_NAME'): + raise RuntimeError(f'MODULE_NAME not found in module {m}') + cl = getattr(module, getattr(module, 'MODULE_NAME')) + instance = cl() + mqtt.add_module(instance) + return instance + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--node-id', type=str, required=True) + parser.add_argument('--modules', type=str, choices=get_mqtt_modules(), nargs='*', + help='mqtt modules to include') + parser.add_argument('--switch-relay', choices=[0, 1], type=int, + help='send relay state') + parser.add_argument('--switch-relay-secret', type=str, + help='secret password to switch relay') + + config.load('mqtt_util', parser=parser) + arg = parser.parse_args() + + if (arg.switch_relay is not None or arg.switch_relay_secret is not None) and 'relay' not in arg.modules: + raise ArgumentError(None, '--relay is only allowed when \'relay\' module included in --modules') + + if (arg.switch_relay is not None and arg.switch_relay_secret is None) or (arg.switch_relay is None and arg.switch_relay_secret is not None): + raise ArgumentError(None, 'both --switch-relay and --switch-relay-secret are required') + + mqtt = MqttNode(node_id=arg.node_id) + + # must-have modules + add_module('ota') + add_module('diagnostics') + + if arg.modules: + for m in arg.modules: + module_instance = add_module(m) + if m == 'relay' and arg.switch_relay is not None: + module_instance.switchpower(mqtt, + arg.switch_relay == 1, + arg.switch_relay_secret) + + mqtt.configure_tls() + try: + mqtt.connect_and_loop() + except KeyboardInterrupt: + mqtt.disconnect() diff --git a/src/pump_mqtt_bot.py b/src/pump_mqtt_bot.py index d3b6de4..86d87d3 100755 --- a/src/pump_mqtt_bot.py +++ b/src/pump_mqtt_bot.py @@ -8,10 +8,9 @@ from telegram import ReplyKeyboardMarkup, User from home.config import config from home.telegram import bot from home.telegram._botutil import user_any_name -from home.mqtt.esp import MqttEspDevice -from home.mqtt import MqttRelay, MqttRelayState -from home.mqtt.payload import MqttPayload -from home.mqtt.payload.relay import InitialDiagnosticsPayload, DiagnosticsPayload +from home.mqtt import MqttEspDevice, MqttPayload +from home.mqtt.module.relay import MqttRelayState +from home.mqtt.module.diagnostics import InitialDiagnosticsPayload, DiagnosticsPayload config.load('pump_mqtt_bot') |