summaryrefslogtreecommitdiff
path: root/include/pio/libs/mqtt_module_ota
diff options
context:
space:
mode:
Diffstat (limited to 'include/pio/libs/mqtt_module_ota')
-rw-r--r--include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp160
-rw-r--r--include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.h75
-rw-r--r--include/pio/libs/mqtt_module_ota/library.json11
3 files changed, 246 insertions, 0 deletions
diff --git a/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp b/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp
new file mode 100644
index 0000000..4e976cd
--- /dev/null
+++ b/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp
@@ -0,0 +1,160 @@
+#include "./ota.h"
+#include <homekit/logging.h>
+#include <homekit/util.h>
+#include <homekit/led.h>
+
+namespace homekit::mqtt {
+
+using homekit::led::mcu_led;
+
+#define MD5_SIZE 16
+
+static const char TOPIC_OTA[] = "ota";
+static const char TOPIC_OTA_RESPONSE[] = "otares";
+
+void MqttOtaModule::onConnect(Mqtt& mqtt) {
+ String topic(TOPIC_OTA);
+ mqtt.subscribeModule(topic, this);
+}
+
+void MqttOtaModule::tick(Mqtt& mqtt) {
+ if (!tickElapsed())
+ return;
+}
+
+void MqttOtaModule::handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) {
+ char md5[33];
+ char* md5Ptr = md5;
+
+ if (index != 0 && ota.dataPacketId != packetId) {
+ PRINTLN("mqtt/ota: non-matching packet id");
+ return;
+ }
+
+ Update.runAsync(true);
+
+ if (index == 0) {
+ if (length < CONFIG_NODE_SECRET_SIZE + MD5_SIZE) {
+ PRINTLN("mqtt/ota: failed to check secret, first packet size is too small");
+ return;
+ }
+
+ if (memcmp((const char*)payload, CONFIG_NODE_SECRET, CONFIG_NODE_SECRET_SIZE) != 0) {
+ PRINTLN("mqtt/ota: invalid secret");
+ return;
+ }
+
+ PRINTF("mqtt/ota: starting update, total=%ul\n", total-CONFIG_NODE_SECRET_SIZE);
+ for (int i = 0; i < MD5_SIZE; i++) {
+ md5Ptr += sprintf(md5Ptr, "%02x", *((unsigned char*)(payload+CONFIG_NODE_SECRET_SIZE+i)));
+ }
+ md5[32] = '\0';
+ PRINTF("mqtt/ota: md5 is %s\n", md5);
+ PRINTF("mqtt/ota: first packet is %ul bytes length\n", length);
+
+ md5[32] = '\0';
+
+ if (Update.isRunning()) {
+ Update.end();
+ Update.clearError();
+ }
+
+ if (!Update.setMD5(md5)) {
+ PRINTLN("mqtt/ota: setMD5 failed");
+ return;
+ }
+
+ ota.dataPacketId = packetId;
+
+ if (!Update.begin(total - CONFIG_NODE_SECRET_SIZE - MD5_SIZE)) {
+ ota.clean();
+#ifdef DEBUG
+ Update.printError(Serial);
+#endif
+ sendResponse(mqtt, OtaResult::UPDATE_ERROR, Update.getError());
+ }
+
+ ota.written = Update.write(const_cast<uint8_t*>(payload)+CONFIG_NODE_SECRET_SIZE + MD5_SIZE, length-CONFIG_NODE_SECRET_SIZE - MD5_SIZE);
+ ota.written += CONFIG_NODE_SECRET_SIZE + MD5_SIZE;
+
+ mcu_led->blink(1, 1);
+ PRINTF("mqtt/ota: updating %u/%u\n", ota.written, Update.size());
+
+ } else {
+ if (!Update.isRunning()) {
+ PRINTLN("mqtt/ota: update is not running");
+ return;
+ }
+
+ if (index == ota.written) {
+ size_t written;
+ if ((written = Update.write(const_cast<uint8_t*>(payload), length)) != length) {
+ PRINTF("mqtt/ota: error: tried to write %ul bytes, write() returned %ul\n",
+ length, written);
+ ota.clean();
+ Update.end();
+ Update.clearError();
+ sendResponse(mqtt, OtaResult::WRITE_ERROR);
+ return;
+ }
+ ota.written += length;
+
+ mcu_led->blink(1, 1);
+ PRINTF("mqtt/ota: updating %u/%u\n",
+ ota.written - CONFIG_NODE_SECRET_SIZE - MD5_SIZE,
+ Update.size());
+ } else {
+ PRINTF("mqtt/ota: position is invalid, expected %ul, got %ul\n", ota.written, index);
+ ota.clean();
+ Update.end();
+ Update.clearError();
+ }
+ }
+
+ if (Update.isFinished()) {
+ ota.dataPacketId = 0;
+
+ if (Update.end()) {
+ ota.finished = true;
+ ota.publishResultPacketId = sendResponse(mqtt, OtaResult::OK);
+ PRINTF("mqtt/ota: ok, otares packet_id=%d\n", ota.publishResultPacketId);
+ } else {
+ ota.clean();
+
+ PRINTF("mqtt/ota: error: %u\n", Update.getError());
+#ifdef DEBUG
+ Update.printError(Serial);
+#endif
+ Update.clearError();
+
+ sendResponse(mqtt, OtaResult::UPDATE_ERROR, Update.getError());
+ }
+ }
+}
+
+uint16_t MqttOtaModule::sendResponse(Mqtt& mqtt, OtaResult status, uint8_t error_code) const {
+ MqttOtaResponsePayload resp{
+ .status = status,
+ .error_code = error_code
+ };
+ return mqtt.publish(TOPIC_OTA_RESPONSE, reinterpret_cast<uint8_t*>(&resp), sizeof(resp));
+}
+
+void MqttOtaModule::onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason) {
+ if (ota.readyToRestart) {
+ restartTimer.once(1, restart);
+ } else if (ota.started()) {
+ PRINTLN("mqtt: update was in progress, canceling..");
+ ota.clean();
+ Update.end();
+ Update.clearError();
+ }
+}
+
+void MqttOtaModule::handleOnPublish(uint16_t packetId) {
+ if (ota.finished && packetId == ota.publishResultPacketId) {
+ ota.readyToRestart = true;
+ }
+}
+
+}
diff --git a/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.h b/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.h
new file mode 100644
index 0000000..df4f7ce
--- /dev/null
+++ b/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.h
@@ -0,0 +1,75 @@
+#ifndef HOMEKIT_LIB_MQTT_MODULE_OTA_H
+#define HOMEKIT_LIB_MQTT_MODULE_OTA_H
+
+#include <stdint.h>
+#include <Ticker.h>
+#include <homekit/mqtt/module.h>
+
+namespace homekit::mqtt {
+
+enum class OtaResult: uint8_t {
+ OK = 0,
+ UPDATE_ERROR = 1,
+ WRITE_ERROR = 2,
+};
+
+struct OtaStatus {
+ uint16_t dataPacketId;
+ uint16_t publishResultPacketId;
+ bool finished;
+ bool readyToRestart;
+ size_t written;
+
+ OtaStatus()
+ : dataPacketId(0)
+ , publishResultPacketId(0)
+ , finished(false)
+ , readyToRestart(false)
+ , written(0)
+ {}
+
+ inline void clean() {
+ dataPacketId = 0;
+ publishResultPacketId = 0;
+ finished = false;
+ readyToRestart = false;
+ written = 0;
+ }
+
+ inline bool started() const {
+ return dataPacketId != 0;
+ }
+};
+
+struct MqttOtaResponsePayload {
+ OtaResult status;
+ uint8_t error_code;
+} __attribute__((packed));
+
+
+class MqttOtaModule: public MqttModule {
+private:
+ OtaStatus ota;
+ Ticker restartTimer;
+
+ uint16_t sendResponse(Mqtt& mqtt, OtaResult status, uint8_t error_code = 0) const;
+
+public:
+ MqttOtaModule() : MqttModule(0, true, true) {}
+
+ void onConnect(Mqtt& mqtt) override;
+ void onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason) override;
+
+ void tick(Mqtt& mqtt) override;
+
+ void handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) override;
+ void handleOnPublish(uint16_t packetId) override;
+
+ inline bool isReadyToRestart() const {
+ return ota.readyToRestart;
+ }
+};
+
+}
+
+#endif //HOMEKIT_LIB_MQTT_MODULE_OTA_H
diff --git a/include/pio/libs/mqtt_module_ota/library.json b/include/pio/libs/mqtt_module_ota/library.json
new file mode 100644
index 0000000..1577fed
--- /dev/null
+++ b/include/pio/libs/mqtt_module_ota/library.json
@@ -0,0 +1,11 @@
+{
+ "name": "homekit_mqtt_module_ota",
+ "version": "1.0.6",
+ "build": {
+ "flags": "-I../../include"
+ },
+ "dependencies": {
+ "homekit_led": "file://../../include/pio/libs/led",
+ "homekit_mqtt": "file://../../include/pio/libs/mqtt"
+ }
+}