import hashlib from typing import Optional from .._payload import MqttPayload from .._node import MqttModule, MqttNode MODULE_NAME = 'MqttOtaModule' class OtaResultPayload(MqttPayload): FORMAT = '=BB' result: int error_code: int class OtaPayload(MqttPayload): secret: str filename: str # structure of returned data: # # uint8_t[len(secret)] secret; # uint8_t[16] md5; # *uint8_t data def pack(self): buf = bytearray(self.secret.encode()) m = hashlib.md5() with open(self.filename, 'rb') as fd: content = fd.read() m.update(content) buf.extend(m.digest()) buf.extend(content) return buf def unpack(cls, buf: bytes): raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') # secret = buf[:12].decode() # filename = buf[12:].decode() # return OTAPayload(secret=secret, filename=filename) class MqttOtaModule(MqttModule): _ota_request: Optional[tuple[str, int]] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._ota_request = None def on_connect(self, mqtt: MqttNode): super().on_connect(mqtt) mqtt.subscribe_module("otares", self) if self._ota_request is not None: filename, qos = self._ota_request self._ota_request = None self.do_push_ota(self._mqtt_node_ref.secret, filename, qos) def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: if topic == 'otares': message = OtaResultPayload.unpack(payload) self._logger.debug(message) return message def do_push_ota(self, secret: str, filename: str, qos: int): payload = OtaPayload(secret=secret, filename=filename) self._mqtt_node_ref.publish('ota', payload=payload.pack(), qos=qos) def push_ota(self, filename: str, qos: int): if not self._initialized: self._ota_request = (filename, qos) else: self.do_push_ota(self._mqtt_node_ref.secret, filename, qos)