diff options
Diffstat (limited to 'platformio/relayctl/src/mqtt.cpp')
-rw-r--r-- | platformio/relayctl/src/mqtt.cpp | 341 |
1 files changed, 261 insertions, 80 deletions
diff --git a/platformio/relayctl/src/mqtt.cpp b/platformio/relayctl/src/mqtt.cpp index cca215b..e1f70c3 100644 --- a/platformio/relayctl/src/mqtt.cpp +++ b/platformio/relayctl/src/mqtt.cpp @@ -1,9 +1,13 @@ +#include <ESP8266httpUpdate.h> #include "mqtt.h" #include "logging.h" #include "wifi.h" #include "config.def.h" #include "relay.h" #include "config.h" +#include "static.h" +#include "util.h" +#include "led.h" namespace homekit::mqtt { @@ -13,26 +17,124 @@ static const uint16_t MQTT_PORT = DEFAULT_MQTT_PORT; static const char MQTT_USERNAME[] = DEFAULT_MQTT_USERNAME; static const char MQTT_PASSWORD[] = DEFAULT_MQTT_PASSWORD; static const char MQTT_CLIENT_ID[] = DEFAULT_MQTT_CLIENT_ID; +static const char MQTT_SECRET[HOME_SECRET_SIZE+1] = HOME_SECRET; -static const char MQTT_SECRET[] = SECRET; -static const char TOPIC_RELAY_POWER[] = "relay/power"; static const char TOPIC_STAT[] = "stat"; -static const char TOPIC_STAT1[] = "stat1"; -static const char TOPIC_ADMIN[] = "admin"; -static const char TOPIC_RELAY[] = "relay"; +static const char TOPIC_INITIAL_STAT[] = "stat1"; +static const char TOPIC_OTA_RESPONSE[] = "otares"; +static const char TOPIC_RELAY_POWER[] = "power"; +static const char TOPIC_ADMIN_OTA[] = "admin/ota"; +static const uint16_t MQTT_KEEPALIVE = 30; +enum class IncomingMessage { + UNKNOWN, + RELAY_POWER, + OTA +}; -using namespace homekit; +using namespace espMqttClientTypes; + +#define MD5_SIZE 16 + +MQTT::MQTT() { + auto cfg = config::read(); + homeId = String(cfg.flags.node_configured ? cfg.home_id : wifi::HOME_ID); -MQTT::MQTT() : client(wifiClient) { randomSeed(micros()); - wifiClient.setFingerprint(MQTT_CA_FINGERPRINT); + client.onConnect([&](bool sessionPresent) { + PRINTLN("mqtt: connected"); - client.setServer(MQTT_SERVER, MQTT_PORT); - client.setCallback([&](char* topic, byte* payload, unsigned int length) { - this->callback(topic, payload, length); + sendInitialStat(); + + subscribe(TOPIC_RELAY_POWER, 1); + subscribe(TOPIC_ADMIN_OTA); + }); + + client.onDisconnect([&](DisconnectReason reason) { + PRINTF("mqtt: disconnected, reason=%d\n", static_cast<int>(reason)); +#ifdef DEBUG + if (reason == DisconnectReason::TLS_BAD_FINGERPRINT) + PRINTLN("reason: bad fingerprint"); +#endif + + if (ota.started()) { + PRINTLN("mqtt: update was in progress, canceling.."); + ota.clean(); + Update.end(); + Update.clearError(); + } + + if (ota.readyToRestart) { + restartTimer.once(1, restart); + } else { + reconnectTimer.once(2, [&]() { + reconnect(); + }); + } + }); + + client.onSubscribe([&](uint16_t packetId, const SubscribeReturncode* returncodes, size_t len) { + PRINTF("mqtt: subscribe ack, packet_id=%d\n", packetId); + for (size_t i = 0; i < len; i++) { + PRINTF(" return code: %u\n", static_cast<unsigned int>(*(returncodes+i))); + } + }); + + client.onUnsubscribe([&](uint16_t packetId) { + PRINTF("mqtt: unsubscribe ack, packet_id=%d\n", packetId); + }); + + client.onMessage([&](const MessageProperties& properties, const char* topic, const uint8_t* payload, size_t len, size_t index, size_t total) { + PRINTF("mqtt: message received, topic=%s, qos=%d, dup=%d, retain=%d, len=%ul, index=%ul, total=%ul\n", + topic, properties.qos, (int)properties.dup, (int)properties.retain, len, index, total); + + IncomingMessage msgType = IncomingMessage::UNKNOWN; + + const char *ptr = topic + homeId.length() + 10; + String relevantTopic(ptr); + + if (relevantTopic == TOPIC_RELAY_POWER) + msgType = IncomingMessage::RELAY_POWER; + else if (relevantTopic == TOPIC_ADMIN_OTA) + msgType = IncomingMessage::OTA; + + if (len != total && msgType != IncomingMessage::OTA) { + PRINTLN("mqtt: received partial message, not supported"); + return; + } + + switch (msgType) { + case IncomingMessage::RELAY_POWER: + handleRelayPowerPayload(payload, total); + break; + + case IncomingMessage::OTA: + if (ota.finished) + break; + handleAdminOtaPayload(properties.packetId, payload, len, index, total); + break; + + case IncomingMessage::UNKNOWN: + PRINTF("error: invalid topic %s\n", topic); + break; + } }); + + client.onPublish([&](uint16_t packetId) { + PRINTF("mqtt: publish ack, packet_id=%d\n", packetId); + + if (ota.finished && packetId == ota.publishResultPacketId) { + ota.readyToRestart = true; + } + }); + + client.setServer(MQTT_SERVER, MQTT_PORT); + client.setClientId(MQTT_CLIENT_ID); + client.setCredentials(MQTT_USERNAME, MQTT_PASSWORD); + client.setCleanSession(true); + client.setFingerprint(MQTT_CA_FINGERPRINT); + client.setKeepAlive(MQTT_KEEPALIVE); } void MQTT::connect() { @@ -40,127 +142,96 @@ void MQTT::connect() { } void MQTT::reconnect() { - char buf[128] {0}; - if (client.connected()) { PRINTLN("warning: already connected"); return; } - - // Attempt to connect - if (client.connect(MQTT_CLIENT_ID, MQTT_USERNAME, MQTT_PASSWORD)) { - PRINTLN("mqtt: connected"); - - sendInitialStat(); - - subscribe(TOPIC_RELAY); - subscribe(TOPIC_ADMIN); - } else { - PRINTF("mqtt: failed to connect, rc=%d\n", client.state()); - wifiClient.getLastSSLError(buf, sizeof(buf)); - PRINTF("SSL error: %s\n", buf); - - reconnectTimer.once(2, [&]() { - reconnect(); - }); - } + client.connect(); } void MQTT::disconnect() { // TODO test how this works??? reconnectTimer.detach(); client.disconnect(); - wifiClient.stop(); } -bool MQTT::loop() { - return client.loop(); +uint16_t MQTT::publish(const String &topic, uint8_t *payload, size_t length) { + String fullTopic = "hk/" + homeId + "/relay/" + topic; + return client.publish(fullTopic.c_str(), 1, false, payload, length); } -bool MQTT::publish(const char* topic, uint8_t *payload, size_t length) { - char full_topic[40] {0}; - strcpy(full_topic, "/hk/"); - strcat(full_topic, wifi::NODE_ID); - strcat(full_topic, "/"); - strcat(full_topic, topic); - return client.publish(full_topic, payload, length); +void MQTT::loop() { + client.loop(); } -bool MQTT::subscribe(const char *topic) { - char full_topic[40] {0}; - strcpy(full_topic, "/hk/"); - strcat(full_topic, wifi::NODE_ID); - strcat(full_topic, "/"); - strcat(full_topic, topic); - strcat(full_topic, "/#"); - bool res = client.subscribe(full_topic, 1); - if (!res) - PRINTF("error: failed to subscribe to %s\n", full_topic); - return res; +uint16_t MQTT::subscribe(const String &topic, uint8_t qos) { + String fullTopic = "hk/" + homeId + "/relay/" + topic; + PRINTF("mqtt: subscribing to %s...\n", fullTopic.c_str()); + + uint16_t packetId = client.subscribe(fullTopic.c_str(), qos); + if (!packetId) + PRINTF("error: failed to subscribe to %s\n", fullTopic.c_str()); + return packetId; } void MQTT::sendInitialStat() { auto cfg = config::read(); - InitialStatPayload stat { - .ip = wifi::getIPAsInteger(), - .fw_version = FW_VERSION, - .rssi = wifi::getRSSI(), - .free_heap = ESP.getFreeHeap(), - .flags = StatFlags { - .state = static_cast<uint8_t>(relay::getState() ? 1 : 0), - .config_changed_value_present = 1, - .config_changed = static_cast<uint8_t>(cfg.flags.node_configured || cfg.flags.wifi_configured ? 1 : 0) - } + InitialStatPayload stat{ + .ip = wifi::getIPAsInteger(), + .fw_version = FW_VERSION, + .rssi = wifi::getRSSI(), + .free_heap = ESP.getFreeHeap(), + .flags = StatFlags{ + .state = static_cast<uint8_t>(relay::getState() ? 1 : 0), + .config_changed_value_present = 1, + .config_changed = static_cast<uint8_t>(cfg.flags.node_configured || + cfg.flags.wifi_configured ? 1 : 0) + } }; - publish(TOPIC_STAT1, reinterpret_cast<uint8_t*>(&stat), sizeof(stat)); + publish(TOPIC_INITIAL_STAT, reinterpret_cast<uint8_t*>(&stat), sizeof(stat)); statStopWatch.save(); } void MQTT::sendStat() { - StatPayload stat { + StatPayload stat{ .rssi = wifi::getRSSI(), .free_heap = ESP.getFreeHeap(), - .flags = StatFlags { + .flags = StatFlags{ .state = static_cast<uint8_t>(relay::getState() ? 1 : 0), .config_changed_value_present = 0, .config_changed = 0 } }; - - PRINTF("free heap: %d\n", ESP.getFreeHeap()); - publish(TOPIC_STAT, reinterpret_cast<uint8_t*>(&stat), sizeof(stat)); statStopWatch.save(); } -void MQTT::callback(char* topic, uint8_t* payload, uint32_t length) { - const size_t bufsize = 16; - char relevant_topic[bufsize]; - strncpy(relevant_topic, topic+strlen(wifi::NODE_ID)+5, bufsize); - - if (strncmp(TOPIC_RELAY_POWER, relevant_topic, bufsize) == 0) { - handleRelayPowerPayload(payload, length); - } else { - PRINTF("error: invalid topic %s\n", topic); - } +uint16_t MQTT::sendOtaResponse(OTAResult status, uint8_t error_code) { + OTAResponse resp{ + .status = status, + .error_code = error_code + }; + return publish(TOPIC_OTA_RESPONSE, reinterpret_cast<uint8_t*>(&resp), sizeof(resp)); } -void MQTT::handleRelayPowerPayload(uint8_t *payload, uint32_t length) { +void MQTT::handleRelayPowerPayload(const uint8_t *payload, uint32_t length) { if (length != sizeof(PowerPayload)) { PRINTF("error: size of payload (%ul) does not match expected (%ul)\n", length, sizeof(PowerPayload)); return; } - auto pd = reinterpret_cast<struct PowerPayload*>(payload); + auto pd = reinterpret_cast<const struct PowerPayload*>(payload); if (strncmp(pd->secret, MQTT_SECRET, sizeof(pd->secret)) != 0) { PRINTLN("error: invalid secret"); return; } if (pd->state == 1) { + PRINTLN("mqtt: turning relay on"); relay::setOn(); } else if (pd->state == 0) { + PRINTLN("mqtt: turning relay off"); relay::setOff(); } else { PRINTLN("error: unexpected state value"); @@ -169,4 +240,114 @@ void MQTT::handleRelayPowerPayload(uint8_t *payload, uint32_t length) { sendStat(); } +void MQTT::handleAdminOtaPayload(uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) { + char md5[33]; + char* md5Ptr = md5; + + if (index != 0 && ota.dataPacketId != packetId) { + PRINTLN("mqtt/ota: non-matching packet id"); + return; + } + + Update.runAsync(true); + + if (index == 0) { + if (length < HOME_SECRET_SIZE + MD5_SIZE) { + PRINTLN("mqtt/ota: failed to check secret, first packet size is too small"); + return; + } + + if (memcmp((const char*)payload, HOME_SECRET, HOME_SECRET_SIZE) != 0) { + PRINTLN("mqtt/ota: invalid secret"); + return; + } + + PRINTF("mqtt/ota: starting update, total=%ul\n", total-HOME_SECRET_SIZE); + for (int i = 0; i < MD5_SIZE; i++) { + md5Ptr += sprintf(md5Ptr, "%02x", *((unsigned char*)(payload+HOME_SECRET_SIZE+i))); + } + md5[32] = '\0'; + PRINTF("mqtt/ota: md5 is %s\n", md5); + PRINTF("mqtt/ota: first packet is %ul bytes length\n", length); + + md5[32] = '\0'; + + if (Update.isRunning()) { + Update.end(); + Update.clearError(); + } + + if (!Update.setMD5(md5)) { + PRINTLN("mqtt/ota: setMD5 failed"); + return; + } + + ota.dataPacketId = packetId; + + if (!Update.begin(total - HOME_SECRET_SIZE - MD5_SIZE)) { + ota.clean(); +#ifdef DEBUG + Update.printError(Serial); +#endif + sendOtaResponse(OTAResult::UPDATE_ERROR, Update.getError()); + } + + ota.written = Update.write(const_cast<uint8_t*>(payload)+HOME_SECRET_SIZE + MD5_SIZE, length-HOME_SECRET_SIZE - MD5_SIZE); + ota.written += HOME_SECRET_SIZE + MD5_SIZE; + + esp_led.blink(1, 1); + PRINTF("mqtt/ota: updating %u/%u\n", ota.written, Update.size()); + + } else { + if (!Update.isRunning()) { + PRINTLN("mqtt/ota: update is not running"); + return; + } + + if (index == ota.written) { + size_t written; + if ((written = Update.write(const_cast<uint8_t*>(payload), length)) != length) { + PRINTF("mqtt/ota: error: tried to write %ul bytes, write() returned %ul\n", + length, written); + ota.clean(); + Update.end(); + Update.clearError(); + sendOtaResponse(OTAResult::WRITE_ERROR); + return; + } + ota.written += length; + + esp_led.blink(1, 1); + PRINTF("mqtt/ota: updating %u/%u\n", + ota.written - HOME_SECRET_SIZE - MD5_SIZE, + Update.size()); + } else { + PRINTF("mqtt/ota: position is invalid, expected %ul, got %ul\n", ota.written, index); + ota.clean(); + Update.end(); + Update.clearError(); + } + } + + if (Update.isFinished()) { + ota.dataPacketId = 0; + + if (Update.end()) { + ota.finished = true; + ota.publishResultPacketId = sendOtaResponse(OTAResult::OK); + PRINTF("mqtt/ota: ok, otares packet_id=%d\n", ota.publishResultPacketId); + } else { + ota.clean(); + + PRINTF("mqtt/ota: error: %u\n", Update.getError()); +#ifdef DEBUG + Update.printError(Serial); +#endif + Update.clearError(); + + sendOtaResponse(OTAResult::UPDATE_ERROR, Update.getError()); + } + } +} + }
\ No newline at end of file |