summaryrefslogtreecommitdiff
path: root/include/pio/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp
blob: 4e976cde024773e768b0c5f6520c005c5d07a9bb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#include "./ota.h"
#include <homekit/logging.h>
#include <homekit/util.h>
#include <homekit/led.h>

namespace homekit::mqtt {

using homekit::led::mcu_led;

#define MD5_SIZE 16

static const char TOPIC_OTA[] = "ota";
static const char TOPIC_OTA_RESPONSE[] = "otares";

void MqttOtaModule::onConnect(Mqtt& mqtt) {
    String topic(TOPIC_OTA);
    mqtt.subscribeModule(topic, this);
}

void MqttOtaModule::tick(Mqtt& mqtt) {
    if (!tickElapsed())
        return;
}

void MqttOtaModule::handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) {
    char md5[33];
    char* md5Ptr = md5;

    if (index != 0 && ota.dataPacketId != packetId) {
        PRINTLN("mqtt/ota: non-matching packet id");
        return;
    }

    Update.runAsync(true);

    if (index == 0) {
        if (length < CONFIG_NODE_SECRET_SIZE + MD5_SIZE) {
            PRINTLN("mqtt/ota: failed to check secret, first packet size is too small");
            return;
        }

        if (memcmp((const char*)payload, CONFIG_NODE_SECRET, CONFIG_NODE_SECRET_SIZE) != 0) {
            PRINTLN("mqtt/ota: invalid secret");
            return;
        }

        PRINTF("mqtt/ota: starting update, total=%ul\n", total-CONFIG_NODE_SECRET_SIZE);
        for (int i = 0; i < MD5_SIZE; i++) {
            md5Ptr += sprintf(md5Ptr, "%02x", *((unsigned char*)(payload+CONFIG_NODE_SECRET_SIZE+i)));
        }
        md5[32] = '\0';
        PRINTF("mqtt/ota: md5 is %s\n", md5);
        PRINTF("mqtt/ota: first packet is %ul bytes length\n", length);

        md5[32] = '\0';

        if (Update.isRunning()) {
            Update.end();
            Update.clearError();
        }

        if (!Update.setMD5(md5)) {
            PRINTLN("mqtt/ota: setMD5 failed");
            return;
        }

        ota.dataPacketId = packetId;

        if (!Update.begin(total - CONFIG_NODE_SECRET_SIZE - MD5_SIZE)) {
            ota.clean();
#ifdef DEBUG
            Update.printError(Serial);
#endif
            sendResponse(mqtt, OtaResult::UPDATE_ERROR, Update.getError());
        }

        ota.written = Update.write(const_cast<uint8_t*>(payload)+CONFIG_NODE_SECRET_SIZE + MD5_SIZE, length-CONFIG_NODE_SECRET_SIZE - MD5_SIZE);
        ota.written += CONFIG_NODE_SECRET_SIZE + MD5_SIZE;

        mcu_led->blink(1, 1);
        PRINTF("mqtt/ota: updating %u/%u\n", ota.written, Update.size());

    } else {
        if (!Update.isRunning()) {
            PRINTLN("mqtt/ota: update is not running");
            return;
        }

        if (index == ota.written) {
            size_t written;
            if ((written = Update.write(const_cast<uint8_t*>(payload), length)) != length) {
                PRINTF("mqtt/ota: error: tried to write %ul bytes, write() returned %ul\n",
                       length, written);
                ota.clean();
                Update.end();
                Update.clearError();
                sendResponse(mqtt, OtaResult::WRITE_ERROR);
                return;
            }
            ota.written += length;

            mcu_led->blink(1, 1);
            PRINTF("mqtt/ota: updating %u/%u\n",
                   ota.written - CONFIG_NODE_SECRET_SIZE - MD5_SIZE,
                   Update.size());
        } else {
            PRINTF("mqtt/ota: position is invalid, expected %ul, got %ul\n", ota.written, index);
            ota.clean();
            Update.end();
            Update.clearError();
        }
    }

    if (Update.isFinished()) {
        ota.dataPacketId = 0;

        if (Update.end()) {
            ota.finished = true;
            ota.publishResultPacketId = sendResponse(mqtt, OtaResult::OK);
            PRINTF("mqtt/ota: ok, otares packet_id=%d\n", ota.publishResultPacketId);
        } else {
            ota.clean();

            PRINTF("mqtt/ota: error: %u\n", Update.getError());
#ifdef DEBUG
            Update.printError(Serial);
#endif
            Update.clearError();

            sendResponse(mqtt, OtaResult::UPDATE_ERROR, Update.getError());
        }
    }
}

uint16_t MqttOtaModule::sendResponse(Mqtt& mqtt, OtaResult status, uint8_t error_code) const {
    MqttOtaResponsePayload resp{
        .status = status,
        .error_code = error_code
    };
    return mqtt.publish(TOPIC_OTA_RESPONSE, reinterpret_cast<uint8_t*>(&resp), sizeof(resp));
}

void MqttOtaModule::onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason) {
    if (ota.readyToRestart) {
        restartTimer.once(1, restart);
    } else if (ota.started()) {
        PRINTLN("mqtt: update was in progress, canceling..");
        ota.clean();
        Update.end();
        Update.clearError();
    }
}

void MqttOtaModule::handleOnPublish(uint16_t packetId) {
    if (ota.finished && packetId == ota.publishResultPacketId) {
        ota.readyToRestart = true;
    }
}

}