#include #include #include #include #include #include "relay.h" #include "mqtt.h" #include "leds.h" namespace homekit::mqtt { static const char TOPIC_DIAGNOSTICS[] = "stat"; static const char TOPIC_INITIAL_DIAGNOSTICS[] = "stat1"; static const char TOPIC_OTA_RESPONSE[] = "otares"; static const char TOPIC_RELAY_POWER[] = "power"; static const char TOPIC_ADMIN_OTA[] = "admin/ota"; static const uint16_t MQTT_KEEPALIVE = 30; enum class IncomingMessage { UNKNOWN, RELAY_POWER, OTA }; using namespace espMqttClientTypes; #define MD5_SIZE 16 MQTT::MQTT() { auto cfg = config::read(); homeId = String(cfg.flags.node_configured ? cfg.node_id : wifi::NODE_ID); randomSeed(micros()); client.onConnect([&](bool sessionPresent) { PRINTLN("mqtt: connected"); sendInitialDiagnostics(); subscribe(TOPIC_RELAY_POWER, 1); subscribe(TOPIC_ADMIN_OTA); }); client.onDisconnect([&](DisconnectReason reason) { PRINTF("mqtt: disconnected, reason=%d\n", static_cast(reason)); #ifdef DEBUG if (reason == DisconnectReason::TLS_BAD_FINGERPRINT) PRINTLN("reason: bad fingerprint"); #endif if (ota.started()) { PRINTLN("mqtt: update was in progress, canceling.."); ota.clean(); Update.end(); Update.clearError(); } if (ota.readyToRestart) { restartTimer.once(1, restart); } else { reconnectTimer.once(2, [&]() { reconnect(); }); } }); client.onSubscribe([&](uint16_t packetId, const SubscribeReturncode* returncodes, size_t len) { PRINTF("mqtt: subscribe ack, packet_id=%d\n", packetId); for (size_t i = 0; i < len; i++) { PRINTF(" return code: %u\n", static_cast(*(returncodes+i))); } }); client.onUnsubscribe([&](uint16_t packetId) { PRINTF("mqtt: unsubscribe ack, packet_id=%d\n", packetId); }); client.onMessage([&](const MessageProperties& properties, const char* topic, const uint8_t* payload, size_t len, size_t index, size_t total) { PRINTF("mqtt: message received, topic=%s, qos=%d, dup=%d, retain=%d, len=%ul, index=%ul, total=%ul\n", topic, properties.qos, (int)properties.dup, (int)properties.retain, len, index, total); IncomingMessage msgType = IncomingMessage::UNKNOWN; const char *ptr = topic + homeId.length() + 10; String relevantTopic(ptr); if (relevantTopic == TOPIC_RELAY_POWER) msgType = IncomingMessage::RELAY_POWER; else if (relevantTopic == TOPIC_ADMIN_OTA) msgType = IncomingMessage::OTA; if (len != total && msgType != IncomingMessage::OTA) { PRINTLN("mqtt: received partial message, not supported"); return; } switch (msgType) { case IncomingMessage::RELAY_POWER: handleRelayPowerPayload(payload, total); break; case IncomingMessage::OTA: if (ota.finished) break; handleAdminOtaPayload(properties.packetId, payload, len, index, total); break; case IncomingMessage::UNKNOWN: PRINTF("error: invalid topic %s\n", topic); break; } }); client.onPublish([&](uint16_t packetId) { PRINTF("mqtt: publish ack, packet_id=%d\n", packetId); if (ota.finished && packetId == ota.publishResultPacketId) { ota.readyToRestart = true; } }); client.setServer(MQTT_SERVER, MQTT_PORT); client.setClientId(MQTT_CLIENT_ID); client.setCredentials(MQTT_USERNAME, MQTT_PASSWORD); client.setCleanSession(true); client.setFingerprint(MQTT_CA_FINGERPRINT); client.setKeepAlive(MQTT_KEEPALIVE); } void MQTT::connect() { reconnect(); } void MQTT::reconnect() { if (client.connected()) { PRINTLN("warning: already connected"); return; } client.connect(); } void MQTT::disconnect() { // TODO test how this works??? reconnectTimer.detach(); client.disconnect(); } uint16_t MQTT::publish(const String &topic, uint8_t *payload, size_t length) { String fullTopic = "hk/" + homeId + "/relay/" + topic; return client.publish(fullTopic.c_str(), 1, false, payload, length); } void MQTT::loop() { client.loop(); } uint16_t MQTT::subscribe(const String &topic, uint8_t qos) { String fullTopic = "hk/" + homeId + "/relay/" + topic; PRINTF("mqtt: subscribing to %s...\n", fullTopic.c_str()); uint16_t packetId = client.subscribe(fullTopic.c_str(), qos); if (!packetId) PRINTF("error: failed to subscribe to %s\n", fullTopic.c_str()); return packetId; } void MQTT::sendInitialDiagnostics() { auto cfg = config::read(); InitialDiagnosticsPayload stat{ .ip = wifi::getIPAsInteger(), .fw_version = CONFIG_FW_VERSION, .rssi = wifi::getRSSI(), .free_heap = ESP.getFreeHeap(), .flags = DiagnosticsFlags{ .state = static_cast(relay::getState() ? 1 : 0), .config_changed_value_present = 1, .config_changed = static_cast(cfg.flags.node_configured || cfg.flags.wifi_configured ? 1 : 0) } }; publish(TOPIC_INITIAL_DIAGNOSTICS, reinterpret_cast(&stat), sizeof(stat)); diagnosticsStopWatch.save(); } void MQTT::sendDiagnostics() { DiagnosticsPayload stat{ .rssi = wifi::getRSSI(), .free_heap = ESP.getFreeHeap(), .flags = DiagnosticsFlags{ .state = static_cast(relay::getState() ? 1 : 0), .config_changed_value_present = 0, .config_changed = 0 } }; publish(TOPIC_DIAGNOSTICS, reinterpret_cast(&stat), sizeof(stat)); diagnosticsStopWatch.save(); } uint16_t MQTT::sendOtaResponse(OTAResult status, uint8_t error_code) { OTAResponse resp{ .status = status, .error_code = error_code }; return publish(TOPIC_OTA_RESPONSE, reinterpret_cast(&resp), sizeof(resp)); } void MQTT::handleRelayPowerPayload(const uint8_t *payload, uint32_t length) { if (length != sizeof(PowerPayload)) { PRINTF("error: size of payload (%ul) does not match expected (%ul)\n", length, sizeof(PowerPayload)); return; } auto pd = reinterpret_cast(payload); if (strncmp(pd->secret, MQTT_SECRET, sizeof(pd->secret)) != 0) { PRINTLN("error: invalid secret"); return; } if (pd->state == 1) { PRINTLN("mqtt: turning relay on"); relay::setOn(); } else if (pd->state == 0) { PRINTLN("mqtt: turning relay off"); relay::setOff(); } else { PRINTLN("error: unexpected state value"); } sendDiagnostics(); } void MQTT::handleAdminOtaPayload(uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) { char md5[33]; char* md5Ptr = md5; if (index != 0 && ota.dataPacketId != packetId) { PRINTLN("mqtt/ota: non-matching packet id"); return; } Update.runAsync(true); if (index == 0) { if (length < CONFIG_NODE_SECRET_SIZE + MD5_SIZE) { PRINTLN("mqtt/ota: failed to check secret, first packet size is too small"); return; } if (memcmp((const char*)payload, CONFIG_NODE_SECRET, CONFIG_NODE_SECRET_SIZE) != 0) { PRINTLN("mqtt/ota: invalid secret"); return; } PRINTF("mqtt/ota: starting update, total=%ul\n", total-NODE_SECRET_SIZE); for (int i = 0; i < MD5_SIZE; i++) { md5Ptr += sprintf(md5Ptr, "%02x", *((unsigned char*)(payload+CONFIG_NODE_SECRET_SIZE+i))); } md5[32] = '\0'; PRINTF("mqtt/ota: md5 is %s\n", md5); PRINTF("mqtt/ota: first packet is %ul bytes length\n", length); md5[32] = '\0'; if (Update.isRunning()) { Update.end(); Update.clearError(); } if (!Update.setMD5(md5)) { PRINTLN("mqtt/ota: setMD5 failed"); return; } ota.dataPacketId = packetId; if (!Update.begin(total - CONFIG_NODE_SECRET_SIZE - MD5_SIZE)) { ota.clean(); #ifdef DEBUG Update.printError(Serial); #endif sendOtaResponse(OTAResult::UPDATE_ERROR, Update.getError()); } ota.written = Update.write(const_cast(payload)+CONFIG_NODE_SECRET_SIZE + MD5_SIZE, length-CONFIG_NODE_SECRET_SIZE - MD5_SIZE); ota.written += CONFIG_NODE_SECRET_SIZE + MD5_SIZE; mcu_led->blink(1, 1); PRINTF("mqtt/ota: updating %u/%u\n", ota.written, Update.size()); } else { if (!Update.isRunning()) { PRINTLN("mqtt/ota: update is not running"); return; } if (index == ota.written) { size_t written; if ((written = Update.write(const_cast(payload), length)) != length) { PRINTF("mqtt/ota: error: tried to write %ul bytes, write() returned %ul\n", length, written); ota.clean(); Update.end(); Update.clearError(); sendOtaResponse(OTAResult::WRITE_ERROR); return; } ota.written += length; mcu_led->blink(1, 1); PRINTF("mqtt/ota: updating %u/%u\n", ota.written - CONFIG_NODE_SECRET_SIZE - MD5_SIZE, Update.size()); } else { PRINTF("mqtt/ota: position is invalid, expected %ul, got %ul\n", ota.written, index); ota.clean(); Update.end(); Update.clearError(); } } if (Update.isFinished()) { ota.dataPacketId = 0; if (Update.end()) { ota.finished = true; ota.publishResultPacketId = sendOtaResponse(OTAResult::OK); PRINTF("mqtt/ota: ok, otares packet_id=%d\n", ota.publishResultPacketId); } else { ota.clean(); PRINTF("mqtt/ota: error: %u\n", Update.getError()); #ifdef DEBUG Update.printError(Serial); #endif Update.clearError(); sendOtaResponse(OTAResult::UPDATE_ERROR, Update.getError()); } } } }