diff options
Diffstat (limited to 'src/home/mqtt')
-rw-r--r-- | src/home/mqtt/__init__.py | 5 | ||||
-rw-r--r-- | src/home/mqtt/esp.py | 106 | ||||
-rw-r--r-- | src/home/mqtt/mqtt.py | 2 | ||||
-rw-r--r-- | src/home/mqtt/payload/__init__.py | 2 | ||||
-rw-r--r-- | src/home/mqtt/payload/base_payload.py | 28 | ||||
-rw-r--r-- | src/home/mqtt/payload/esp.py | 78 | ||||
-rw-r--r-- | src/home/mqtt/payload/inverter.py | 6 | ||||
-rw-r--r-- | src/home/mqtt/payload/relay.py | 90 | ||||
-rw-r--r-- | src/home/mqtt/payload/sensors.py | 4 | ||||
-rw-r--r-- | src/home/mqtt/payload/temphum.py | 14 | ||||
-rw-r--r-- | src/home/mqtt/relay.py | 107 | ||||
-rw-r--r-- | src/home/mqtt/temphum.py | 33 |
12 files changed, 292 insertions, 183 deletions
diff --git a/src/home/mqtt/__init__.py b/src/home/mqtt/__init__.py index c9a6c6e..982e2b6 100644 --- a/src/home/mqtt/__init__.py +++ b/src/home/mqtt/__init__.py @@ -1,3 +1,4 @@ -from .mqtt import MQTTBase +from .mqtt import MqttBase from .util import poll_tick -from .relay import MQTTRelay, MQTTRelayState, MQTTRelayDevice
\ No newline at end of file +from .relay import MqttRelay, MqttRelayState +from .temphum import MqttTempHum
\ No newline at end of file diff --git a/src/home/mqtt/esp.py b/src/home/mqtt/esp.py new file mode 100644 index 0000000..56ced83 --- /dev/null +++ b/src/home/mqtt/esp.py @@ -0,0 +1,106 @@ +import re +import paho.mqtt.client as mqtt + +from .mqtt import MqttBase +from typing import Optional, Union +from .payload.esp import ( + OTAPayload, + OTAResultPayload, + DiagnosticsPayload, + InitialDiagnosticsPayload +) + + +class MqttEspDevice: + id: str + secret: Optional[str] + + def __init__(self, id: str, secret: Optional[str] = None): + self.id = id + self.secret = secret + + +class MqttEspBase(MqttBase): + _devices: list[MqttEspDevice] + _message_callback: Optional[callable] + _ota_publish_callback: Optional[callable] + + TOPIC_LEAF = 'esp' + + def __init__(self, + devices: Union[MqttEspDevice, list[MqttEspDevice]], + subscribe_to_updates=True): + super().__init__(clean_session=True) + if not isinstance(devices, list): + devices = [devices] + self._devices = devices + self._message_callback = None + self._ota_publish_callback = None + self._subscribe_to_updates = subscribe_to_updates + self._ota_mid = None + + def on_connect(self, client: mqtt.Client, userdata, flags, rc): + super().on_connect(client, userdata, flags, rc) + + if self._subscribe_to_updates: + for device in self._devices: + topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#' + self._logger.debug(f"subscribing to {topic}") + client.subscribe(topic, qos=1) + + def on_publish(self, client: mqtt.Client, userdata, mid): + if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback: + self._ota_publish_callback() + + def set_message_callback(self, callback: callable): + self._message_callback = callback + + def on_message(self, client: mqtt.Client, userdata, msg): + try: + match = re.match(self.get_mqtt_topics(), msg.topic) + self._logger.debug(f'topic: {msg.topic}') + if not match: + return + + device_id = match.group(1) + subtopic = match.group(2) + + # try: + next(d for d in self._devices if d.id == device_id) + # except StopIteration:h + # return + + message = None + if subtopic == 'stat': + message = DiagnosticsPayload.unpack(msg.payload) + elif subtopic == 'stat1': + message = InitialDiagnosticsPayload.unpack(msg.payload) + elif subtopic == 'otares': + message = OTAResultPayload.unpack(msg.payload) + + if message and self._message_callback: + self._message_callback(device_id, message) + return True + + except Exception as e: + self._logger.exception(str(e)) + + def push_ota(self, + device_id, + filename: str, + publish_callback: callable, + qos: int): + device = next(d for d in self._devices if d.id == device_id) + assert device.secret is not None, 'device secret not specified' + + self._ota_publish_callback = publish_callback + payload = OTAPayload(secret=device.secret, filename=filename) + publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota', + payload=payload.pack(), + qos=qos) + self._ota_mid = publish_result.mid + self._client.loop_write() + + @classmethod + def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None): + return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'
\ No newline at end of file diff --git a/src/home/mqtt/mqtt.py b/src/home/mqtt/mqtt.py index 9dd973b..4acd4f6 100644 --- a/src/home/mqtt/mqtt.py +++ b/src/home/mqtt/mqtt.py @@ -13,7 +13,7 @@ def username_and_password() -> Tuple[str, str]: return username, password -class MQTTBase: +class MqttBase: def __init__(self, clean_session=True): self._client = mqtt.Client(client_id=config['mqtt']['client_id'], protocol=mqtt.MQTTv311, diff --git a/src/home/mqtt/payload/__init__.py b/src/home/mqtt/payload/__init__.py index 9fcaf3e..eee6709 100644 --- a/src/home/mqtt/payload/__init__.py +++ b/src/home/mqtt/payload/__init__.py @@ -1 +1 @@ -from .base_payload import MQTTPayload
\ No newline at end of file +from .base_payload import MqttPayload
\ No newline at end of file diff --git a/src/home/mqtt/payload/base_payload.py b/src/home/mqtt/payload/base_payload.py index 108e0c0..1abd898 100644 --- a/src/home/mqtt/payload/base_payload.py +++ b/src/home/mqtt/payload/base_payload.py @@ -5,7 +5,21 @@ import re from typing import Optional, Tuple -class MQTTPayload(abc.ABC): +def pldstr(self) -> str: + attrs = [] + for field in self.__class__.__annotations__: + if hasattr(self, field): + attr = getattr(self, field) + attrs.append(f'{field}={attr}') + if attrs: + attrs_s = ' ' + attrs_s += ', '.join(attrs) + else: + attrs_s = '' + return f'<%s{attrs_s}>' % (self.__class__.__name__,) + + +class MqttPayload(abc.ABC): FORMAT = '' PACKER = {} UNPACKER = {} @@ -70,7 +84,7 @@ class MQTTPayload(abc.ABC): bf_number = -1 i += 1 - if issubclass(field_type, MQTTPayloadCustomField): + if issubclass(field_type, MqttPayloadCustomField): kwargs[field] = field_type.unpack(data[i]) else: kwargs[field] = cls._unpack_field(field, data[i]) @@ -87,15 +101,18 @@ class MQTTPayload(abc.ABC): @classmethod def _unpack_field(cls, name, val): - if isinstance(val, MQTTPayloadCustomField): + if isinstance(val, MqttPayloadCustomField): return if cls.UNPACKER and name in cls.UNPACKER: return cls.UNPACKER[name](val) else: return val + def __str__(self): + return pldstr(self) -class MQTTPayloadCustomField(abc.ABC): + +class MqttPayloadCustomField(abc.ABC): def __init__(self, **kwargs): for field in self.__class__.__annotations__: setattr(self, field, kwargs[field]) @@ -109,6 +126,9 @@ class MQTTPayloadCustomField(abc.ABC): def unpack(cls, *args, **kwargs): pass + def __str__(self): + return pldstr(self) + def bit_field(seq_no: int, total_bits: int, bits: int): return type(f'MQTTPayloadBitField_{seq_no}_{total_bits}_{bits}', (object,), { diff --git a/src/home/mqtt/payload/esp.py b/src/home/mqtt/payload/esp.py new file mode 100644 index 0000000..171cdb9 --- /dev/null +++ b/src/home/mqtt/payload/esp.py @@ -0,0 +1,78 @@ +import hashlib + +from .base_payload import MqttPayload, MqttPayloadCustomField + + +class OTAResultPayload(MqttPayload): + FORMAT = '=BB' + result: int + error_code: int + + +class OTAPayload(MqttPayload): + secret: str + filename: str + + # structure of returned data: + # + # uint8_t[len(secret)] secret; + # uint8_t[16] md5; + # *uint8_t data + + def pack(self): + buf = bytearray(self.secret.encode()) + m = hashlib.md5() + with open(self.filename, 'rb') as fd: + content = fd.read() + m.update(content) + buf.extend(m.digest()) + buf.extend(content) + return buf + + def unpack(cls, buf: bytes): + raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') + # secret = buf[:12].decode() + # filename = buf[12:].decode() + # return OTAPayload(secret=secret, filename=filename) + + +class DiagnosticsFlags(MqttPayloadCustomField): + state: bool + config_changed_value_present: bool + config_changed: bool + + @staticmethod + def unpack(flags: int): + # _logger.debug(f'StatFlags.unpack: flags={flags}') + state = flags & 0x1 + ccvp = (flags >> 1) & 0x1 + cc = (flags >> 2) & 0x1 + # _logger.debug(f'StatFlags.unpack: state={state}') + return DiagnosticsFlags(state=(state == 1), + config_changed_value_present=(ccvp == 1), + config_changed=(cc == 1)) + + def __index__(self): + bits = 0 + bits |= (int(self.state) & 0x1) + bits |= (int(self.config_changed_value_present) & 0x1) << 1 + bits |= (int(self.config_changed) & 0x1) << 2 + return bits + + +class InitialDiagnosticsPayload(MqttPayload): + FORMAT = '=IBbIB' + + ip: int + fw_version: int + rssi: int + free_heap: int + flags: DiagnosticsFlags + + +class DiagnosticsPayload(MqttPayload): + FORMAT = '=bIB' + + rssi: int + free_heap: int + flags: DiagnosticsFlags diff --git a/src/home/mqtt/payload/inverter.py b/src/home/mqtt/payload/inverter.py index 1d4099c..09388df 100644 --- a/src/home/mqtt/payload/inverter.py +++ b/src/home/mqtt/payload/inverter.py @@ -1,13 +1,13 @@ import struct -from .base_payload import MQTTPayload, bit_field +from .base_payload import MqttPayload, bit_field from typing import Tuple _mult_10 = lambda n: int(n*10) _div_10 = lambda n: n/10 -class Status(MQTTPayload): +class Status(MqttPayload): # 46 bytes FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH' @@ -65,7 +65,7 @@ class Status(MQTTPayload): load_connected: bit_field(0, 16, 1) -class Generation(MQTTPayload): +class Generation(MqttPayload): # 8 bytes FORMAT = 'II' diff --git a/src/home/mqtt/payload/relay.py b/src/home/mqtt/payload/relay.py index 1a38201..4902991 100644 --- a/src/home/mqtt/payload/relay.py +++ b/src/home/mqtt/payload/relay.py @@ -1,53 +1,13 @@ -import hashlib +from .base_payload import MqttPayload +from .esp import ( + OTAResultPayload, + OTAPayload, + InitialDiagnosticsPayload, + DiagnosticsPayload +) -from .base_payload import MQTTPayload, MQTTPayloadCustomField - -# _logger = logging.getLogger(__name__) - -class StatFlags(MQTTPayloadCustomField): - state: bool - config_changed_value_present: bool - config_changed: bool - - @staticmethod - def unpack(flags: int): - # _logger.debug(f'StatFlags.unpack: flags={flags}') - state = flags & 0x1 - ccvp = (flags >> 1) & 0x1 - cc = (flags >> 2) & 0x1 - # _logger.debug(f'StatFlags.unpack: state={state}') - return StatFlags(state=(state == 1), - config_changed_value_present=(ccvp == 1), - config_changed=(cc == 1)) - - def __index__(self): - bits = 0 - bits |= (int(self.state) & 0x1) - bits |= (int(self.config_changed_value_present) & 0x1) << 1 - bits |= (int(self.config_changed) & 0x1) << 2 - return bits - - -class InitialStatPayload(MQTTPayload): - FORMAT = '=IBbIB' - - ip: int - fw_version: int - rssi: int - free_heap: int - flags: StatFlags - - -class StatPayload(MQTTPayload): - FORMAT = '=bIB' - - rssi: int - free_heap: int - flags: StatFlags - - -class PowerPayload(MQTTPayload): +class PowerPayload(MqttPayload): FORMAT = '=12sB' PACKER = { 'state': lambda n: int(n), @@ -60,37 +20,3 @@ class PowerPayload(MQTTPayload): secret: str state: bool - - -class OTAResultPayload(MQTTPayload): - FORMAT = '=BB' - result: int - error_code: int - - -class OTAPayload(MQTTPayload): - secret: str - filename: str - - # structure of returned data: - # - # uint8_t[len(secret)] secret; - # uint8_t[16] md5; - # *uint8_t data - - def pack(self): - buf = bytearray(self.secret.encode()) - m = hashlib.md5() - with open(self.filename, 'rb') as fd: - content = fd.read() - m.update(content) - buf.extend(m.digest()) - buf.extend(content) - return buf - - def unpack(cls, buf: bytes): - raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') - # secret = buf[:12].decode() - # filename = buf[12:].decode() - # return OTAPayload(secret=secret, filename=filename) - diff --git a/src/home/mqtt/payload/sensors.py b/src/home/mqtt/payload/sensors.py index 3ecc243..f99b307 100644 --- a/src/home/mqtt/payload/sensors.py +++ b/src/home/mqtt/payload/sensors.py @@ -1,10 +1,10 @@ -from .base_payload import MQTTPayload +from .base_payload import MqttPayload _mult_100 = lambda n: int(n*100) _div_100 = lambda n: n/100 -class Temperature(MQTTPayload): +class Temperature(MqttPayload): FORMAT = 'IhH' PACKER = { 'temp': _mult_100, diff --git a/src/home/mqtt/payload/temphum.py b/src/home/mqtt/payload/temphum.py new file mode 100644 index 0000000..5b45ecb --- /dev/null +++ b/src/home/mqtt/payload/temphum.py @@ -0,0 +1,14 @@ +from .base_payload import MqttPayload + +two_digits_precision = lambda x: round(x, 2) + + +class TempHumDataPayload(MqttPayload): + FORMAT = '=dd' + UNPACKER = { + 'temp': two_digits_precision, + 'rh': two_digits_precision + } + + temp: float + rh: float diff --git a/src/home/mqtt/relay.py b/src/home/mqtt/relay.py index 53d43e4..a90f19c 100644 --- a/src/home/mqtt/relay.py +++ b/src/home/mqtt/relay.py @@ -2,83 +2,43 @@ import paho.mqtt.client as mqtt import re import datetime -from .mqtt import MQTTBase -from typing import Optional, Union from .payload.relay import ( - InitialStatPayload, - StatPayload, PowerPayload, - OTAPayload, - OTAResultPayload ) +from .esp import MqttEspBase -class MQTTRelayDevice: - id: str - secret: Optional[str] +class MqttRelay(MqttEspBase): + TOPIC_LEAF = 'relay' - def __init__(self, id: str, secret: Optional[str] = None): - self.id = id - self.secret = secret - - -class MQTTRelay(MQTTBase): - _devices: list[MQTTRelayDevice] - _message_callback: Optional[callable] - _ota_publish_callback: Optional[callable] - - def __init__(self, - devices: Union[MQTTRelayDevice, list[MQTTRelayDevice]], - subscribe_to_updates=True): - super().__init__(clean_session=True) - if not isinstance(devices, list): - devices = [devices] - self._devices = devices - self._message_callback = None - self._ota_publish_callback = None - self._subscribe_to_updates = subscribe_to_updates - self._ota_mid = None - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - super().on_connect(client, userdata, flags, rc) - - if self._subscribe_to_updates: - for device in self._devices: - topic = f'hk/{device.id}/relay/#' - self._logger.debug(f"subscribing to {topic}") - client.subscribe(topic, qos=1) + def set_power(self, device_id, enable: bool, secret=None): + device = next(d for d in self._devices if d.id == device_id) + secret = secret if secret else device.secret - def on_publish(self, client: mqtt.Client, userdata, mid): - if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback: - self._ota_publish_callback() + assert secret is not None, 'device secret not specified' - def set_message_callback(self, callback: callable): - self._message_callback = callback + payload = PowerPayload(secret=secret, + state=enable) + self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/power', + payload=payload.pack(), + qos=1) + self._client.loop_write() def on_message(self, client: mqtt.Client, userdata, msg): + if super().on_message(client, userdata, msg): + return + try: - match = re.match(r'^hk/(.*?)/relay/(stat|stat1|power|otares)$', msg.topic) - self._logger.debug(f'topic: {msg.topic}') + match = re.match(self.get_mqtt_topics(['power']), msg.topic) if not match: return device_id = match.group(1) subtopic = match.group(2) - try: - next(d for d in self._devices if d.id == device_id) - except StopIteration: - return - message = None - if subtopic == 'stat': - message = StatPayload.unpack(msg.payload) - elif subtopic == 'stat1': - message = InitialStatPayload.unpack(msg.payload) - elif subtopic == 'power': + if subtopic == 'power': message = PowerPayload.unpack(msg.payload) - elif subtopic == 'otares': - message = OTAResultPayload.unpack(msg.payload) if message and self._message_callback: self._message_callback(device_id, message) @@ -86,37 +46,8 @@ class MQTTRelay(MQTTBase): except Exception as e: self._logger.exception(str(e)) - def set_power(self, device_id, enable: bool, secret=None): - device = next(d for d in self._devices if d.id == device_id) - secret = secret if secret else device.secret - - assert secret is not None, 'device secret not specified' - - payload = PowerPayload(secret=secret, - state=enable) - self._client.publish(f'hk/{device.id}/relay/power', - payload=payload.pack(), - qos=1) - self._client.loop_write() - - def push_ota(self, - device_id, - filename: str, - publish_callback: callable, - qos: int): - device = next(d for d in self._devices if d.id == device_id) - assert device.secret is not None, 'device secret not specified' - - self._ota_publish_callback = publish_callback - payload = OTAPayload(secret=device.secret, filename=filename) - publish_result = self._client.publish(f'hk/{device.id}/relay/admin/ota', - payload=payload.pack(), - qos=qos) - self._ota_mid = publish_result.mid - self._client.loop_write() - -class MQTTRelayState: +class MqttRelayState: enabled: bool update_time: datetime.datetime rssi: int diff --git a/src/home/mqtt/temphum.py b/src/home/mqtt/temphum.py new file mode 100644 index 0000000..b9b2eb9 --- /dev/null +++ b/src/home/mqtt/temphum.py @@ -0,0 +1,33 @@ +import paho.mqtt.client as mqtt +import re + +from .payload.temphum import ( + TempHumDataPayload +) +from .esp import MqttEspBase + + +class MqttTempHum(MqttEspBase): + TOPIC_LEAF = 'temphum' + + def on_message(self, client: mqtt.Client, userdata, msg): + if super().on_message(client, userdata, msg): + return + + try: + match = re.match(self.get_mqtt_topics(['data']), msg.topic) + if not match: + return + + device_id = match.group(1) + subtopic = match.group(2) + + message = None + if subtopic == 'data': + message = TempHumDataPayload.unpack(msg.payload) + + if message and self._message_callback: + self._message_callback(device_id, message) + + except Exception as e: + self._logger.exception(str(e)) |