diff options
Diffstat (limited to 'include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp')
-rw-r--r-- | include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp | 160 |
1 files changed, 160 insertions, 0 deletions
diff --git a/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp b/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp new file mode 100644 index 0000000..4e976cd --- /dev/null +++ b/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp @@ -0,0 +1,160 @@ +#include "./ota.h" +#include <homekit/logging.h> +#include <homekit/util.h> +#include <homekit/led.h> + +namespace homekit::mqtt { + +using homekit::led::mcu_led; + +#define MD5_SIZE 16 + +static const char TOPIC_OTA[] = "ota"; +static const char TOPIC_OTA_RESPONSE[] = "otares"; + +void MqttOtaModule::onConnect(Mqtt& mqtt) { + String topic(TOPIC_OTA); + mqtt.subscribeModule(topic, this); +} + +void MqttOtaModule::tick(Mqtt& mqtt) { + if (!tickElapsed()) + return; +} + +void MqttOtaModule::handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) { + char md5[33]; + char* md5Ptr = md5; + + if (index != 0 && ota.dataPacketId != packetId) { + PRINTLN("mqtt/ota: non-matching packet id"); + return; + } + + Update.runAsync(true); + + if (index == 0) { + if (length < CONFIG_NODE_SECRET_SIZE + MD5_SIZE) { + PRINTLN("mqtt/ota: failed to check secret, first packet size is too small"); + return; + } + + if (memcmp((const char*)payload, CONFIG_NODE_SECRET, CONFIG_NODE_SECRET_SIZE) != 0) { + PRINTLN("mqtt/ota: invalid secret"); + return; + } + + PRINTF("mqtt/ota: starting update, total=%ul\n", total-CONFIG_NODE_SECRET_SIZE); + for (int i = 0; i < MD5_SIZE; i++) { + md5Ptr += sprintf(md5Ptr, "%02x", *((unsigned char*)(payload+CONFIG_NODE_SECRET_SIZE+i))); + } + md5[32] = '\0'; + PRINTF("mqtt/ota: md5 is %s\n", md5); + PRINTF("mqtt/ota: first packet is %ul bytes length\n", length); + + md5[32] = '\0'; + + if (Update.isRunning()) { + Update.end(); + Update.clearError(); + } + + if (!Update.setMD5(md5)) { + PRINTLN("mqtt/ota: setMD5 failed"); + return; + } + + ota.dataPacketId = packetId; + + if (!Update.begin(total - CONFIG_NODE_SECRET_SIZE - MD5_SIZE)) { + ota.clean(); +#ifdef DEBUG + Update.printError(Serial); +#endif + sendResponse(mqtt, OtaResult::UPDATE_ERROR, Update.getError()); + } + + ota.written = Update.write(const_cast<uint8_t*>(payload)+CONFIG_NODE_SECRET_SIZE + MD5_SIZE, length-CONFIG_NODE_SECRET_SIZE - MD5_SIZE); + ota.written += CONFIG_NODE_SECRET_SIZE + MD5_SIZE; + + mcu_led->blink(1, 1); + PRINTF("mqtt/ota: updating %u/%u\n", ota.written, Update.size()); + + } else { + if (!Update.isRunning()) { + PRINTLN("mqtt/ota: update is not running"); + return; + } + + if (index == ota.written) { + size_t written; + if ((written = Update.write(const_cast<uint8_t*>(payload), length)) != length) { + PRINTF("mqtt/ota: error: tried to write %ul bytes, write() returned %ul\n", + length, written); + ota.clean(); + Update.end(); + Update.clearError(); + sendResponse(mqtt, OtaResult::WRITE_ERROR); + return; + } + ota.written += length; + + mcu_led->blink(1, 1); + PRINTF("mqtt/ota: updating %u/%u\n", + ota.written - CONFIG_NODE_SECRET_SIZE - MD5_SIZE, + Update.size()); + } else { + PRINTF("mqtt/ota: position is invalid, expected %ul, got %ul\n", ota.written, index); + ota.clean(); + Update.end(); + Update.clearError(); + } + } + + if (Update.isFinished()) { + ota.dataPacketId = 0; + + if (Update.end()) { + ota.finished = true; + ota.publishResultPacketId = sendResponse(mqtt, OtaResult::OK); + PRINTF("mqtt/ota: ok, otares packet_id=%d\n", ota.publishResultPacketId); + } else { + ota.clean(); + + PRINTF("mqtt/ota: error: %u\n", Update.getError()); +#ifdef DEBUG + Update.printError(Serial); +#endif + Update.clearError(); + + sendResponse(mqtt, OtaResult::UPDATE_ERROR, Update.getError()); + } + } +} + +uint16_t MqttOtaModule::sendResponse(Mqtt& mqtt, OtaResult status, uint8_t error_code) const { + MqttOtaResponsePayload resp{ + .status = status, + .error_code = error_code + }; + return mqtt.publish(TOPIC_OTA_RESPONSE, reinterpret_cast<uint8_t*>(&resp), sizeof(resp)); +} + +void MqttOtaModule::onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason) { + if (ota.readyToRestart) { + restartTimer.once(1, restart); + } else if (ota.started()) { + PRINTLN("mqtt: update was in progress, canceling.."); + ota.clean(); + Update.end(); + Update.clearError(); + } +} + +void MqttOtaModule::handleOnPublish(uint16_t packetId) { + if (ota.finished && packetId == ota.publishResultPacketId) { + ota.readyToRestart = true; + } +} + +} |