diff options
Diffstat (limited to 'include/pio/libs/mqtt_module_ota')
3 files changed, 246 insertions, 0 deletions
diff --git a/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp b/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp new file mode 100644 index 0000000..4e976cd --- /dev/null +++ b/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp @@ -0,0 +1,160 @@ +#include "./ota.h" +#include <homekit/logging.h> +#include <homekit/util.h> +#include <homekit/led.h> + +namespace homekit::mqtt { + +using homekit::led::mcu_led; + +#define MD5_SIZE 16 + +static const char TOPIC_OTA[] = "ota"; +static const char TOPIC_OTA_RESPONSE[] = "otares"; + +void MqttOtaModule::onConnect(Mqtt& mqtt) { + String topic(TOPIC_OTA); + mqtt.subscribeModule(topic, this); +} + +void MqttOtaModule::tick(Mqtt& mqtt) { + if (!tickElapsed()) + return; +} + +void MqttOtaModule::handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) { + char md5[33]; + char* md5Ptr = md5; + + if (index != 0 && ota.dataPacketId != packetId) { + PRINTLN("mqtt/ota: non-matching packet id"); + return; + } + + Update.runAsync(true); + + if (index == 0) { + if (length < CONFIG_NODE_SECRET_SIZE + MD5_SIZE) { + PRINTLN("mqtt/ota: failed to check secret, first packet size is too small"); + return; + } + + if (memcmp((const char*)payload, CONFIG_NODE_SECRET, CONFIG_NODE_SECRET_SIZE) != 0) { + PRINTLN("mqtt/ota: invalid secret"); + return; + } + + PRINTF("mqtt/ota: starting update, total=%ul\n", total-CONFIG_NODE_SECRET_SIZE); + for (int i = 0; i < MD5_SIZE; i++) { + md5Ptr += sprintf(md5Ptr, "%02x", *((unsigned char*)(payload+CONFIG_NODE_SECRET_SIZE+i))); + } + md5[32] = '\0'; + PRINTF("mqtt/ota: md5 is %s\n", md5); + PRINTF("mqtt/ota: first packet is %ul bytes length\n", length); + + md5[32] = '\0'; + + if (Update.isRunning()) { + Update.end(); + Update.clearError(); + } + + if (!Update.setMD5(md5)) { + PRINTLN("mqtt/ota: setMD5 failed"); + return; + } + + ota.dataPacketId = packetId; + + if (!Update.begin(total - CONFIG_NODE_SECRET_SIZE - MD5_SIZE)) { + ota.clean(); +#ifdef DEBUG + Update.printError(Serial); +#endif + sendResponse(mqtt, OtaResult::UPDATE_ERROR, Update.getError()); + } + + ota.written = Update.write(const_cast<uint8_t*>(payload)+CONFIG_NODE_SECRET_SIZE + MD5_SIZE, length-CONFIG_NODE_SECRET_SIZE - MD5_SIZE); + ota.written += CONFIG_NODE_SECRET_SIZE + MD5_SIZE; + + mcu_led->blink(1, 1); + PRINTF("mqtt/ota: updating %u/%u\n", ota.written, Update.size()); + + } else { + if (!Update.isRunning()) { + PRINTLN("mqtt/ota: update is not running"); + return; + } + + if (index == ota.written) { + size_t written; + if ((written = Update.write(const_cast<uint8_t*>(payload), length)) != length) { + PRINTF("mqtt/ota: error: tried to write %ul bytes, write() returned %ul\n", + length, written); + ota.clean(); + Update.end(); + Update.clearError(); + sendResponse(mqtt, OtaResult::WRITE_ERROR); + return; + } + ota.written += length; + + mcu_led->blink(1, 1); + PRINTF("mqtt/ota: updating %u/%u\n", + ota.written - CONFIG_NODE_SECRET_SIZE - MD5_SIZE, + Update.size()); + } else { + PRINTF("mqtt/ota: position is invalid, expected %ul, got %ul\n", ota.written, index); + ota.clean(); + Update.end(); + Update.clearError(); + } + } + + if (Update.isFinished()) { + ota.dataPacketId = 0; + + if (Update.end()) { + ota.finished = true; + ota.publishResultPacketId = sendResponse(mqtt, OtaResult::OK); + PRINTF("mqtt/ota: ok, otares packet_id=%d\n", ota.publishResultPacketId); + } else { + ota.clean(); + + PRINTF("mqtt/ota: error: %u\n", Update.getError()); +#ifdef DEBUG + Update.printError(Serial); +#endif + Update.clearError(); + + sendResponse(mqtt, OtaResult::UPDATE_ERROR, Update.getError()); + } + } +} + +uint16_t MqttOtaModule::sendResponse(Mqtt& mqtt, OtaResult status, uint8_t error_code) const { + MqttOtaResponsePayload resp{ + .status = status, + .error_code = error_code + }; + return mqtt.publish(TOPIC_OTA_RESPONSE, reinterpret_cast<uint8_t*>(&resp), sizeof(resp)); +} + +void MqttOtaModule::onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason) { + if (ota.readyToRestart) { + restartTimer.once(1, restart); + } else if (ota.started()) { + PRINTLN("mqtt: update was in progress, canceling.."); + ota.clean(); + Update.end(); + Update.clearError(); + } +} + +void MqttOtaModule::handleOnPublish(uint16_t packetId) { + if (ota.finished && packetId == ota.publishResultPacketId) { + ota.readyToRestart = true; + } +} + +} diff --git a/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.h b/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.h new file mode 100644 index 0000000..df4f7ce --- /dev/null +++ b/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.h @@ -0,0 +1,75 @@ +#ifndef HOMEKIT_LIB_MQTT_MODULE_OTA_H +#define HOMEKIT_LIB_MQTT_MODULE_OTA_H + +#include <stdint.h> +#include <Ticker.h> +#include <homekit/mqtt/module.h> + +namespace homekit::mqtt { + +enum class OtaResult: uint8_t { + OK = 0, + UPDATE_ERROR = 1, + WRITE_ERROR = 2, +}; + +struct OtaStatus { + uint16_t dataPacketId; + uint16_t publishResultPacketId; + bool finished; + bool readyToRestart; + size_t written; + + OtaStatus() + : dataPacketId(0) + , publishResultPacketId(0) + , finished(false) + , readyToRestart(false) + , written(0) + {} + + inline void clean() { + dataPacketId = 0; + publishResultPacketId = 0; + finished = false; + readyToRestart = false; + written = 0; + } + + inline bool started() const { + return dataPacketId != 0; + } +}; + +struct MqttOtaResponsePayload { + OtaResult status; + uint8_t error_code; +} __attribute__((packed)); + + +class MqttOtaModule: public MqttModule { +private: + OtaStatus ota; + Ticker restartTimer; + + uint16_t sendResponse(Mqtt& mqtt, OtaResult status, uint8_t error_code = 0) const; + +public: + MqttOtaModule() : MqttModule(0, true, true) {} + + void onConnect(Mqtt& mqtt) override; + void onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason) override; + + void tick(Mqtt& mqtt) override; + + void handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) override; + void handleOnPublish(uint16_t packetId) override; + + inline bool isReadyToRestart() const { + return ota.readyToRestart; + } +}; + +} + +#endif //HOMEKIT_LIB_MQTT_MODULE_OTA_H diff --git a/include/pio/libs/mqtt_module_ota/library.json b/include/pio/libs/mqtt_module_ota/library.json new file mode 100644 index 0000000..1577fed --- /dev/null +++ b/include/pio/libs/mqtt_module_ota/library.json @@ -0,0 +1,11 @@ +{ + "name": "homekit_mqtt_module_ota", + "version": "1.0.6", + "build": { + "flags": "-I../../include" + }, + "dependencies": { + "homekit_led": "file://../../include/pio/libs/led", + "homekit_mqtt": "file://../../include/pio/libs/mqtt" + } +} |