summaryrefslogtreecommitdiff
path: root/src/home/mqtt/module/ota.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/home/mqtt/module/ota.py')
-rw-r--r--src/home/mqtt/module/ota.py77
1 files changed, 77 insertions, 0 deletions
diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py
new file mode 100644
index 0000000..cd34332
--- /dev/null
+++ b/src/home/mqtt/module/ota.py
@@ -0,0 +1,77 @@
+import hashlib
+
+from typing import Optional
+from .._payload import MqttPayload
+from .._node import MqttModule, MqttNode
+
+MODULE_NAME = 'MqttOtaModule'
+
+
+class OtaResultPayload(MqttPayload):
+ FORMAT = '=BB'
+ result: int
+ error_code: int
+
+
+class OtaPayload(MqttPayload):
+ secret: str
+ filename: str
+
+ # structure of returned data:
+ #
+ # uint8_t[len(secret)] secret;
+ # uint8_t[16] md5;
+ # *uint8_t data
+
+ def pack(self):
+ buf = bytearray(self.secret.encode())
+ m = hashlib.md5()
+ with open(self.filename, 'rb') as fd:
+ content = fd.read()
+ m.update(content)
+ buf.extend(m.digest())
+ buf.extend(content)
+ return buf
+
+ def unpack(cls, buf: bytes):
+ raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented')
+ # secret = buf[:12].decode()
+ # filename = buf[12:].decode()
+ # return OTAPayload(secret=secret, filename=filename)
+
+
+class MqttOtaModule(MqttModule):
+ _ota_request: Optional[tuple[str, int]]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._ota_request = None
+
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
+ mqtt.subscribe_module("otares", self)
+
+ if self._ota_request is not None:
+ filename, qos = self._ota_request
+ self._ota_request = None
+ self.do_push_ota(self._mqtt_node_ref.secret, filename, qos)
+
+ def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
+ if topic == 'otares':
+ message = OtaResultPayload.unpack(payload)
+ self._logger.debug(message)
+ return message
+
+ def do_push_ota(self, secret: str, filename: str, qos: int):
+ payload = OtaPayload(secret=secret, filename=filename)
+ self._mqtt_node_ref.publish('ota',
+ payload=payload.pack(),
+ qos=qos)
+
+ def push_ota(self,
+ filename: str,
+ qos: int):
+ if not self._initialized:
+ self._ota_request = (filename, qos)
+ else:
+ self.do_push_ota(filename, qos)