#include "./ota.h" #include #include #include namespace homekit::mqtt { using homekit::led::mcu_led; #define MD5_SIZE 16 static const char TOPIC_OTA[] = "ota"; static const char TOPIC_OTA_RESPONSE[] = "otares"; void MqttOtaModule::onConnect(Mqtt& mqtt) { String topic(TOPIC_OTA); mqtt.subscribeModule(topic, this); } void MqttOtaModule::tick(Mqtt& mqtt) { if (!tickElapsed()) return; } void MqttOtaModule::handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) { char md5[33]; char* md5Ptr = md5; if (index != 0 && ota.dataPacketId != packetId) { PRINTLN("mqtt/ota: non-matching packet id"); return; } Update.runAsync(true); if (index == 0) { if (length < CONFIG_NODE_SECRET_SIZE + MD5_SIZE) { PRINTLN("mqtt/ota: failed to check secret, first packet size is too small"); return; } if (memcmp((const char*)payload, CONFIG_NODE_SECRET, CONFIG_NODE_SECRET_SIZE) != 0) { PRINTLN("mqtt/ota: invalid secret"); return; } PRINTF("mqtt/ota: starting update, total=%ul\n", total-CONFIG_NODE_SECRET_SIZE); for (int i = 0; i < MD5_SIZE; i++) { md5Ptr += sprintf(md5Ptr, "%02x", *((unsigned char*)(payload+CONFIG_NODE_SECRET_SIZE+i))); } md5[32] = '\0'; PRINTF("mqtt/ota: md5 is %s\n", md5); PRINTF("mqtt/ota: first packet is %ul bytes length\n", length); md5[32] = '\0'; if (Update.isRunning()) { Update.end(); Update.clearError(); } if (!Update.setMD5(md5)) { PRINTLN("mqtt/ota: setMD5 failed"); return; } ota.dataPacketId = packetId; if (!Update.begin(total - CONFIG_NODE_SECRET_SIZE - MD5_SIZE)) { ota.clean(); #ifdef DEBUG Update.printError(Serial); #endif sendResponse(mqtt, OtaResult::UPDATE_ERROR, Update.getError()); } ota.written = Update.write(const_cast(payload)+CONFIG_NODE_SECRET_SIZE + MD5_SIZE, length-CONFIG_NODE_SECRET_SIZE - MD5_SIZE); ota.written += CONFIG_NODE_SECRET_SIZE + MD5_SIZE; mcu_led->blink(1, 1); PRINTF("mqtt/ota: updating %u/%u\n", ota.written, Update.size()); } else { if (!Update.isRunning()) { PRINTLN("mqtt/ota: update is not running"); return; } if (index == ota.written) { size_t written; if ((written = Update.write(const_cast(payload), length)) != length) { PRINTF("mqtt/ota: error: tried to write %ul bytes, write() returned %ul\n", length, written); ota.clean(); Update.end(); Update.clearError(); sendResponse(mqtt, OtaResult::WRITE_ERROR); return; } ota.written += length; mcu_led->blink(1, 1); PRINTF("mqtt/ota: updating %u/%u\n", ota.written - CONFIG_NODE_SECRET_SIZE - MD5_SIZE, Update.size()); } else { PRINTF("mqtt/ota: position is invalid, expected %ul, got %ul\n", ota.written, index); ota.clean(); Update.end(); Update.clearError(); } } if (Update.isFinished()) { ota.dataPacketId = 0; if (Update.end()) { ota.finished = true; ota.publishResultPacketId = sendResponse(mqtt, OtaResult::OK); PRINTF("mqtt/ota: ok, otares packet_id=%d\n", ota.publishResultPacketId); } else { ota.clean(); PRINTF("mqtt/ota: error: %u\n", Update.getError()); #ifdef DEBUG Update.printError(Serial); #endif Update.clearError(); sendResponse(mqtt, OtaResult::UPDATE_ERROR, Update.getError()); } } } uint16_t MqttOtaModule::sendResponse(Mqtt& mqtt, OtaResult status, uint8_t error_code) const { MqttOtaResponsePayload resp{ .status = status, .error_code = error_code }; return mqtt.publish(TOPIC_OTA_RESPONSE, reinterpret_cast(&resp), sizeof(resp)); } void MqttOtaModule::onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason) { if (ota.readyToRestart) { restartTimer.once(1, restart); } else if (ota.started()) { PRINTLN("mqtt: update was in progress, canceling.."); ota.clean(); Update.end(); Update.clearError(); } } void MqttOtaModule::handleOnPublish(uint16_t packetId) { if (ota.finished && packetId == ota.publishResultPacketId) { ota.readyToRestart = true; } } }