diff options
author | Evgeny Zinoviev <me@ch1p.io> | 2023-05-31 09:22:00 +0300 |
---|---|---|
committer | Evgeny Zinoviev <me@ch1p.io> | 2023-06-10 02:07:23 +0300 |
commit | f29e139cbb7e4a4d539cba6e894ef4a6acd312d6 (patch) | |
tree | 6246f126325c5c36fb573134a05f2771cd747966 | |
parent | 3e3753d726f8a02d98368f20f77dd9fa739e3d80 (diff) |
WIP: big refactoring
114 files changed, 2335 insertions, 1330 deletions
@@ -17,6 +17,7 @@ platformio.ini CMakeListsPrivate.txt /platformio/*/CMakeLists.txt /platformio/*/CMakeListsPrivate.txt +/platformio/*/.gitignore *.swp /localwebsite/vendor diff --git a/misc/openwrt/etc/rc.local b/misc/openwrt/etc/rc.local index 407d1eb..32b1227 100644 --- a/misc/openwrt/etc/rc.local +++ b/misc/openwrt/etc/rc.local @@ -17,7 +17,7 @@ done sleep 0.1 # block internet access for untrusted cameras -iptables -I FORWARD 1 -m set --match-set ipcam src ! -d 192.168.5.0 -j REJECT +iptables -I FORWARD 1 -m set --match-set ipcam src ! -d 192.168.5.0/24 -j REJECT # add some default routing rules ipset add mts-azov 192.168.5.0/24 # everybody diff --git a/platformio/common/libs/main/homekit/main.cpp b/platformio/common/libs/main/homekit/main.cpp index fd08925..816c764 100644 --- a/platformio/common/libs/main/homekit/main.cpp +++ b/platformio/common/libs/main/homekit/main.cpp @@ -6,7 +6,12 @@ namespace homekit::main { +#ifndef CONFIG_TARGET_ESP01 +#ifndef CONFIG_NO_RECOVERY enum WorkingMode working_mode = WorkingMode::NORMAL; +#endif +#endif + static const uint16_t recovery_boot_detection_ms = 2000; static const uint8_t recovery_boot_delay_ms = 100; @@ -22,8 +27,10 @@ static StopWatch blinkStopWatch; #endif #ifndef CONFIG_TARGET_ESP01 +#ifndef CONFIG_NO_RECOVERY static DNSServer* dnsServer = nullptr; #endif +#endif static void onWifiConnected(const WiFiEventStationModeGotIP& event); static void onWifiDisconnected(const WiFiEventStationModeDisconnected& event); @@ -45,6 +52,7 @@ static void wifiConnect() { } #ifndef CONFIG_TARGET_ESP01 +#ifndef CONFIG_NO_RECOVERY static void wifiHotspot() { led::mcu_led->on(); @@ -71,13 +79,16 @@ static void waitForRecoveryPress() { } } #endif +#endif void setup() { WiFi.disconnect(); +#ifndef CONFIG_NO_RECOVERY #ifndef CONFIG_TARGET_ESP01 homekit::main::waitForRecoveryPress(); #endif +#endif #ifdef DEBUG Serial.begin(115200); @@ -95,6 +106,7 @@ void setup() { } #ifndef CONFIG_TARGET_ESP01 +#ifndef CONFIG_NO_RECOVERY switch (working_mode) { case WorkingMode::RECOVERY: wifiHotspot(); @@ -102,19 +114,24 @@ void setup() { case WorkingMode::NORMAL: #endif +#endif wifiConnectHandler = WiFi.onStationModeGotIP(onWifiConnected); wifiDisconnectHandler = WiFi.onStationModeDisconnected(onWifiDisconnected); wifiConnect(); +#ifndef CONFIG_NO_RECOVERY #ifndef CONFIG_TARGET_ESP01 break; } #endif +#endif } void loop(LoopConfig* config) { +#ifndef CONFIG_NO_RECOVERY #ifndef CONFIG_TARGET_ESP01 if (working_mode == WorkingMode::NORMAL) { #endif +#endif if (wifi_state == WiFiConnectionState::WAITING) { PRINT("."); led::mcu_led->blink(2, 50); @@ -166,6 +183,7 @@ void loop(LoopConfig* config) { } #endif } +#ifndef CONFIG_NO_RECOVERY #ifndef CONFIG_TARGET_ESP01 } else { if (dnsServer != nullptr) @@ -176,6 +194,7 @@ void loop(LoopConfig* config) { httpServer->loop(); } #endif +#endif } static void onWifiConnected(const WiFiEventStationModeGotIP& event) { @@ -191,4 +210,4 @@ static void onWifiDisconnected(const WiFiEventStationModeDisconnected& event) { wifiTimer.once(2, wifiConnect); } -}
\ No newline at end of file +} diff --git a/platformio/common/libs/main/homekit/main.h b/platformio/common/libs/main/homekit/main.h index a503dd0..78a0695 100644 --- a/platformio/common/libs/main/homekit/main.h +++ b/platformio/common/libs/main/homekit/main.h @@ -10,8 +10,10 @@ #include <homekit/config.h> #include <homekit/logging.h> #ifndef CONFIG_TARGET_ESP01 +#ifndef CONFIG_NO_RECOVERY #include <homekit/http_server.h> #endif +#endif #include <homekit/wifi.h> #include <homekit/mqtt/mqtt.h> @@ -20,6 +22,7 @@ namespace homekit::main { #ifndef CONFIG_TARGET_ESP01 +#ifndef CONFIG_NO_RECOVERY enum class WorkingMode { RECOVERY, // AP mode, http server with configuration NORMAL, // MQTT client @@ -27,6 +30,7 @@ enum class WorkingMode { extern enum WorkingMode working_mode; #endif +#endif enum class WiFiConnectionState { WAITING = 0, diff --git a/platformio/common/libs/main/library.json b/platformio/common/libs/main/library.json index 04eedab..728d4f8 100644 --- a/platformio/common/libs/main/library.json +++ b/platformio/common/libs/main/library.json @@ -1,6 +1,6 @@ { "name": "homekit_main", - "version": "1.0.8", + "version": "1.0.10", "build": { "flags": "-I../../include" }, diff --git a/platformio/common/libs/mqtt/homekit/mqtt/module.cpp b/platformio/common/libs/mqtt/homekit/mqtt/module.cpp index e78ff12..0ac7637 100644 --- a/platformio/common/libs/mqtt/homekit/mqtt/module.cpp +++ b/platformio/common/libs/mqtt/homekit/mqtt/module.cpp @@ -21,6 +21,6 @@ void MqttModule::handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, con void MqttModule::handleOnPublish(uint16_t packetId) {} -void MqttModule::handleOnDisconnect(espMqttClientTypes::DisconnectReason reason) {} +void MqttModule::onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason) {} } diff --git a/platformio/common/libs/mqtt/homekit/mqtt/module.h b/platformio/common/libs/mqtt/homekit/mqtt/module.h index e4a01f8..0a328f3 100644 --- a/platformio/common/libs/mqtt/homekit/mqtt/module.h +++ b/platformio/common/libs/mqtt/homekit/mqtt/module.h @@ -28,20 +28,25 @@ public: , receiveOnPublish(_receiveOnPublish) , receiveOnDisconnect(_receiveOnDisconnect) {} - virtual void init(Mqtt& mqtt) = 0; virtual void tick(Mqtt& mqtt) = 0; + virtual void onConnect(Mqtt& mqtt) = 0; + virtual void onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason); + virtual void handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total); virtual void handleOnPublish(uint16_t packetId); - virtual void handleOnDisconnect(espMqttClientTypes::DisconnectReason reason); inline void setInitialized() { initialized = true; } - inline short getTickInterval() { - return tickInterval; - } + inline void unsetInitialized() { + initialized = false; + } + + inline short getTickInterval() const { + return tickInterval; + } friend class Mqtt; }; diff --git a/platformio/common/libs/mqtt/homekit/mqtt/mqtt.cpp b/platformio/common/libs/mqtt/homekit/mqtt/mqtt.cpp index cb2cea7..aa769a5 100644 --- a/platformio/common/libs/mqtt/homekit/mqtt/mqtt.cpp +++ b/platformio/common/libs/mqtt/homekit/mqtt/mqtt.cpp @@ -34,7 +34,7 @@ Mqtt::Mqtt() { for (auto* module: modules) { if (!module->initialized) { - module->init(*this); + module->onConnect(*this); module->setInitialized(); } } @@ -50,18 +50,13 @@ Mqtt::Mqtt() { #endif for (auto* module: modules) { - if (module->receiveOnDisconnect) { - module->handleOnDisconnect(reason); - } + module->onDisconnect(*this, reason); + module->unsetInitialized(); } -// if (ota.readyToRestart) { -// restartTimer.once(1, restart); -// } else { - reconnectTimer.once(2, [&]() { - reconnect(); - }); -// } + reconnectTimer.once(2, [&]() { + reconnect(); + }); }); client.onSubscribe([&](uint16_t packetId, const SubscribeReturncode* returncodes, size_t len) { @@ -79,7 +74,7 @@ Mqtt::Mqtt() { PRINTF("mqtt: message received, topic=%s, qos=%d, dup=%d, retain=%d, len=%ul, index=%ul, total=%ul\n", topic, properties.qos, (int)properties.dup, (int)properties.retain, len, index, total); - const char *ptr = topic + nodeId.length() + 10; + const char *ptr = topic + nodeId.length() + 4; String relevantTopic(ptr); auto it = moduleSubscriptions.find(relevantTopic); @@ -87,7 +82,7 @@ Mqtt::Mqtt() { auto module = it->second; module->handlePayload(*this, relevantTopic, properties.packetId, payload, len, index, total); } else { - PRINTF("error: module subscription for topic %s not found\n", topic); + PRINTF("error: module subscription for topic %s not found\n", relevantTopic.c_str()); } }); @@ -130,8 +125,8 @@ void Mqtt::disconnect() { void Mqtt::loop() { client.loop(); for (auto& module: modules) { - if (module->getTickInterval() != 0) - module->tick(*this); + if (module->getTickInterval() != 0) + module->tick(*this); } } @@ -154,14 +149,14 @@ uint16_t Mqtt::subscribe(const String& topic, uint8_t qos) { void Mqtt::addModule(MqttModule* module) { modules.emplace_back(module); if (connected) { - module->init(*this); + module->onConnect(*this); module->setInitialized(); } } void Mqtt::subscribeModule(String& topic, MqttModule* module, uint8_t qos) { moduleSubscriptions[topic] = module; - subscribe(topic, qos); + subscribe(topic, qos); } } diff --git a/platformio/common/libs/mqtt/library.json b/platformio/common/libs/mqtt/library.json index d1ad420..f3f2504 100644 --- a/platformio/common/libs/mqtt/library.json +++ b/platformio/common/libs/mqtt/library.json @@ -1,6 +1,6 @@ { "name": "homekit_mqtt", - "version": "1.0.9", + "version": "1.0.11", "build": { "flags": "-I../../include" } diff --git a/platformio/common/libs/mqtt_module_diagnostics/homekit/mqtt/module/diagnostics.cpp b/platformio/common/libs/mqtt_module_diagnostics/homekit/mqtt/module/diagnostics.cpp index d36a7e9..e0f797e 100644 --- a/platformio/common/libs/mqtt_module_diagnostics/homekit/mqtt/module/diagnostics.cpp +++ b/platformio/common/libs/mqtt_module_diagnostics/homekit/mqtt/module/diagnostics.cpp @@ -7,12 +7,21 @@ namespace homekit::mqtt { static const char TOPIC_DIAGNOSTICS[] = "diag"; static const char TOPIC_INITIAL_DIAGNOSTICS[] = "d1ag"; -void MqttDiagnosticsModule::init(Mqtt& mqtt) {} +void MqttDiagnosticsModule::onConnect(Mqtt &mqtt) { + sendDiagnostics(mqtt); +} + +void MqttDiagnosticsModule::onDisconnect(Mqtt &mqtt, espMqttClientTypes::DisconnectReason reason) { + initialSent = false; +} void MqttDiagnosticsModule::tick(Mqtt& mqtt) { if (!tickElapsed()) return; + sendDiagnostics(mqtt); +} +void MqttDiagnosticsModule::sendDiagnostics(Mqtt& mqtt) { auto cfg = config::read(); if (!initialSent) { diff --git a/platformio/common/libs/mqtt_module_diagnostics/homekit/mqtt/module/diagnostics.h b/platformio/common/libs/mqtt_module_diagnostics/homekit/mqtt/module/diagnostics.h index 055c179..bb7a81a 100644 --- a/platformio/common/libs/mqtt_module_diagnostics/homekit/mqtt/module/diagnostics.h +++ b/platformio/common/libs/mqtt_module_diagnostics/homekit/mqtt/module/diagnostics.h @@ -32,12 +32,15 @@ class MqttDiagnosticsModule: public MqttModule { private: bool initialSent; + void sendDiagnostics(Mqtt& mqtt); + public: MqttDiagnosticsModule() : MqttModule(30) , initialSent(false) {} - void init(Mqtt& mqtt) override; + void onConnect(Mqtt& mqtt) override; + void onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason) override; void tick(Mqtt& mqtt) override; }; diff --git a/platformio/common/libs/mqtt_module_diagnostics/library.json b/platformio/common/libs/mqtt_module_diagnostics/library.json index 8df306d..a3d3244 100644 --- a/platformio/common/libs/mqtt_module_diagnostics/library.json +++ b/platformio/common/libs/mqtt_module_diagnostics/library.json @@ -1,10 +1,10 @@ { "name": "homekit_mqtt_module_diagnostics", - "version": "1.0.1", + "version": "1.0.2", "build": { "flags": "-I../../include" }, "dependencies": { - "homekit_mqtt": "file://../common/libs/mqtt" + "homekit_mqtt": "file://../common/libs/mqtt" } } diff --git a/platformio/common/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp b/platformio/common/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp index 2f5f814..4e976cd 100644 --- a/platformio/common/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp +++ b/platformio/common/libs/mqtt_module_ota/homekit/mqtt/module/ota.cpp @@ -12,7 +12,7 @@ using homekit::led::mcu_led; static const char TOPIC_OTA[] = "ota"; static const char TOPIC_OTA_RESPONSE[] = "otares"; -void MqttOtaModule::init(Mqtt& mqtt) { +void MqttOtaModule::onConnect(Mqtt& mqtt) { String topic(TOPIC_OTA); mqtt.subscribeModule(topic, this); } @@ -140,17 +140,15 @@ uint16_t MqttOtaModule::sendResponse(Mqtt& mqtt, OtaResult status, uint8_t error return mqtt.publish(TOPIC_OTA_RESPONSE, reinterpret_cast<uint8_t*>(&resp), sizeof(resp)); } -void MqttOtaModule::handleOnDisconnect(espMqttClientTypes::DisconnectReason reason) { - if (ota.started()) { +void MqttOtaModule::onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason) { + if (ota.readyToRestart) { + restartTimer.once(1, restart); + } else if (ota.started()) { PRINTLN("mqtt: update was in progress, canceling.."); ota.clean(); Update.end(); Update.clearError(); } - - if (ota.readyToRestart) { - restartTimer.once(1, restart); - } } void MqttOtaModule::handleOnPublish(uint16_t packetId) { diff --git a/platformio/common/libs/mqtt_module_ota/homekit/mqtt/module/ota.h b/platformio/common/libs/mqtt_module_ota/homekit/mqtt/module/ota.h index 53613c3..df4f7ce 100644 --- a/platformio/common/libs/mqtt_module_ota/homekit/mqtt/module/ota.h +++ b/platformio/common/libs/mqtt_module_ota/homekit/mqtt/module/ota.h @@ -57,11 +57,14 @@ private: public: MqttOtaModule() : MqttModule(0, true, true) {} - void init(Mqtt& mqtt) override; + void onConnect(Mqtt& mqtt) override; + void onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason) override; + void tick(Mqtt& mqtt) override; + void handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) override; void handleOnPublish(uint16_t packetId) override; - void handleOnDisconnect(espMqttClientTypes::DisconnectReason reason) override; + inline bool isReadyToRestart() const { return ota.readyToRestart; } diff --git a/platformio/common/libs/mqtt_module_ota/library.json b/platformio/common/libs/mqtt_module_ota/library.json index 30db7d2..4f40a47 100644 --- a/platformio/common/libs/mqtt_module_ota/library.json +++ b/platformio/common/libs/mqtt_module_ota/library.json @@ -1,11 +1,11 @@ { "name": "homekit_mqtt_module_ota", - "version": "1.0.2", + "version": "1.0.5", "build": { "flags": "-I../../include" }, "dependencies": { "homekit_led": "file://../common/libs/led", - "homekit_mqtt": "file://../common/libs/mqtt" + "homekit_mqtt": "file://../common/libs/mqtt" } } diff --git a/platformio/common/libs/mqtt_module_relay/homekit/mqtt/module/relay.cpp b/platformio/common/libs/mqtt_module_relay/homekit/mqtt/module/relay.cpp index ab40727..90c57f9 100644 --- a/platformio/common/libs/mqtt_module_relay/homekit/mqtt/module/relay.cpp +++ b/platformio/common/libs/mqtt_module_relay/homekit/mqtt/module/relay.cpp @@ -5,19 +5,28 @@ namespace homekit::mqtt { static const char TOPIC_RELAY_SWITCH[] = "relay/switch"; +static const char TOPIC_RELAY_STATUS[] = "relay/status"; -void MqttRelayModule::init(Mqtt &mqtt) { - String topic(TOPIC_RELAY_SWITCH); - mqtt.subscribeModule(topic, this, 1); +void MqttRelayModule::onConnect(Mqtt &mqtt) { + String topic(TOPIC_RELAY_SWITCH); + mqtt.subscribeModule(topic, this, 1); +} + +void MqttRelayModule::onDisconnect(Mqtt &mqtt, espMqttClientTypes::DisconnectReason reason) { +#ifdef CONFIG_RELAY_OFF_ON_DISCONNECT + if (relay::state()) { + relay::off(); + } +#endif } void MqttRelayModule::tick(homekit::mqtt::Mqtt& mqtt) {} void MqttRelayModule::handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) { - if (topic != TOPIC_RELAY_SWITCH) - return; + if (topic != TOPIC_RELAY_SWITCH) + return; - if (length != sizeof(MqttRelaySwitchPayload)) { + if (length != sizeof(MqttRelaySwitchPayload)) { PRINTF("error: size of payload (%ul) does not match expected (%ul)\n", length, sizeof(MqttRelaySwitchPayload)); return; @@ -29,6 +38,8 @@ void MqttRelayModule::handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId return; } + MqttRelayStatusPayload resp{}; + if (pd->state == 1) { PRINTLN("mqtt: turning relay on"); relay::on(); @@ -38,6 +49,10 @@ void MqttRelayModule::handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId } else { PRINTLN("error: unexpected state value"); } + + resp.opened = relay::state(); + mqtt.publish(TOPIC_RELAY_STATUS, reinterpret_cast<uint8_t*>(&resp), sizeof(resp)); } } + diff --git a/platformio/common/libs/mqtt_module_relay/homekit/mqtt/module/relay.h b/platformio/common/libs/mqtt_module_relay/homekit/mqtt/module/relay.h index 6420de1..e245527 100644 --- a/platformio/common/libs/mqtt_module_relay/homekit/mqtt/module/relay.h +++ b/platformio/common/libs/mqtt_module_relay/homekit/mqtt/module/relay.h @@ -10,14 +10,20 @@ struct MqttRelaySwitchPayload { uint8_t state; } __attribute__((packed)); +struct MqttRelayStatusPayload { + uint8_t opened; +} __attribute__((packed)); + class MqttRelayModule : public MqttModule { public: MqttRelayModule() : MqttModule(0) {} - void init(Mqtt& mqtt) override; + void onConnect(Mqtt& mqtt) override; + void onDisconnect(Mqtt& mqtt, espMqttClientTypes::DisconnectReason reason) override; void tick(Mqtt& mqtt) override; - void handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) override; + void handlePayload(Mqtt& mqtt, String& topic, uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) override; }; } #endif //HOMEKIT_LIB_MQTT_MODULE_RELAY_H + diff --git a/platformio/common/libs/mqtt_module_relay/library.json b/platformio/common/libs/mqtt_module_relay/library.json index e71cf95..6cbbfb0 100644 --- a/platformio/common/libs/mqtt_module_relay/library.json +++ b/platformio/common/libs/mqtt_module_relay/library.json @@ -1,6 +1,6 @@ { "name": "homekit_mqtt_module_relay", - "version": "1.0.3", + "version": "1.0.5", "build": { "flags": "-I../../include" }, diff --git a/platformio/common/libs/mqtt_module_temphum/homekit/mqtt/module/temphum.cpp b/platformio/common/libs/mqtt_module_temphum/homekit/mqtt/module/temphum.cpp index 82f1d74..409f38f 100644 --- a/platformio/common/libs/mqtt_module_temphum/homekit/mqtt/module/temphum.cpp +++ b/platformio/common/libs/mqtt_module_temphum/homekit/mqtt/module/temphum.cpp @@ -4,7 +4,7 @@ namespace homekit::mqtt { static const char TOPIC_TEMPHUM_DATA[] = "temphum/data"; -void MqttTemphumModule::init(Mqtt &mqtt) {} +void MqttTemphumModule::onConnect(Mqtt &mqtt) {} void MqttTemphumModule::tick(homekit::mqtt::Mqtt& mqtt) { if (!tickElapsed()) diff --git a/platformio/common/libs/mqtt_module_temphum/homekit/mqtt/module/temphum.h b/platformio/common/libs/mqtt_module_temphum/homekit/mqtt/module/temphum.h index 5c41cef..7b28afc 100644 --- a/platformio/common/libs/mqtt_module_temphum/homekit/mqtt/module/temphum.h +++ b/platformio/common/libs/mqtt_module_temphum/homekit/mqtt/module/temphum.h @@ -19,7 +19,7 @@ private: public: MqttTemphumModule(temphum::Sensor* _sensor) : MqttModule(10), sensor(_sensor) {} - void init(Mqtt& mqtt) override; + void onConnect(Mqtt& mqtt) override; void tick(Mqtt& mqtt) override; }; diff --git a/platformio/common/libs/mqtt_module_temphum/library.json b/platformio/common/libs/mqtt_module_temphum/library.json index 9bb8cf1..068debd 100644 --- a/platformio/common/libs/mqtt_module_temphum/library.json +++ b/platformio/common/libs/mqtt_module_temphum/library.json @@ -1,11 +1,11 @@ { "name": "homekit_mqtt_module_temphum", - "version": "1.0.9", + "version": "1.0.10", "build": { "flags": "-I../../include" }, "dependencies": { - "homekit_mqtt": "file://../common/libs/mqtt", + "homekit_mqtt": "file://../common/libs/mqtt", "homekit_temphum": "file://../common/libs/temphum" } } diff --git a/platformio/temphum_relayctl/src/main.cpp b/platformio/temphum_relayctl/src/main.cpp index 0b05316..7f0945e 100644 --- a/platformio/temphum_relayctl/src/main.cpp +++ b/platformio/temphum_relayctl/src/main.cpp @@ -27,6 +27,7 @@ void setup() { main::setup(); relay::init(); + relay::off(); #if CONFIG_MODULE == HOMEKIT_SI7021 sensor = new temphum::Si7021(); diff --git a/requirements.txt b/requirements.txt index 46f9b8c..4595dea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,25 +1,24 @@ paho-mqtt==1.6.1 inverterd~=1.0.3 clickhouse-driver~=0.2.0 -toml~=0.10.2 mysql-connector-python~=8.0.27 -Werkzeug==2.2.2 +Werkzeug==2.3.6 uwsgi~=2.0.20 -python-telegram-bot==13.15 -requests==2.28.1 +python-telegram-bot==20.3 +requests==2.31.0 aiohttp~=3.8.1 -pytz==2022.6 +pytz==2023.3 PyYAML~=6.0 -apscheduler~=3.9.1 +apscheduler==3.10.1 psutil~=5.9.1 aioshutil~=1.1 -scikit-image~=0.19.3 - +scikit-image==0.21.0 +cerberus~=1.3.4 # following can be installed from debian repositories # matplotlib~=3.5.0 -Pillow~=9.1.1 +Pillow==9.5.0 # for polaris kettle protocol implementation -cryptography==38.0.4 -zeroconf==0.39.4
\ No newline at end of file +cryptography==41.0.1 +zeroconf==0.64.1
\ No newline at end of file diff --git a/src/camera_node.py b/src/camera_node.py index d175e17..3f2c5a4 100755 --- a/src/camera_node.py +++ b/src/camera_node.py @@ -65,7 +65,7 @@ class ESP32CameraNodeServer(MediaNodeServer): if __name__ == '__main__': - config.load('camera_node') + config.load_app('camera_node') recorder_kwargs = {} camera_type = CameraType(config['camera']['type']) diff --git a/src/esp32_capture.py b/src/esp32_capture.py index 4a9ce10..0441565 100755 --- a/src/esp32_capture.py +++ b/src/esp32_capture.py @@ -5,7 +5,7 @@ import os.path from argparse import ArgumentParser from home.camera.esp32 import WebClient -from home.util import parse_addr, Addr +from home.util import Addr from apscheduler.schedulers.asyncio import AsyncIOScheduler from datetime import datetime from typing import Optional @@ -50,7 +50,7 @@ if __name__ == '__main__': loop = asyncio.get_event_loop() - ESP32Capture(parse_addr(arg.addr), arg.interval, arg.output_directory) + ESP32Capture(Addr.fromstring(arg.addr), arg.interval, arg.output_directory) try: loop.run_forever() except KeyboardInterrupt: diff --git a/src/esp32cam_capture_diff_node.py b/src/esp32cam_capture_diff_node.py index 4363e9e..59482f7 100755 --- a/src/esp32cam_capture_diff_node.py +++ b/src/esp32cam_capture_diff_node.py @@ -7,7 +7,7 @@ import home.telegram.aio as telegram from home.config import config from home.camera.esp32 import WebClient -from home.util import parse_addr, send_datagram, stringify +from home.util import Addr, send_datagram, stringify from apscheduler.schedulers.asyncio import AsyncIOScheduler from typing import Optional @@ -34,11 +34,11 @@ async def pyssim(fn1: str, fn2: str) -> float: class ESP32CamCaptureDiffNode: def __init__(self): - self.client = WebClient(parse_addr(config['esp32cam_web_addr'])) + self.client = WebClient(Addr.fromstring(config['esp32cam_web_addr'])) self.directory = tempfile.gettempdir() self.nextpic = 1 self.first = True - self.server_addr = parse_addr(config['node']['server_addr']) + self.server_addr = Addr.fromstring(config['node']['server_addr']) self.scheduler = AsyncIOScheduler() self.scheduler.add_job(self.capture, 'interval', seconds=config['node']['interval']) @@ -76,7 +76,7 @@ class ESP32CamCaptureDiffNode: if __name__ == '__main__': - config.load('esp32cam_capture_diff_node') + config.load_app('esp32cam_capture_diff_node') loop = asyncio.get_event_loop() ESP32CamCaptureDiffNode() diff --git a/src/esp_mqtt_util.py b/src/esp_mqtt_util.py deleted file mode 100755 index 263128c..0000000 --- a/src/esp_mqtt_util.py +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env python3 -from typing import Optional -from argparse import ArgumentParser -from enum import Enum - -from home.config import config -from home.mqtt import MqttRelay -from home.mqtt.esp import MqttEspBase -from home.mqtt.temphum import MqttTempHum -from home.mqtt.esp import MqttEspDevice - -mqtt_client: Optional[MqttEspBase] = None - - -class NodeType(Enum): - RELAY = 'relay' - TEMPHUM = 'temphum' - - -if __name__ == '__main__': - parser = ArgumentParser() - parser.add_argument('--device-id', type=str, required=True) - parser.add_argument('--type', type=str, required=True, - choices=[i.name.lower() for i in NodeType]) - - config.load('mqtt_util', parser=parser) - arg = parser.parse_args() - - mqtt_node_type = NodeType(arg.type) - devices = MqttEspDevice(id=arg.device_id) - - if mqtt_node_type == NodeType.RELAY: - mqtt_client = MqttRelay(devices=devices) - elif mqtt_node_type == NodeType.TEMPHUM: - mqtt_client = MqttTempHum(devices=devices) - - mqtt_client.set_message_callback(lambda device_id, payload: print(payload)) - mqtt_client.configure_tls() - try: - mqtt_client.connect_and_loop() - except KeyboardInterrupt: - mqtt_client.disconnect() diff --git a/src/gpiorelayd.py b/src/gpiorelayd.py index 85015a7..f1a9e57 100755 --- a/src/gpiorelayd.py +++ b/src/gpiorelayd.py @@ -13,7 +13,7 @@ if __name__ == '__main__': if not os.getegid() == 0: sys.exit('Must be run as root.') - config.load() + config.load_app() try: s = RelayServer(pinname=config.get('relayd.pin'), diff --git a/src/home/audio/amixer.py b/src/home/audio/amixer.py index 53e6bce..5133c97 100644 --- a/src/home/audio/amixer.py +++ b/src/home/audio/amixer.py @@ -1,6 +1,6 @@ import subprocess -from ..config import config +from ..config import app_config as config from threading import Lock from typing import Union, List diff --git a/src/home/config/__init__.py b/src/home/config/__init__.py index cc9c091..2fa5214 100644 --- a/src/home/config/__init__.py +++ b/src/home/config/__init__.py @@ -1 +1,13 @@ -from .config import ConfigStore, config, is_development_mode, setup_logging +from .config import ( + Config, + ConfigUnit, + AppConfigUnit, + Translation, + config, + is_development_mode, + setup_logging +) +from ._configs import ( + LinuxBoardsConfig, + ServicesListConfig +)
\ No newline at end of file diff --git a/src/home/config/_configs.py b/src/home/config/_configs.py new file mode 100644 index 0000000..3a1aae5 --- /dev/null +++ b/src/home/config/_configs.py @@ -0,0 +1,55 @@ +from .config import ConfigUnit +from typing import Optional + + +class ServicesListConfig(ConfigUnit): + NAME = 'services_list' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'type': 'list', + 'empty': False, + 'schema': { + 'type': 'string' + } + } + + +class LinuxBoardsConfig(ConfigUnit): + NAME = 'linux_boards' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'type': 'dict', + 'schema': { + 'mdns': {'type': 'string', 'required': True}, + 'board': {'type': 'string', 'required': True}, + 'network': { + 'type': 'list', + 'required': True, + 'empty': False, + 'allowed': ['wifi', 'ethernet'] + }, + 'ram': {'type': 'integer', 'required': True}, + 'online': {'type': 'boolean', 'required': True}, + + # optional + 'services': { + 'type': 'list', + 'empty': False, + 'allowed': ServicesListConfig().get() + }, + 'ext_hdd': { + 'type': 'list', + 'schema': { + 'type': 'dict', + 'schema': { + 'mountpoint': {'type': 'string', 'required': True}, + 'size': {'type': 'integer', 'required': True} + } + }, + }, + } + } diff --git a/src/home/config/config.py b/src/home/config/config.py index 4681685..aef9ee7 100644 --- a/src/home/config/config.py +++ b/src/home/config/config.py @@ -1,58 +1,256 @@ -import toml import yaml import logging import os +import pprint -from os.path import join, isdir, isfile -from typing import Optional, Any, MutableMapping +from abc import ABC +from cerberus import Validator, DocumentError +from typing import Optional, Any, MutableMapping, Union from argparse import ArgumentParser -from ..util import parse_addr +from enum import Enum, auto +from os.path import join, isdir, isfile +from ..util import Addr + + +CONFIG_DIRECTORIES = ( + join(os.environ['HOME'], '.config', 'homekit'), + '/etc/homekit' +) + +class RootSchemaType(Enum): + DEFAULT = auto() + DICT = auto() + LIST = auto() + + +class BaseConfigUnit(ABC): + _data: MutableMapping[str, Any] + _logger: logging.Logger + def __init__(self): + self._data = {} + self._logger = logging.getLogger(self.__class__.__name__) + + def __getitem__(self, key): + return self._data[key] + + def __setitem__(self, key, value): + raise NotImplementedError('overwriting config values is prohibited') -def _get_config_path(name: str) -> str: - formats = ['toml', 'yaml'] + def __contains__(self, key): + return key in self._data - dirname = join(os.environ['HOME'], '.config', name) + def load_from(self, path: str): + with open(path, 'r') as fd: + self._data = yaml.safe_load(fd) - if isdir(dirname): - for fmt in formats: - filename = join(dirname, f'config.{fmt}') - if isfile(filename): - return filename + def get(self, + key: Optional[str] = None, + default=None): + if key is None: + return self._data - raise IOError(f'config not found in {dirname}') + cur = self._data + pts = key.split('.') + for i in range(len(pts)): + k = pts[i] + if i < len(pts)-1: + if k not in cur: + raise KeyError(f'key {k} not found') + else: + return cur[k] if k in cur else default + cur = self._data[k] - else: - filenames = [join(os.environ['HOME'], '.config', f'{name}.{format}') for format in formats] - for file in filenames: - if isfile(file): - return file + raise KeyError(f'option {key} not found') - raise IOError(f'config not found') +class ConfigUnit(BaseConfigUnit): + NAME = 'dumb' + + def __init__(self, name=None, load=True): + super().__init__() + + self._data = {} + self._logger = logging.getLogger(self.__class__.__name__) + + if self.NAME != 'dumb' and load: + self.load_from(self.get_config_path()) + self.validate() + + elif name is not None: + self.NAME = name + + @classmethod + def get_config_path(cls, name=None) -> str: + if name is None: + name = cls.NAME + if name is None: + raise ValueError('get_config_path: name is none') + + for dirname in CONFIG_DIRECTORIES: + if isdir(dirname): + filename = join(dirname, f'{name}.yaml') + if isfile(filename): + return filename + + raise IOError(f'\'{name}.yaml\' not found') + + @staticmethod + def schema() -> Optional[dict]: + return None + + def validate(self): + schema = self.schema() + if not schema: + self._logger.warning('validate: no schema') + return + + if isinstance(self, AppConfigUnit): + schema['logging'] = { + 'type': 'dict', + 'schema': { + 'logging': {'type': 'bool'} + } + } + + rst = RootSchemaType.DEFAULT + try: + if schema['type'] == 'dict': + rst = RootSchemaType.DICT + elif schema['type'] == 'list': + rst = RootSchemaType.LIST + elif schema['roottype'] == 'dict': + del schema['roottype'] + rst = RootSchemaType.DICT + except KeyError: + pass + + if rst == RootSchemaType.DICT: + v = Validator({'document': { + 'type': 'dict', + 'keysrules': {'type': 'string'}, + 'valuesrules': schema + }}) + result = v.validate({'document': self._data}) + elif rst == RootSchemaType.LIST: + v = Validator({'document': schema}) + result = v.validate({'document': self._data}) + else: + v = Validator(schema) + result = v.validate(self._data) + # pprint.pprint(self._data) + if not result: + # pprint.pprint(v.errors) + raise DocumentError(f'{self.__class__.__name__}: failed to validate data:\n{pprint.pformat(v.errors)}') + try: + self.custom_validator(self._data) + except Exception as e: + raise DocumentError(f'{self.__class__.__name__}: {str(e)}') + + @staticmethod + def custom_validator(data): + pass -class ConfigStore: - data: MutableMapping[str, Any] + def get_addr(self, key: str): + return Addr.fromstring(self.get(key)) + + +class AppConfigUnit(ConfigUnit): + _logging_verbose: bool + _logging_fmt: Optional[str] + _logging_file: Optional[str] + + def __init__(self, *args, **kwargs): + super().__init__(load=False, *args, **kwargs) + self._logging_verbose = False + self._logging_fmt = None + self._logging_file = None + + def logging_set_fmt(self, fmt: str) -> None: + self._logging_fmt = fmt + + def logging_get_fmt(self) -> Optional[str]: + try: + return self['logging']['default_fmt'] + except KeyError: + return self._logging_fmt + + def logging_set_file(self, file: str) -> None: + self._logging_file = file + + def logging_get_file(self) -> Optional[str]: + try: + return self['logging']['file'] + except KeyError: + return self._logging_file + + def logging_set_verbose(self): + self._logging_verbose = True + + def logging_is_verbose(self) -> bool: + try: + return bool(self['logging']['verbose']) + except KeyError: + return self._logging_verbose + + +class TranslationUnit(BaseConfigUnit): + pass + + +class Translation: + LANGUAGES = ('en', 'ru') + _langs: dict[str, TranslationUnit] + + def __init__(self, name: str): + super().__init__() + self._langs = {} + for lang in self.LANGUAGES: + for dirname in CONFIG_DIRECTORIES: + if isdir(dirname): + filename = join(dirname, f'i18n-{lang}', f'{name}.yaml') + if lang in self._langs: + raise RuntimeError(f'{name}: translation unit for lang \'{lang}\' already loaded') + self._langs[lang] = TranslationUnit() + self._langs[lang].load_from(filename) + diff = set() + for data in self._langs.values(): + diff ^= data.get().keys() + if len(diff) > 0: + raise RuntimeError(f'{name}: translation units have difference in keys: ' + ', '.join(diff)) + + def get(self, lang: str) -> TranslationUnit: + return self._langs[lang] + + +class Config: app_name: Optional[str] + app_config: AppConfigUnit def __init__(self): - self.data = {} self.app_name = None + self.app_config = AppConfigUnit() + + def load_app(self, + name: Optional[Union[str, AppConfigUnit, bool]] = None, + use_cli=True, + parser: ArgumentParser = None, + no_config=False): + global app_config + + if issubclass(name, AppConfigUnit) or name == AppConfigUnit: + self.app_name = name.NAME + self.app_config = name() + app_config = self.app_config + else: + self.app_name = name if isinstance(name, str) else None - def load(self, name: Optional[str] = None, - use_cli=True, - parser: ArgumentParser = None): - self.app_name = name - - if (name is None) and (not use_cli): + if self.app_name is None and not use_cli: raise RuntimeError('either config name must be none or use_cli must be True') - log_default_fmt = False - log_file = None - log_verbose = False - no_config = name is False - + no_config = name is False or no_config path = None + if use_cli: if parser is None: parser = ArgumentParser() @@ -68,75 +266,38 @@ class ConfigStore: path = args.config if args.verbose: - log_verbose = True + self.app_config.logging_set_verbose() if args.log_file: - log_file = args.log_file + self.app_config.logging_set_file(args.log_file) if args.log_default_fmt: - log_default_fmt = args.log_default_fmt + self.app_config.logging_set_fmt(args.log_default_fmt) - if not no_config and path is None: - path = _get_config_path(name) + if not isinstance(name, ConfigUnit): + if not no_config and path is None: + path = ConfigUnit.get_config_path(name=self.app_name) - if no_config: - self.data = {} - else: - if path.endswith('.toml'): - self.data = toml.load(path) - elif path.endswith('.yaml'): - with open(path, 'r') as fd: - self.data = yaml.safe_load(fd) - - if 'logging' in self: - if not log_file and 'file' in self['logging']: - log_file = self['logging']['file'] - if log_default_fmt and 'default_fmt' in self['logging']: - log_default_fmt = self['logging']['default_fmt'] + if not no_config: + self.app_config.load_from(path) - setup_logging(log_verbose, log_file, log_default_fmt) + setup_logging(self.app_config.logging_is_verbose(), + self.app_config.logging_get_file(), + self.app_config.logging_get_fmt()) if use_cli: return args - def __getitem__(self, key): - return self.data[key] - - def __setitem__(self, key, value): - raise NotImplementedError('overwriting config values is prohibited') - - def __contains__(self, key): - return key in self.data - - def get(self, key: str, default=None): - cur = self.data - pts = key.split('.') - for i in range(len(pts)): - k = pts[i] - if i < len(pts)-1: - if k not in cur: - raise KeyError(f'key {k} not found') - else: - return cur[k] if k in cur else default - cur = self.data[k] - raise KeyError(f'option {key} not found') - - def get_addr(self, key: str): - return parse_addr(self.get(key)) - - def items(self): - return self.data.items() - -config = ConfigStore() +config = Config() def is_development_mode() -> bool: if 'HK_MODE' in os.environ and os.environ['HK_MODE'] == 'dev': return True - return ('logging' in config) and ('verbose' in config['logging']) and (config['logging']['verbose'] is True) + return ('logging' in config.app_config) and ('verbose' in config.app_config['logging']) and (config.app_config['logging']['verbose'] is True) -def setup_logging(verbose=False, log_file=None, default_fmt=False): +def setup_logging(verbose=False, log_file=None, default_fmt=None): logging_level = logging.INFO if is_development_mode() or verbose: logging_level = logging.DEBUG diff --git a/src/home/database/clickhouse.py b/src/home/database/clickhouse.py index ca81628..d0ec283 100644 --- a/src/home/database/clickhouse.py +++ b/src/home/database/clickhouse.py @@ -1,7 +1,7 @@ import logging from zoneinfo import ZoneInfo -from datetime import datetime, timedelta +from datetime import datetime from clickhouse_driver import Client as ClickhouseClient from ..config import is_development_mode diff --git a/src/home/database/sqlite.py b/src/home/database/sqlite.py index bfba929..8c6145c 100644 --- a/src/home/database/sqlite.py +++ b/src/home/database/sqlite.py @@ -5,24 +5,27 @@ import logging from ..config import config, is_development_mode -def _get_database_path(name: str, dbname: str) -> str: - return os.path.join(os.environ['HOME'], '.config', name, f'{dbname}.db') +def _get_database_path(name: str) -> str: + return os.path.join( + os.environ['HOME'], + '.config', + 'homekit', + 'data', + f'{name}.db') class SQLiteBase: SCHEMA = 1 - def __init__(self, name=None, dbname='bot', check_same_thread=False): - db_path = config.get('db_path', default=None) - if db_path is None: - if not name: - name = config.app_name - if not dbname: - dbname = name - db_path = _get_database_path(name, dbname) + def __init__(self, name=None, check_same_thread=False): + if name is None: + name = config.app_config['database_name'] + database_path = _get_database_path(name) + if not os.path.exists(os.path.dirname(database_path)): + os.makedirs(os.path.dirname(database_path)) self.logger = logging.getLogger(self.__class__.__name__) - self.sqlite = sqlite3.connect(db_path, check_same_thread=check_same_thread) + self.sqlite = sqlite3.connect(database_path, check_same_thread=check_same_thread) if is_development_mode(): self.sql_logger = logging.getLogger(self.__class__.__name__) diff --git a/src/home/inverter/config.py b/src/home/inverter/config.py new file mode 100644 index 0000000..62b8859 --- /dev/null +++ b/src/home/inverter/config.py @@ -0,0 +1,13 @@ +from ..config import ConfigUnit +from typing import Optional + + +class InverterdConfig(ConfigUnit): + NAME = 'inverterd' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'remote_addr': {'type': 'string'}, + 'local_addr': {'type': 'string'}, + }
\ No newline at end of file diff --git a/src/home/media/__init__.py b/src/home/media/__init__.py index 976c990..6923105 100644 --- a/src/home/media/__init__.py +++ b/src/home/media/__init__.py @@ -12,6 +12,7 @@ __map__ = { __all__ = list(itertools.chain(*__map__.values())) + def __getattr__(name): if name in __all__: for file, names in __map__.items(): diff --git a/src/home/mqtt/__init__.py b/src/home/mqtt/__init__.py index 982e2b6..707d59c 100644 --- a/src/home/mqtt/__init__.py +++ b/src/home/mqtt/__init__.py @@ -1,4 +1,7 @@ -from .mqtt import MqttBase -from .util import poll_tick -from .relay import MqttRelay, MqttRelayState -from .temphum import MqttTempHum
\ No newline at end of file +from ._mqtt import Mqtt +from ._node import MqttNode +from ._module import MqttModule +from ._wrapper import MqttWrapper +from ._config import MqttConfig, MqttCreds, MqttNodesConfig +from ._payload import MqttPayload, MqttPayloadCustomField +from ._util import get_modules as get_mqtt_modules
\ No newline at end of file diff --git a/src/home/mqtt/_config.py b/src/home/mqtt/_config.py new file mode 100644 index 0000000..f9047b4 --- /dev/null +++ b/src/home/mqtt/_config.py @@ -0,0 +1,165 @@ +from ..config import ConfigUnit +from typing import Optional, Union +from ..util import Addr +from collections import namedtuple + +MqttCreds = namedtuple('MqttCreds', 'username, password') + + +class MqttConfig(ConfigUnit): + NAME = 'mqtt' + + @staticmethod + def schema() -> Optional[dict]: + addr_schema = { + 'type': 'dict', + 'required': True, + 'schema': { + 'host': {'type': 'string', 'required': True}, + 'port': {'type': 'integer', 'required': True} + } + } + + schema = {} + for key in ('local', 'remote'): + schema[f'{key}_addr'] = addr_schema + + schema['creds'] = { + 'type': 'dict', + 'required': True, + 'keysrules': {'type': 'string'}, + 'valuesrules': { + 'type': 'dict', + 'schema': { + 'username': {'type': 'string', 'required': True}, + 'password': {'type': 'string', 'required': True}, + } + } + } + + for key in ('client', 'server'): + schema[f'default_{key}_creds'] = {'type': 'string', 'required': True} + + return schema + + def remote_addr(self) -> Addr: + return Addr(host=self['remote_addr']['host'], + port=self['remote_addr']['port']) + + def local_addr(self) -> Addr: + return Addr(host=self['local_addr']['host'], + port=self['local_addr']['port']) + + def creds_by_name(self, name: str) -> MqttCreds: + return MqttCreds(username=self['creds'][name]['username'], + password=self['creds'][name]['password']) + + def creds(self) -> MqttCreds: + return self.creds_by_name(self['default_client_creds']) + + def server_creds(self) -> MqttCreds: + return self.creds_by_name(self['default_server_creds']) + + +class MqttNodesConfig(ConfigUnit): + NAME = 'mqtt_nodes' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'common': { + 'type': 'dict', + 'schema': { + 'temphum': { + 'type': 'dict', + 'schema': { + 'interval': {'type': 'integer'} + } + }, + 'password': {'type': 'string'} + } + }, + 'nodes': { + 'type': 'dict', + 'required': True, + 'keysrules': {'type': 'string'}, + 'valuesrules': { + 'type': 'dict', + 'schema': { + 'type': {'type': 'string', 'required': True, 'allowed': ['esp8266', 'linux', 'none'],}, + 'board': {'type': 'string', 'allowed': ['nodemcu', 'd1_mini_lite', 'esp12e']}, + 'temphum': { + 'type': 'dict', + 'schema': { + 'module': {'type': 'string', 'required': True, 'allowed': ['si7021', 'dht12']}, + 'interval': {'type': 'integer'}, + 'i2c_bus': {'type': 'integer'}, + 'tcpserver': { + 'type': 'dict', + 'schema': { + 'port': {'type': 'integer', 'required': True} + } + } + } + }, + 'relay': { + 'type': 'dict', + 'schema': { + 'device_type': {'type': 'string', 'allowed': ['lamp', 'pump', 'solenoid'], 'required': True}, + 'legacy_topics': {'type': 'boolean'} + } + }, + 'password': {'type': 'string'} + } + } + } + } + + @staticmethod + def custom_validator(data): + for name, node in data['nodes'].items(): + if 'temphum' in node: + if node['type'] == 'linux': + if 'i2c_bus' not in node['temphum']: + raise KeyError(f'nodes.{name}.temphum: i2c_bus is missing but required for type=linux') + if node['type'] in ('esp8266',) and 'board' not in node: + raise KeyError(f'nodes.{name}: board is missing but required for type={node["type"]}') + + def get_node(self, name: str) -> dict: + node = self['nodes'][name] + if node['type'] == 'none': + return node + + try: + if 'password' not in node: + node['password'] = self['common']['password'] + except KeyError: + pass + + try: + if 'temphum' in node: + for ckey, cval in self['common']['temphum'].items(): + if ckey not in node['temphum']: + node['temphum'][ckey] = cval + except KeyError: + pass + + return node + + def get_nodes(self, + filters: Optional[Union[list[str], tuple[str]]] = None, + only_names=False) -> Union[dict, list[str]]: + if filters: + for f in filters: + if f not in ('temphum', 'relay'): + raise ValueError(f'{self.__class__.__name__}::get_node(): invalid filter {f}') + reslist = [] + resdict = {} + for name in self['nodes'].keys(): + node = self.get_node(name) + if (not filters) or ('temphum' in filters and 'temphum' in node) or ('relay' in filters and 'relay' in node): + if only_names: + reslist.append(name) + else: + resdict[name] = node + return reslist if only_names else resdict diff --git a/src/home/mqtt/_module.py b/src/home/mqtt/_module.py new file mode 100644 index 0000000..80f27bb --- /dev/null +++ b/src/home/mqtt/_module.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import abc +import logging +import threading + +from time import sleep +from ..util import next_tick_gen + +from typing import TYPE_CHECKING, Optional +if TYPE_CHECKING: + from ._node import MqttNode + from ._payload import MqttPayload + + +class MqttModule(abc.ABC): + _tick_interval: int + _initialized: bool + _connected: bool + _ticker: Optional[threading.Thread] + _mqtt_node_ref: Optional[MqttNode] + + def __init__(self, tick_interval=0): + self._tick_interval = tick_interval + self._initialized = False + self._ticker = None + self._logger = logging.getLogger(self.__class__.__name__) + self._connected = False + self._mqtt_node_ref = None + + def on_connect(self, mqtt: MqttNode): + self._connected = True + self._mqtt_node_ref = mqtt + if self._tick_interval: + self._start_ticker() + + def on_disconnect(self, mqtt: MqttNode): + self._connected = False + self._mqtt_node_ref = None + + def is_initialized(self): + return self._initialized + + def set_initialized(self): + self._initialized = True + + def unset_initialized(self): + self._initialized = False + + def tick(self): + pass + + def _tick(self): + g = next_tick_gen(self._tick_interval) + while self._connected: + sleep(next(g)) + if not self._connected: + break + self.tick() + + def _start_ticker(self): + if not self._ticker or not self._ticker.is_alive(): + name_part = f'{self._mqtt_node_ref.id}/' if self._mqtt_node_ref else '' + self._ticker = None + self._ticker = threading.Thread(target=self._tick, + name=f'mqtt:{self.__class__.__name__}/{name_part}ticker') + self._ticker.start() + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + pass diff --git a/src/home/mqtt/mqtt.py b/src/home/mqtt/_mqtt.py index 4acd4f6..746ae2e 100644 --- a/src/home/mqtt/mqtt.py +++ b/src/home/mqtt/_mqtt.py @@ -3,19 +3,24 @@ import paho.mqtt.client as mqtt import ssl import logging -from typing import Tuple -from ..config import config +from ._config import MqttCreds, MqttConfig +from typing import Optional -def username_and_password() -> Tuple[str, str]: - username = config['mqtt']['username'] if 'username' in config['mqtt'] else None - password = config['mqtt']['password'] if 'password' in config['mqtt'] else None - return username, password +class Mqtt: + _connected: bool + _is_server: bool + _mqtt_config: MqttConfig + def __init__(self, + clean_session=True, + client_id='', + creds: Optional[MqttCreds] = None, + is_server=False): + if not client_id: + raise ValueError('client_id must not be empty') -class MqttBase: - def __init__(self, clean_session=True): - self._client = mqtt.Client(client_id=config['mqtt']['client_id'], + self._client = mqtt.Client(client_id=client_id, protocol=mqtt.MQTTv311, clean_session=clean_session) self._client.on_connect = self.on_connect @@ -24,15 +29,17 @@ class MqttBase: self._client.on_log = self.on_log self._client.on_publish = self.on_publish self._loop_started = False - + self._connected = False + self._is_server = is_server + self._mqtt_config = MqttConfig() self._logger = logging.getLogger(self.__class__.__name__) - username, password = username_and_password() - if username and password: - self._logger.debug(f'username={username} password={password}') - self._client.username_pw_set(username, password) + if not creds: + creds = self._mqtt_config.creds() if not is_server else self._mqtt_config.server_creds() + + self._client.username_pw_set(creds.username, creds.password) - def configure_tls(self): + def _configure_tls(self): ca_certs = os.path.realpath(os.path.join( os.path.dirname(os.path.realpath(__file__)), '..', @@ -41,13 +48,14 @@ class MqttBase: 'assets', 'mqtt_ca.crt' )) - self._client.tls_set(ca_certs=ca_certs, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_TLSv1_2) + self._client.tls_set(ca_certs=ca_certs, + cert_reqs=ssl.CERT_REQUIRED, + tls_version=ssl.PROTOCOL_TLSv1_2) def connect_and_loop(self, loop_forever=True): - host = config['mqtt']['host'] - port = config['mqtt']['port'] - - self._client.connect(host, port, 60) + self._configure_tls() + addr = self._mqtt_config.local_addr() if self._is_server else self._mqtt_config.remote_addr() + self._client.connect(addr.host, addr.port, 60) if loop_forever: self._client.loop_forever() else: @@ -61,9 +69,11 @@ class MqttBase: def on_connect(self, client: mqtt.Client, userdata, flags, rc): self._logger.info("Connected with result code " + str(rc)) + self._connected = True def on_disconnect(self, client: mqtt.Client, userdata, rc): self._logger.info("Disconnected with result code " + str(rc)) + self._connected = False def on_log(self, client: mqtt.Client, userdata, level, buf): level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO @@ -73,4 +83,4 @@ class MqttBase: self._logger.debug(msg.topic + ": " + str(msg.payload)) def on_publish(self, client: mqtt.Client, userdata, mid): - self._logger.debug(f'publish done, mid={mid}')
\ No newline at end of file + self._logger.debug(f'publish done, mid={mid}') diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py new file mode 100644 index 0000000..4e259a4 --- /dev/null +++ b/src/home/mqtt/_node.py @@ -0,0 +1,92 @@ +import logging +import importlib + +from typing import List, TYPE_CHECKING, Optional +from ._payload import MqttPayload +from ._module import MqttModule +if TYPE_CHECKING: + from ._wrapper import MqttWrapper +else: + MqttWrapper = None + + +class MqttNode: + _modules: List[MqttModule] + _module_subscriptions: dict[str, MqttModule] + _node_id: str + _node_secret: str + _payload_callbacks: list[callable] + _wrapper: Optional[MqttWrapper] + + def __init__(self, + node_id: str, + node_secret: Optional[str] = None): + self._modules = [] + self._module_subscriptions = {} + self._node_id = node_id + self._node_secret = node_secret + self._payload_callbacks = [] + self._logger = logging.getLogger(self.__class__.__name__) + self._wrapper = None + + def on_connect(self, wrapper: MqttWrapper): + self._wrapper = wrapper + for module in self._modules: + if not module.is_initialized(): + module.on_connect(self) + module.set_initialized() + + def on_disconnect(self): + self._wrapper = None + for module in self._modules: + module.unset_initialized() + + def on_message(self, topic, payload): + if topic in self._module_subscriptions: + payload = self._module_subscriptions[topic].handle_payload(self, topic, payload) + if isinstance(payload, MqttPayload): + for f in self._payload_callbacks: + f(self, payload) + + def load_module(self, module_name: str, *args, **kwargs) -> MqttModule: + module = importlib.import_module(f'..module.{module_name}', __name__) + if not hasattr(module, 'MODULE_NAME'): + raise RuntimeError(f'MODULE_NAME not found in module {module}') + cl = getattr(module, getattr(module, 'MODULE_NAME')) + instance = cl(*args, **kwargs) + self.add_module(instance) + return instance + + def add_module(self, module: MqttModule): + self._modules.append(module) + if self._wrapper and self._wrapper._connected: + module.on_connect(self) + module.set_initialized() + + def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1): + if not self._wrapper or not self._wrapper._connected: + raise RuntimeError('not connected') + + self._module_subscriptions[topic] = module + self._wrapper.subscribe(self.id, topic, qos) + + def publish(self, + topic: str, + payload: bytes, + qos: int = 1): + self._wrapper.publish(self.id, topic, payload, qos) + + def add_payload_callback(self, callback: callable): + self._payload_callbacks.append(callback) + + @property + def id(self) -> str: + return self._node_id + + @property + def secret(self) -> str: + return self._node_secret + + @secret.setter + def secret(self, secret: str) -> None: + self._node_secret = secret diff --git a/src/home/mqtt/payload/base_payload.py b/src/home/mqtt/_payload.py index 1abd898..58eeae3 100644 --- a/src/home/mqtt/payload/base_payload.py +++ b/src/home/mqtt/_payload.py @@ -1,5 +1,5 @@ -import abc import struct +import abc import re from typing import Optional, Tuple @@ -142,4 +142,4 @@ def _bit_field_params(cl) -> Optional[Tuple[int, ...]]: match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__) if match is not None: return tuple([int(match.group(i)) for i in range(1, 4)]) - return None + return None
\ No newline at end of file diff --git a/src/home/mqtt/_util.py b/src/home/mqtt/_util.py new file mode 100644 index 0000000..390d463 --- /dev/null +++ b/src/home/mqtt/_util.py @@ -0,0 +1,15 @@ +import os +import re + +from typing import List + + +def get_modules() -> List[str]: + modules = [] + modules_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module') + for name in os.listdir(modules_dir): + if os.path.isdir(os.path.join(modules_dir, name)): + continue + name = re.sub(r'\.py$', '', name) + modules.append(name) + return modules diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py new file mode 100644 index 0000000..f858f88 --- /dev/null +++ b/src/home/mqtt/_wrapper.py @@ -0,0 +1,59 @@ +import paho.mqtt.client as mqtt + +from ._mqtt import Mqtt +from ._node import MqttNode +from ..config import config +from ..util import strgen + + +class MqttWrapper(Mqtt): + _nodes: list[MqttNode] + + def __init__(self, + client_id: str, + topic_prefix='hk', + randomize_client_id=False, + clean_session=True): + if randomize_client_id: + client_id += '_'+strgen(6) + super().__init__(clean_session=clean_session, + client_id=client_id) + self._nodes = [] + self._topic_prefix = topic_prefix + + def on_connect(self, client: mqtt.Client, userdata, flags, rc): + super().on_connect(client, userdata, flags, rc) + for node in self._nodes: + node.on_connect(self) + + def on_disconnect(self, client: mqtt.Client, userdata, rc): + super().on_disconnect(client, userdata, rc) + for node in self._nodes: + node.on_disconnect() + + def on_message(self, client: mqtt.Client, userdata, msg): + try: + topic = msg.topic + for node in self._nodes: + node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload) + except Exception as e: + self._logger.exception(str(e)) + + def add_node(self, node: MqttNode): + self._nodes.append(node) + if self._connected: + node.on_connect(self) + + def subscribe(self, + node_id: str, + topic: str, + qos: int): + self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos) + + def publish(self, + node_id: str, + topic: str, + payload: bytes, + qos: int): + self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos) + self._client.loop_write() diff --git a/src/home/mqtt/esp.py b/src/home/mqtt/esp.py deleted file mode 100644 index 56ced83..0000000 --- a/src/home/mqtt/esp.py +++ /dev/null @@ -1,106 +0,0 @@ -import re -import paho.mqtt.client as mqtt - -from .mqtt import MqttBase -from typing import Optional, Union -from .payload.esp import ( - OTAPayload, - OTAResultPayload, - DiagnosticsPayload, - InitialDiagnosticsPayload -) - - -class MqttEspDevice: - id: str - secret: Optional[str] - - def __init__(self, id: str, secret: Optional[str] = None): - self.id = id - self.secret = secret - - -class MqttEspBase(MqttBase): - _devices: list[MqttEspDevice] - _message_callback: Optional[callable] - _ota_publish_callback: Optional[callable] - - TOPIC_LEAF = 'esp' - - def __init__(self, - devices: Union[MqttEspDevice, list[MqttEspDevice]], - subscribe_to_updates=True): - super().__init__(clean_session=True) - if not isinstance(devices, list): - devices = [devices] - self._devices = devices - self._message_callback = None - self._ota_publish_callback = None - self._subscribe_to_updates = subscribe_to_updates - self._ota_mid = None - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - super().on_connect(client, userdata, flags, rc) - - if self._subscribe_to_updates: - for device in self._devices: - topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#' - self._logger.debug(f"subscribing to {topic}") - client.subscribe(topic, qos=1) - - def on_publish(self, client: mqtt.Client, userdata, mid): - if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback: - self._ota_publish_callback() - - def set_message_callback(self, callback: callable): - self._message_callback = callback - - def on_message(self, client: mqtt.Client, userdata, msg): - try: - match = re.match(self.get_mqtt_topics(), msg.topic) - self._logger.debug(f'topic: {msg.topic}') - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - # try: - next(d for d in self._devices if d.id == device_id) - # except StopIteration:h - # return - - message = None - if subtopic == 'stat': - message = DiagnosticsPayload.unpack(msg.payload) - elif subtopic == 'stat1': - message = InitialDiagnosticsPayload.unpack(msg.payload) - elif subtopic == 'otares': - message = OTAResultPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - return True - - except Exception as e: - self._logger.exception(str(e)) - - def push_ota(self, - device_id, - filename: str, - publish_callback: callable, - qos: int): - device = next(d for d in self._devices if d.id == device_id) - assert device.secret is not None, 'device secret not specified' - - self._ota_publish_callback = publish_callback - payload = OTAPayload(secret=device.secret, filename=filename) - publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota', - payload=payload.pack(), - qos=qos) - self._ota_mid = publish_result.mid - self._client.loop_write() - - @classmethod - def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None): - return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'
\ No newline at end of file diff --git a/src/home/mqtt/payload/esp.py b/src/home/mqtt/module/diagnostics.py index 171cdb9..5db5e99 100644 --- a/src/home/mqtt/payload/esp.py +++ b/src/home/mqtt/module/diagnostics.py @@ -1,39 +1,8 @@ -import hashlib +from .._payload import MqttPayload, MqttPayloadCustomField +from .._node import MqttNode, MqttModule +from typing import Optional -from .base_payload import MqttPayload, MqttPayloadCustomField - - -class OTAResultPayload(MqttPayload): - FORMAT = '=BB' - result: int - error_code: int - - -class OTAPayload(MqttPayload): - secret: str - filename: str - - # structure of returned data: - # - # uint8_t[len(secret)] secret; - # uint8_t[16] md5; - # *uint8_t data - - def pack(self): - buf = bytearray(self.secret.encode()) - m = hashlib.md5() - with open(self.filename, 'rb') as fd: - content = fd.read() - m.update(content) - buf.extend(m.digest()) - buf.extend(content) - return buf - - def unpack(cls, buf: bytes): - raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') - # secret = buf[:12].decode() - # filename = buf[12:].decode() - # return OTAPayload(secret=secret, filename=filename) +MODULE_NAME = 'MqttDiagnosticsModule' class DiagnosticsFlags(MqttPayloadCustomField): @@ -76,3 +45,20 @@ class DiagnosticsPayload(MqttPayload): rssi: int free_heap: int flags: DiagnosticsFlags + + +class MqttDiagnosticsModule(MqttModule): + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + for topic in ('diag', 'd1ag', 'stat', 'stat1'): + mqtt.subscribe_module(topic, self) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + message = None + if topic in ('stat', 'diag'): + message = DiagnosticsPayload.unpack(payload) + elif topic in ('stat1', 'd1ag'): + message = InitialDiagnosticsPayload.unpack(payload) + if message: + self._logger.debug(message) + return message diff --git a/src/home/mqtt/module/inverter.py b/src/home/mqtt/module/inverter.py new file mode 100644 index 0000000..d927a06 --- /dev/null +++ b/src/home/mqtt/module/inverter.py @@ -0,0 +1,195 @@ +import time +import json +import datetime +try: + import inverterd +except: + pass + +from typing import Optional +from .._module import MqttModule +from .._node import MqttNode +from .._payload import MqttPayload, bit_field +try: + from home.database import InverterDatabase +except: + pass + +_mult_10 = lambda n: int(n*10) +_div_10 = lambda n: n/10 + + +MODULE_NAME = 'MqttInverterModule' + +STATUS_TOPIC = 'status' +GENERATION_TOPIC = 'generation' + + +class MqttInverterStatusPayload(MqttPayload): + # 46 bytes + FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH' + + PACKER = { + 'grid_voltage': _mult_10, + 'grid_freq': _mult_10, + 'ac_output_voltage': _mult_10, + 'ac_output_freq': _mult_10, + 'battery_voltage': _mult_10, + 'battery_voltage_scc': _mult_10, + 'battery_voltage_scc2': _mult_10, + 'pv1_input_voltage': _mult_10, + 'pv2_input_voltage': _mult_10 + } + UNPACKER = { + 'grid_voltage': _div_10, + 'grid_freq': _div_10, + 'ac_output_voltage': _div_10, + 'ac_output_freq': _div_10, + 'battery_voltage': _div_10, + 'battery_voltage_scc': _div_10, + 'battery_voltage_scc2': _div_10, + 'pv1_input_voltage': _div_10, + 'pv2_input_voltage': _div_10 + } + + time: int + grid_voltage: float + grid_freq: float + ac_output_voltage: float + ac_output_freq: float + ac_output_apparent_power: int + ac_output_active_power: int + output_load_percent: int + battery_voltage: float + battery_voltage_scc: float + battery_voltage_scc2: float + battery_discharge_current: int + battery_charge_current: int + battery_capacity: int + inverter_heat_sink_temp: int + mppt1_charger_temp: int + mppt2_charger_temp: int + pv1_input_power: int + pv2_input_power: int + pv1_input_voltage: float + pv2_input_voltage: float + + # H + mppt1_charger_status: bit_field(0, 16, 2) + mppt2_charger_status: bit_field(0, 16, 2) + battery_power_direction: bit_field(0, 16, 2) + dc_ac_power_direction: bit_field(0, 16, 2) + line_power_direction: bit_field(0, 16, 2) + load_connected: bit_field(0, 16, 1) + + +class MqttInverterGenerationPayload(MqttPayload): + # 8 bytes + FORMAT = 'II' + + time: int + wh: int + + +class MqttInverterModule(MqttModule): + _status_poll_freq: int + _generation_poll_freq: int + _inverter: Optional[inverterd.Client] + _database: Optional[InverterDatabase] + _gen_prev: float + + def __init__(self, status_poll_freq=0, generation_poll_freq=0): + super().__init__(tick_interval=status_poll_freq) + self._status_poll_freq = status_poll_freq + self._generation_poll_freq = generation_poll_freq + + # this defines whether this is a publisher or a subscriber + if status_poll_freq > 0: + self._inverter = inverterd.Client() + self._inverter.connect() + self._inverter.format(inverterd.Format.SIMPLE_JSON) + self._database = None + else: + self._inverter = None + self._database = InverterDatabase() + + self._gen_prev = 0 + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + if not self._inverter: + mqtt.subscribe_module(STATUS_TOPIC, self) + mqtt.subscribe_module(GENERATION_TOPIC, self) + + def tick(self): + if not self._inverter: + return + + # read status + now = time.time() + try: + raw = self._inverter.exec('get-status') + except inverterd.InverterError as e: + self._logger.error(f'inverter error: {str(e)}') + # TODO send to server + return + + data = json.loads(raw)['data'] + status = MqttInverterStatusPayload(time=round(now), **data) + self._mqtt_node_ref.publish(STATUS_TOPIC, status.pack()) + + # read today's generation stat + now = time.time() + if self._gen_prev == 0 or now - self._gen_prev >= self._generation_poll_freq: + self._gen_prev = now + today = datetime.date.today() + try: + raw = self._inverter.exec('get-day-generated', (today.year, today.month, today.day)) + except inverterd.InverterError as e: + self._logger.error(f'inverter error: {str(e)}') + # TODO send to server + return + + data = json.loads(raw)['data'] + gen = MqttInverterGenerationPayload(time=round(now), wh=data['wh']) + self._mqtt_node_ref.publish(GENERATION_TOPIC, gen.pack()) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + home_id = 1 # legacy compat + + if topic == STATUS_TOPIC: + s = MqttInverterStatusPayload.unpack(payload) + self._database.add_status(home_id=home_id, + client_time=s.time, + grid_voltage=int(s.grid_voltage*10), + grid_freq=int(s.grid_freq * 10), + ac_output_voltage=int(s.ac_output_voltage * 10), + ac_output_freq=int(s.ac_output_freq * 10), + ac_output_apparent_power=s.ac_output_apparent_power, + ac_output_active_power=s.ac_output_active_power, + output_load_percent=s.output_load_percent, + battery_voltage=int(s.battery_voltage * 10), + battery_voltage_scc=int(s.battery_voltage_scc * 10), + battery_voltage_scc2=int(s.battery_voltage_scc2 * 10), + battery_discharge_current=s.battery_discharge_current, + battery_charge_current=s.battery_charge_current, + battery_capacity=s.battery_capacity, + inverter_heat_sink_temp=s.inverter_heat_sink_temp, + mppt1_charger_temp=s.mppt1_charger_temp, + mppt2_charger_temp=s.mppt2_charger_temp, + pv1_input_power=s.pv1_input_power, + pv2_input_power=s.pv2_input_power, + pv1_input_voltage=int(s.pv1_input_voltage * 10), + pv2_input_voltage=int(s.pv2_input_voltage * 10), + mppt1_charger_status=s.mppt1_charger_status, + mppt2_charger_status=s.mppt2_charger_status, + battery_power_direction=s.battery_power_direction, + dc_ac_power_direction=s.dc_ac_power_direction, + line_power_direction=s.line_power_direction, + load_connected=s.load_connected) + return s + + elif topic == GENERATION_TOPIC: + gen = MqttInverterGenerationPayload.unpack(payload) + self._database.add_generation(home_id, gen.time, gen.wh) + return gen diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py new file mode 100644 index 0000000..cd34332 --- /dev/null +++ b/src/home/mqtt/module/ota.py @@ -0,0 +1,77 @@ +import hashlib + +from typing import Optional +from .._payload import MqttPayload +from .._node import MqttModule, MqttNode + +MODULE_NAME = 'MqttOtaModule' + + +class OtaResultPayload(MqttPayload): + FORMAT = '=BB' + result: int + error_code: int + + +class OtaPayload(MqttPayload): + secret: str + filename: str + + # structure of returned data: + # + # uint8_t[len(secret)] secret; + # uint8_t[16] md5; + # *uint8_t data + + def pack(self): + buf = bytearray(self.secret.encode()) + m = hashlib.md5() + with open(self.filename, 'rb') as fd: + content = fd.read() + m.update(content) + buf.extend(m.digest()) + buf.extend(content) + return buf + + def unpack(cls, buf: bytes): + raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented') + # secret = buf[:12].decode() + # filename = buf[12:].decode() + # return OTAPayload(secret=secret, filename=filename) + + +class MqttOtaModule(MqttModule): + _ota_request: Optional[tuple[str, int]] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._ota_request = None + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + mqtt.subscribe_module("otares", self) + + if self._ota_request is not None: + filename, qos = self._ota_request + self._ota_request = None + self.do_push_ota(self._mqtt_node_ref.secret, filename, qos) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + if topic == 'otares': + message = OtaResultPayload.unpack(payload) + self._logger.debug(message) + return message + + def do_push_ota(self, secret: str, filename: str, qos: int): + payload = OtaPayload(secret=secret, filename=filename) + self._mqtt_node_ref.publish('ota', + payload=payload.pack(), + qos=qos) + + def push_ota(self, + filename: str, + qos: int): + if not self._initialized: + self._ota_request = (filename, qos) + else: + self.do_push_ota(filename, qos) diff --git a/src/home/mqtt/module/relay.py b/src/home/mqtt/module/relay.py new file mode 100644 index 0000000..e968031 --- /dev/null +++ b/src/home/mqtt/module/relay.py @@ -0,0 +1,92 @@ +import datetime + +from typing import Optional +from .. import MqttModule, MqttPayload, MqttNode + +MODULE_NAME = 'MqttRelayModule' + + +class MqttPowerSwitchPayload(MqttPayload): + FORMAT = '=12sB' + PACKER = { + 'state': lambda n: int(n), + 'secret': lambda s: s.encode('utf-8') + } + UNPACKER = { + 'state': lambda n: bool(n), + 'secret': lambda s: s.decode('utf-8') + } + + secret: str + state: bool + + +class MqttPowerStatusPayload(MqttPayload): + FORMAT = '=B' + PACKER = { + 'opened': lambda n: int(n), + } + UNPACKER = { + 'opened': lambda n: bool(n), + } + + opened: bool + + +class MqttRelayState: + enabled: bool + update_time: datetime.datetime + rssi: int + fw_version: int + ever_updated: bool + + def __init__(self): + self.ever_updated = False + self.enabled = False + self.rssi = 0 + + def update(self, + enabled: bool, + rssi: int, + fw_version=None): + self.ever_updated = True + self.enabled = enabled + self.rssi = rssi + self.update_time = datetime.datetime.now() + if fw_version: + self.fw_version = fw_version + + +class MqttRelayModule(MqttModule): + _legacy_topics: bool + + def __init__(self, legacy_topics=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self._legacy_topics = legacy_topics + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + mqtt.subscribe_module(self._get_switch_topic(), self) + mqtt.subscribe_module('relay/status', self) + + def switchpower(self, + enable: bool): + payload = MqttPowerSwitchPayload(secret=self._mqtt_node_ref.secret, + state=enable) + self._mqtt_node_ref.publish(self._get_switch_topic(), + payload=payload.pack()) + + def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: + message = None + + if topic == self._get_switch_topic(): + message = MqttPowerSwitchPayload.unpack(payload) + elif topic == 'relay/status': + message = MqttPowerStatusPayload.unpack(payload) + + if message is not None: + self._logger.debug(message) + return message + + def _get_switch_topic(self) -> str: + return 'relay/power' if self._legacy_topics else 'relay/switch' diff --git a/src/home/mqtt/module/temphum.py b/src/home/mqtt/module/temphum.py new file mode 100644 index 0000000..fd02cca --- /dev/null +++ b/src/home/mqtt/module/temphum.py @@ -0,0 +1,82 @@ +from .._node import MqttNode +from .._module import MqttModule +from .._payload import MqttPayload +from typing import Optional +from ...temphum import BaseSensor + +two_digits_precision = lambda x: round(x, 2) + +MODULE_NAME = 'MqttTempHumModule' +DATA_TOPIC = 'temphum/data' + + +class MqttTemphumDataPayload(MqttPayload): + FORMAT = '=ddb' + UNPACKER = { + 'temp': two_digits_precision, + 'rh': two_digits_precision + } + + temp: float + rh: float + error: int + + +# class MqttTempHumNodes(HashableEnum): +# KBN_SH_HALL = auto() +# KBN_SH_BATHROOM = auto() +# KBN_SH_LIVINGROOM = auto() +# KBN_SH_BEDROOM = auto() +# +# KBN_BH_2FL = auto() +# KBN_BH_2FL_STREET = auto() +# KBN_BH_1FL_LIVINGROOM = auto() +# KBN_BH_1FL_BEDROOM = auto() +# KBN_BH_1FL_BATHROOM = auto() +# +# KBN_NH_1FL_INV = auto() +# KBN_NH_1FL_CENTER = auto() +# KBN_NH_1LF_KT = auto() +# KBN_NH_1FL_DS = auto() +# KBN_NH_1FS_EZ = auto() +# +# SPB_FLAT120_CABINET = auto() + + +class MqttTempHumModule(MqttModule): + def __init__(self, + sensor: Optional[BaseSensor] = None, + write_to_database=False, + *args, **kwargs): + if sensor is not None: + kwargs['tick_interval'] = 10 + super().__init__(*args, **kwargs) + self._sensor = sensor + + def on_connect(self, mqtt: MqttNode): + super().on_connect(mqtt) + mqtt.subscribe_module(DATA_TOPIC, self) + + def tick(self): + if not self._sensor: + return + + error = 0 + temp = 0 + rh = 0 + try: + temp = self._sensor.temperature() + rh = self._sensor.humidity() + except: + error = 1 + pld = MqttTemphumDataPayload(temp=temp, rh=rh, error=error) + self._mqtt_node_ref.publish(DATA_TOPIC, pld.pack()) + + def handle_payload(self, + mqtt: MqttNode, + topic: str, + payload: bytes) -> Optional[MqttPayload]: + if topic == DATA_TOPIC: + message = MqttTemphumDataPayload.unpack(payload) + self._logger.debug(message) + return message diff --git a/src/home/mqtt/payload/__init__.py b/src/home/mqtt/payload/__init__.py deleted file mode 100644 index eee6709..0000000 --- a/src/home/mqtt/payload/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base_payload import MqttPayload
\ No newline at end of file diff --git a/src/home/mqtt/payload/inverter.py b/src/home/mqtt/payload/inverter.py deleted file mode 100644 index 09388df..0000000 --- a/src/home/mqtt/payload/inverter.py +++ /dev/null @@ -1,73 +0,0 @@ -import struct - -from .base_payload import MqttPayload, bit_field -from typing import Tuple - -_mult_10 = lambda n: int(n*10) -_div_10 = lambda n: n/10 - - -class Status(MqttPayload): - # 46 bytes - FORMAT = 'IHHHHHHBHHHHHBHHHHHHHH' - - PACKER = { - 'grid_voltage': _mult_10, - 'grid_freq': _mult_10, - 'ac_output_voltage': _mult_10, - 'ac_output_freq': _mult_10, - 'battery_voltage': _mult_10, - 'battery_voltage_scc': _mult_10, - 'battery_voltage_scc2': _mult_10, - 'pv1_input_voltage': _mult_10, - 'pv2_input_voltage': _mult_10 - } - UNPACKER = { - 'grid_voltage': _div_10, - 'grid_freq': _div_10, - 'ac_output_voltage': _div_10, - 'ac_output_freq': _div_10, - 'battery_voltage': _div_10, - 'battery_voltage_scc': _div_10, - 'battery_voltage_scc2': _div_10, - 'pv1_input_voltage': _div_10, - 'pv2_input_voltage': _div_10 - } - - time: int - grid_voltage: float - grid_freq: float - ac_output_voltage: float - ac_output_freq: float - ac_output_apparent_power: int - ac_output_active_power: int - output_load_percent: int - battery_voltage: float - battery_voltage_scc: float - battery_voltage_scc2: float - battery_discharge_current: int - battery_charge_current: int - battery_capacity: int - inverter_heat_sink_temp: int - mppt1_charger_temp: int - mppt2_charger_temp: int - pv1_input_power: int - pv2_input_power: int - pv1_input_voltage: float - pv2_input_voltage: float - - # H - mppt1_charger_status: bit_field(0, 16, 2) - mppt2_charger_status: bit_field(0, 16, 2) - battery_power_direction: bit_field(0, 16, 2) - dc_ac_power_direction: bit_field(0, 16, 2) - line_power_direction: bit_field(0, 16, 2) - load_connected: bit_field(0, 16, 1) - - -class Generation(MqttPayload): - # 8 bytes - FORMAT = 'II' - - time: int - wh: int diff --git a/src/home/mqtt/payload/relay.py b/src/home/mqtt/payload/relay.py deleted file mode 100644 index 4902991..0000000 --- a/src/home/mqtt/payload/relay.py +++ /dev/null @@ -1,22 +0,0 @@ -from .base_payload import MqttPayload -from .esp import ( - OTAResultPayload, - OTAPayload, - InitialDiagnosticsPayload, - DiagnosticsPayload -) - - -class PowerPayload(MqttPayload): - FORMAT = '=12sB' - PACKER = { - 'state': lambda n: int(n), - 'secret': lambda s: s.encode('utf-8') - } - UNPACKER = { - 'state': lambda n: bool(n), - 'secret': lambda s: s.decode('utf-8') - } - - secret: str - state: bool diff --git a/src/home/mqtt/payload/sensors.py b/src/home/mqtt/payload/sensors.py deleted file mode 100644 index f99b307..0000000 --- a/src/home/mqtt/payload/sensors.py +++ /dev/null @@ -1,20 +0,0 @@ -from .base_payload import MqttPayload - -_mult_100 = lambda n: int(n*100) -_div_100 = lambda n: n/100 - - -class Temperature(MqttPayload): - FORMAT = 'IhH' - PACKER = { - 'temp': _mult_100, - 'rh': _mult_100, - } - UNPACKER = { - 'temp': _div_100, - 'rh': _div_100, - } - - time: int - temp: float - rh: float diff --git a/src/home/mqtt/payload/temphum.py b/src/home/mqtt/payload/temphum.py deleted file mode 100644 index c0b744e..0000000 --- a/src/home/mqtt/payload/temphum.py +++ /dev/null @@ -1,15 +0,0 @@ -from .base_payload import MqttPayload - -two_digits_precision = lambda x: round(x, 2) - - -class TempHumDataPayload(MqttPayload): - FORMAT = '=ddb' - UNPACKER = { - 'temp': two_digits_precision, - 'rh': two_digits_precision - } - - temp: float - rh: float - error: int diff --git a/src/home/mqtt/relay.py b/src/home/mqtt/relay.py deleted file mode 100644 index a90f19c..0000000 --- a/src/home/mqtt/relay.py +++ /dev/null @@ -1,71 +0,0 @@ -import paho.mqtt.client as mqtt -import re -import datetime - -from .payload.relay import ( - PowerPayload, -) -from .esp import MqttEspBase - - -class MqttRelay(MqttEspBase): - TOPIC_LEAF = 'relay' - - def set_power(self, device_id, enable: bool, secret=None): - device = next(d for d in self._devices if d.id == device_id) - secret = secret if secret else device.secret - - assert secret is not None, 'device secret not specified' - - payload = PowerPayload(secret=secret, - state=enable) - self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/power', - payload=payload.pack(), - qos=1) - self._client.loop_write() - - def on_message(self, client: mqtt.Client, userdata, msg): - if super().on_message(client, userdata, msg): - return - - try: - match = re.match(self.get_mqtt_topics(['power']), msg.topic) - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - message = None - if subtopic == 'power': - message = PowerPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - - except Exception as e: - self._logger.exception(str(e)) - - -class MqttRelayState: - enabled: bool - update_time: datetime.datetime - rssi: int - fw_version: int - ever_updated: bool - - def __init__(self): - self.ever_updated = False - self.enabled = False - self.rssi = 0 - - def update(self, - enabled: bool, - rssi: int, - fw_version=None): - self.ever_updated = True - self.enabled = enabled - self.rssi = rssi - self.update_time = datetime.datetime.now() - if fw_version: - self.fw_version = fw_version diff --git a/src/home/mqtt/temphum.py b/src/home/mqtt/temphum.py deleted file mode 100644 index 44810ef..0000000 --- a/src/home/mqtt/temphum.py +++ /dev/null @@ -1,54 +0,0 @@ -import paho.mqtt.client as mqtt -import re - -from enum import auto -from .payload.temphum import TempHumDataPayload -from .esp import MqttEspBase -from ..util import HashableEnum - - -class MqttTempHumNodes(HashableEnum): - KBN_SH_HALL = auto() - KBN_SH_BATHROOM = auto() - KBN_SH_LIVINGROOM = auto() - KBN_SH_BEDROOM = auto() - - KBN_BH_2FL = auto() - KBN_BH_2FL_STREET = auto() - KBN_BH_1FL_LIVINGROOM = auto() - KBN_BH_1FL_BEDROOM = auto() - KBN_BH_1FL_BATHROOM = auto() - - KBN_NH_1FL_INV = auto() - KBN_NH_1FL_CENTER = auto() - KBN_NH_1LF_KT = auto() - KBN_NH_1FL_DS = auto() - KBN_NH_1FS_EZ = auto() - - SPB_FLAT120_CABINET = auto() - - -class MqttTempHum(MqttEspBase): - TOPIC_LEAF = 'temphum' - - def on_message(self, client: mqtt.Client, userdata, msg): - if super().on_message(client, userdata, msg): - return - - try: - match = re.match(self.get_mqtt_topics(['data']), msg.topic) - if not match: - return - - device_id = match.group(1) - subtopic = match.group(2) - - message = None - if subtopic == 'data': - message = TempHumDataPayload.unpack(msg.payload) - - if message and self._message_callback: - self._message_callback(device_id, message) - - except Exception as e: - self._logger.exception(str(e)) diff --git a/src/home/mqtt/util.py b/src/home/mqtt/util.py deleted file mode 100644 index f71ffd8..0000000 --- a/src/home/mqtt/util.py +++ /dev/null @@ -1,8 +0,0 @@ -import time - - -def poll_tick(freq): - t = time.time() - while True: - t += freq - yield max(t - time.time(), 0) diff --git a/src/home/pio/products.py b/src/home/pio/products.py index 7649078..388da03 100644 --- a/src/home/pio/products.py +++ b/src/home/pio/products.py @@ -16,10 +16,6 @@ _products_dir = os.path.join( def get_products(): products = [] for f in os.listdir(_products_dir): - # temp hack - if f.endswith('-esp01'): - continue - # skip the common dir if f in ('common',): continue diff --git a/src/home/telegram/_botcontext.py b/src/home/telegram/_botcontext.py index f343eeb..a143bfe 100644 --- a/src/home/telegram/_botcontext.py +++ b/src/home/telegram/_botcontext.py @@ -1,6 +1,7 @@ from typing import Optional, List -from telegram import Update, ParseMode, User, CallbackQuery +from telegram import Update, User, CallbackQuery +from telegram.constants import ParseMode from telegram.ext import CallbackContext from ._botdb import BotDatabase @@ -26,25 +27,25 @@ class Context: self._store = store self._user_lang = None - def reply(self, text, markup=None): + async def reply(self, text, markup=None): if markup is None: markup = self._markup_getter(self) kwargs = dict(parse_mode=ParseMode.HTML) if not isinstance(markup, IgnoreMarkup): kwargs['reply_markup'] = markup - return self._update.message.reply_text(text, **kwargs) + return await self._update.message.reply_text(text, **kwargs) - def reply_exc(self, e: Exception) -> None: - self.reply(exc2text(e), markup=IgnoreMarkup()) + async def reply_exc(self, e: Exception) -> None: + await self.reply(exc2text(e), markup=IgnoreMarkup()) - def answer(self, text: str = None): - self.callback_query.answer(text) + async def answer(self, text: str = None): + await self.callback_query.answer(text) - def edit(self, text, markup=None): + async def edit(self, text, markup=None): kwargs = dict(parse_mode=ParseMode.HTML) if not isinstance(markup, IgnoreMarkup): kwargs['reply_markup'] = markup - self.callback_query.edit_message_text(text, **kwargs) + await self.callback_query.edit_message_text(text, **kwargs) @property def text(self) -> str: diff --git a/src/home/telegram/bot.py b/src/home/telegram/bot.py index 10bfe06..7e22263 100644 --- a/src/home/telegram/bot.py +++ b/src/home/telegram/bot.py @@ -5,19 +5,19 @@ import itertools from enum import Enum, auto from functools import wraps -from typing import Optional, Union, Tuple +from typing import Optional, Union, Tuple, Coroutine from telegram import Update, ReplyKeyboardMarkup from telegram.ext import ( - Updater, - Filters, - BaseFilter, + Application, + filters, CommandHandler, MessageHandler, CallbackQueryHandler, CallbackContext, ConversationHandler ) +from telegram.ext.filters import BaseFilter from telegram.error import TimedOut from home.config import config @@ -33,26 +33,26 @@ from ._botcontext import Context db: Optional[BotDatabase] = None _user_filter: Optional[BaseFilter] = None -_cancel_filter = Filters.text(lang.all('cancel')) -_back_filter = Filters.text(lang.all('back')) -_cancel_and_back_filter = Filters.text(lang.all('back') + lang.all('cancel')) +_cancel_filter = filters.Text(lang.all('cancel')) +_back_filter = filters.Text(lang.all('back')) +_cancel_and_back_filter = filters.Text(lang.all('back') + lang.all('cancel')) _logger = logging.getLogger(__name__) -_updater: Optional[Updater] = None +_application: Optional[Application] = None _reporting: Optional[ReportingHelper] = None -_exception_handler: Optional[callable] = None +_exception_handler: Optional[Coroutine] = None _dispatcher = None _markup_getter: Optional[callable] = None -_start_handler_ref: Optional[callable] = None +_start_handler_ref: Optional[Coroutine] = None def text_filter(*args): if not _user_filter: raise RuntimeError('user_filter is not initialized') - return Filters.text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter + return filters.Text(args[0] if isinstance(args[0], list) else [*args]) & _user_filter -def _handler_of_handler(*args, **kwargs): +async def _handler_of_handler(*args, **kwargs): self = None context = None update = None @@ -99,7 +99,7 @@ def _handler_of_handler(*args, **kwargs): if self: _args.insert(0, self) - result = f(*_args, **kwargs) + result = await f(*_args, **kwargs) return result if not return_with_context else (result, ctx) except Exception as e: @@ -107,7 +107,7 @@ def _handler_of_handler(*args, **kwargs): if not _exception_handler(e, ctx) and not isinstance(e, TimedOut): _logger.exception(e) if not ctx.is_callback_context(): - ctx.reply_exc(e) + await ctx.reply_exc(e) else: notify_user(ctx.user_id, exc2text(e)) else: @@ -117,10 +117,10 @@ def _handler_of_handler(*args, **kwargs): def handler(**kwargs): def inner(f): @wraps(f) - def _handler(*args, **inner_kwargs): + async def _handler(*args, **inner_kwargs): if 'argument' in kwargs and kwargs['argument'] == 'message_key': inner_kwargs['argument'] = 'message_key' - return _handler_of_handler(f=f, *args, **inner_kwargs) + return await _handler_of_handler(f=f, *args, **inner_kwargs) messages = [] texts = [] @@ -139,43 +139,43 @@ def handler(**kwargs): new_messages = list(itertools.chain.from_iterable([lang.all(m) for m in messages])) texts += new_messages texts = list(set(texts)) - _updater.dispatcher.add_handler( + _application.add_handler( MessageHandler(text_filter(*texts), _handler), group=0 ) if 'command' in kwargs: - _updater.dispatcher.add_handler(CommandHandler(kwargs['command'], _handler), group=0) + _application.add_handler(CommandHandler(kwargs['command'], _handler), group=0) if 'callback' in kwargs: - _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) + _application.add_handler(CallbackQueryHandler(_handler, pattern=kwargs['callback']), group=0) return _handler return inner -def simplehandler(f: callable): +def simplehandler(f: Coroutine): @wraps(f) - def _handler(*args, **kwargs): - return _handler_of_handler(f=f, *args, **kwargs) + async def _handler(*args, **kwargs): + return await _handler_of_handler(f=f, *args, **kwargs) return _handler def callbackhandler(*args, **kwargs): def inner(f): @wraps(f) - def _handler(*args, **kwargs): - return _handler_of_handler(f=f, *args, **kwargs) + async def _handler(*args, **kwargs): + return await _handler_of_handler(f=f, *args, **kwargs) pattern_kwargs = {} if kwargs['callback'] != '*': pattern_kwargs['pattern'] = kwargs['callback'] - _updater.dispatcher.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) + _application.add_handler(CallbackQueryHandler(_handler, **pattern_kwargs), group=0) return _handler return inner -def exceptionhandler(f: callable): +async def exceptionhandler(f: callable): global _exception_handler if _exception_handler: _logger.warning('exception handler already set, we will overwrite it') @@ -198,10 +198,10 @@ def convinput(state, is_enter=False, **kwargs): ) @wraps(f) - def _impl(*args, **kwargs): - result, ctx = _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) + async def _impl(*args, **kwargs): + result, ctx = await _handler_of_handler(f=f, *args, **kwargs, return_with_context=True) if result == conversation.END: - start(ctx) + await start(ctx) return result return _impl @@ -252,7 +252,7 @@ class conversation: handlers.append(MessageHandler(text_filter(lang.all(message) if 'messages_lang_completed' not in kwargs else message), self.make_invoker(target_state))) if 'regex' in kwargs: - handlers.append(MessageHandler(Filters.regex(kwargs['regex']) & _user_filter, f)) + handlers.append(MessageHandler(filters.Regex(kwargs['regex']) & _user_filter, f)) if 'command' in kwargs: handlers.append(CommandHandler(kwargs['command'], f, _user_filter)) @@ -327,21 +327,21 @@ class conversation: @staticmethod @simplehandler - def invalid(ctx: Context): - ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) + async def invalid(ctx: Context): + await ctx.reply(ctx.lang('invalid_input'), markup=IgnoreMarkup()) # return 0 # FIXME is this needed @simplehandler - def cancel(self, ctx: Context): - start(ctx) + async def cancel(self, ctx: Context): + await start(ctx) self.set_user_state(ctx.user_id, None) return conversation.END @simplehandler - def back(self, ctx: Context): + async def back(self, ctx: Context): cur_state = self.get_user_state(ctx.user_id) if cur_state is None: - start(ctx) + await start(ctx) self.set_user_state(ctx.user_id, None) return conversation.END @@ -411,7 +411,7 @@ class LangConversation(conversation): START, = range(1) @conventer(START, command='lang') - def entry(self, ctx: Context): + async def entry(self, ctx: Context): self._logger.debug(f'current language: {ctx.user_lang}') buttons = [] @@ -419,11 +419,11 @@ class LangConversation(conversation): buttons.append(name) markup = ReplyKeyboardMarkup([buttons, [ctx.lang('cancel')]], one_time_keyboard=False) - ctx.reply(ctx.lang('select_language'), markup=markup) + await ctx.reply(ctx.lang('select_language'), markup=markup) return self.START @convinput(START, messages=lang.languages) - def input(self, ctx: Context): + async def input(self, ctx: Context): selected_lang = None for key, value in languages.items(): if value == ctx.text: @@ -434,30 +434,34 @@ class LangConversation(conversation): raise ValueError('could not find the language') db.set_user_lang(ctx.user_id, selected_lang) - ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) + await ctx.reply(ctx.lang('saved'), markup=IgnoreMarkup()) return self.END def initialize(): global _user_filter - global _updater + global _application + # global _updater global _dispatcher # init user_filter - if 'users' in config['bot']: - _logger.info('allowed users: ' + str(config['bot']['users'])) - _user_filter = Filters.user(config['bot']['users']) + _user_ids = config.app_config.get_user_ids() + if len(_user_ids) > 0: + _logger.info('allowed users: ' + str(_user_ids)) + _user_filter = filters.User(_user_ids) else: - _user_filter = Filters.all # not sure if this is correct + _user_filter = filters.ALL # not sure if this is correct - # init updater - _updater = Updater(config['bot']['token'], - request_kwargs={'read_timeout': 6, 'connect_timeout': 7}) + _application = Application.builder()\ + .token(config.app_config.get('bot.token'))\ + .connect_timeout(7)\ + .read_timeout(6)\ + .build() # transparently log all messages - _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, _logging_message_handler), group=10) - _updater.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) + # _application.dispatcher.add_handler(MessageHandler(filters.ALL & _user_filter, _logging_message_handler), group=10) + # _application.dispatcher.add_handler(CallbackQueryHandler(_logging_callback_handler), group=10) def run(start_handler=None, any_handler=None): @@ -473,37 +477,38 @@ def run(start_handler=None, any_handler=None): _start_handler_ref = start_handler - _updater.dispatcher.add_handler(LangConversation().get_handler(), group=0) - _updater.dispatcher.add_handler(CommandHandler('start', simplehandler(start_handler), _user_filter)) - _updater.dispatcher.add_handler(MessageHandler(Filters.all & _user_filter, any_handler)) + _application.add_handler(LangConversation().get_handler(), group=0) + _application.add_handler(CommandHandler('start', + callback=simplehandler(start_handler), + filters=_user_filter)) + _application.add_handler(MessageHandler(filters.ALL & _user_filter, any_handler)) - _updater.start_polling() - _updater.idle() + _application.run_polling() def add_conversation(conv: conversation) -> None: - _updater.dispatcher.add_handler(conv.get_handler(), group=0) + _application.add_handler(conv.get_handler(), group=0) def add_handler(h): - _updater.dispatcher.add_handler(h, group=0) + _application.add_handler(h, group=0) -def start(ctx: Context): - return _start_handler_ref(ctx) +async def start(ctx: Context): + return await _start_handler_ref(ctx) -def _default_start_handler(ctx: Context): +async def _default_start_handler(ctx: Context): if 'start_message' not in lang: - return ctx.reply('Please define start_message or override start()') - ctx.reply(ctx.lang('start_message')) + return await ctx.reply('Please define start_message or override start()') + await ctx.reply(ctx.lang('start_message')) @simplehandler -def _default_any_handler(ctx: Context): +async def _default_any_handler(ctx: Context): if 'invalid_command' not in lang: - return ctx.reply('Please define invalid_command or override any()') - ctx.reply(ctx.lang('invalid_command')) + return await ctx.reply('Please define invalid_command or override any()') + await ctx.reply(ctx.lang('invalid_command')) def _logging_message_handler(update: Update, context: CallbackContext): @@ -535,7 +540,7 @@ def notify_all(text_getter: callable, continue text = text_getter(db.get_user_lang(user_id)) - _updater.bot.send_message(chat_id=user_id, + _application.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML') @@ -543,33 +548,33 @@ def notify_all(text_getter: callable, def notify_user(user_id: int, text: Union[str, Exception], **kwargs) -> None: if isinstance(text, Exception): text = exc2text(text) - _updater.bot.send_message(chat_id=user_id, + _application.bot.send_message(chat_id=user_id, text=text, parse_mode='HTML', **kwargs) def send_photo(user_id, **kwargs): - _updater.bot.send_photo(chat_id=user_id, **kwargs) + _application.bot.send_photo(chat_id=user_id, **kwargs) def send_audio(user_id, **kwargs): - _updater.bot.send_audio(chat_id=user_id, **kwargs) + _application.bot.send_audio(chat_id=user_id, **kwargs) def send_file(user_id, **kwargs): - _updater.bot.send_document(chat_id=user_id, **kwargs) + _application.bot.send_document(chat_id=user_id, **kwargs) def edit_message_text(user_id, message_id, *args, **kwargs): - _updater.bot.edit_message_text(chat_id=user_id, + _application.bot.edit_message_text(chat_id=user_id, message_id=message_id, parse_mode='HTML', *args, **kwargs) def delete_message(user_id, message_id): - _updater.bot.delete_message(chat_id=user_id, message_id=message_id) + _application.bot.delete_message(chat_id=user_id, message_id=message_id) def set_database(_db: BotDatabase): diff --git a/src/home/telegram/config.py b/src/home/telegram/config.py new file mode 100644 index 0000000..7a46087 --- /dev/null +++ b/src/home/telegram/config.py @@ -0,0 +1,75 @@ +from ..config import ConfigUnit +from typing import Optional, Union +from abc import ABC +from enum import Enum + + +class TelegramUserListType(Enum): + USERS = 'users' + NOTIFY = 'notify_users' + + +class TelegramUserIdsConfig(ConfigUnit): + NAME = 'telegram_user_ids' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'roottype': 'dict', + 'type': 'integer' + } + + +_user_ids_config = TelegramUserIdsConfig() + + +def _user_id_mapper(user: Union[str, int]) -> int: + if isinstance(user, int): + return user + return _user_ids_config[user] + + +class TelegramChatsConfig(ConfigUnit): + NAME = 'telegram_chats' + + @staticmethod + def schema() -> Optional[dict]: + return { + 'type': 'dict', + 'schema': { + 'id': {'type': 'string', 'required': True}, + 'token': {'type': 'string', 'required': True}, + } + } + + +class TelegramBotConfig(ConfigUnit, ABC): + @staticmethod + def schema() -> Optional[dict]: + return { + 'bot': { + 'type': 'dict', + 'schema': { + 'token': {'type': 'string', 'required': True}, + TelegramUserListType.USERS: {**TelegramBotConfig._userlist_schema(), 'required': True}, + TelegramUserListType.NOTIFY: TelegramBotConfig._userlist_schema(), + } + } + } + + @staticmethod + def _userlist_schema() -> dict: + return {'type': 'list', 'schema': {'type': ['string', 'int']}} + + @staticmethod + def custom_validator(data): + for ult in TelegramUserListType: + users = data['bot'][ult.value] + for user in users: + if isinstance(user, str): + if user not in _user_ids_config: + raise ValueError(f'user {user} not found in {TelegramUserIdsConfig.NAME}') + + def get_user_ids(self, + ult: TelegramUserListType = TelegramUserListType.USERS) -> list[int]: + return list(map(_user_id_mapper, self['bot'][ult.value]))
\ No newline at end of file diff --git a/src/home/temphum/__init__.py b/src/home/temphum/__init__.py index 55a7e1f..46d14e6 100644 --- a/src/home/temphum/__init__.py +++ b/src/home/temphum/__init__.py @@ -1,18 +1 @@ -from .base import SensorType, TempHumSensor -from .si7021 import Si7021 -from .dht12 import DHT12 - -__all__ = [ - 'SensorType', - 'TempHumSensor', - 'create_sensor' -] - - -def create_sensor(type: SensorType, bus: int) -> TempHumSensor: - if type == SensorType.Si7021: - return Si7021(bus) - elif type == SensorType.DHT12: - return DHT12(bus) - else: - raise ValueError('unexpected sensor type') +from .base import SensorType, BaseSensor diff --git a/src/home/temphum/base.py b/src/home/temphum/base.py index e774433..602cab7 100644 --- a/src/home/temphum/base.py +++ b/src/home/temphum/base.py @@ -1,25 +1,19 @@ -import smbus - -from abc import abstractmethod, ABC +from abc import ABC from enum import Enum -class TempHumSensor: - @abstractmethod +class BaseSensor(ABC): + def __init__(self, bus: int): + super().__init__() + self.bus = smbus.SMBus(bus) + def humidity(self) -> float: pass - @abstractmethod def temperature(self) -> float: pass -class I2CTempHumSensor(TempHumSensor, ABC): - def __init__(self, bus: int): - super().__init__() - self.bus = smbus.SMBus(bus) - - class SensorType(Enum): Si7021 = 'si7021' - DHT12 = 'dht12' + DHT12 = 'dht12'
\ No newline at end of file diff --git a/src/home/temphum/dht12.py b/src/home/temphum/dht12.py deleted file mode 100644 index d495766..0000000 --- a/src/home/temphum/dht12.py +++ /dev/null @@ -1,22 +0,0 @@ -from .base import I2CTempHumSensor - - -class DHT12(I2CTempHumSensor): - i2c_addr = 0x5C - - def _measure(self): - raw = self.bus.read_i2c_block_data(self.i2c_addr, 0, 5) - if (raw[0] + raw[1] + raw[2] + raw[3]) & 0xff != raw[4]: - raise ValueError("checksum error") - return raw - - def temperature(self) -> float: - raw = self._measure() - temp = raw[2] + (raw[3] & 0x7f) * 0.1 - if raw[3] & 0x80: - temp *= -1 - return temp - - def humidity(self) -> float: - raw = self._measure() - return raw[0] + raw[1] * 0.1 diff --git a/src/home/temphum/i2c.py b/src/home/temphum/i2c.py new file mode 100644 index 0000000..7d8e2e3 --- /dev/null +++ b/src/home/temphum/i2c.py @@ -0,0 +1,52 @@ +import abc +import smbus + +from .base import BaseSensor, SensorType + + +class I2CSensor(BaseSensor, abc.ABC): + def __init__(self, bus: int): + super().__init__() + self.bus = smbus.SMBus(bus) + + +class DHT12(I2CSensor): + i2c_addr = 0x5C + + def _measure(self): + raw = self.bus.read_i2c_block_data(self.i2c_addr, 0, 5) + if (raw[0] + raw[1] + raw[2] + raw[3]) & 0xff != raw[4]: + raise ValueError("checksum error") + return raw + + def temperature(self) -> float: + raw = self._measure() + temp = raw[2] + (raw[3] & 0x7f) * 0.1 + if raw[3] & 0x80: + temp *= -1 + return temp + + def humidity(self) -> float: + raw = self._measure() + return raw[0] + raw[1] * 0.1 + + +class Si7021(I2CSensor): + i2c_addr = 0x40 + + def temperature(self) -> float: + raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE3, 2) + return 175.72 * (raw[0] << 8 | raw[1]) / 65536.0 - 46.85 + + def humidity(self) -> float: + raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE5, 2) + return 125.0 * (raw[0] << 8 | raw[1]) / 65536.0 - 6.0 + + +def create_sensor(type: SensorType, bus: int) -> BaseSensor: + if type == SensorType.Si7021: + return Si7021(bus) + elif type == SensorType.DHT12: + return DHT12(bus) + else: + raise ValueError('unexpected sensor type') diff --git a/src/home/temphum/si7021.py b/src/home/temphum/si7021.py deleted file mode 100644 index 6289e15..0000000 --- a/src/home/temphum/si7021.py +++ /dev/null @@ -1,13 +0,0 @@ -from .base import I2CTempHumSensor - - -class Si7021(I2CTempHumSensor): - i2c_addr = 0x40 - - def temperature(self) -> float: - raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE3, 2) - return 175.72 * (raw[0] << 8 | raw[1]) / 65536.0 - 46.85 - - def humidity(self) -> float: - raw = self.bus.read_i2c_block_data(self.i2c_addr, 0xE5, 2) - return 125.0 * (raw[0] << 8 | raw[1]) / 65536.0 - 6.0 diff --git a/src/home/util.py b/src/home/util.py index 93a9d8f..35505bc 100644 --- a/src/home/util.py +++ b/src/home/util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import socket import time @@ -6,17 +8,57 @@ import traceback import logging import string import random +import re from enum import Enum from datetime import datetime from typing import Tuple, Optional, List from zlib import adler32 -Addr = Tuple[str, int] # network address type (host, port) - logger = logging.getLogger(__name__) +def validate_ipv4_or_hostname(address: str, raise_exception: bool = False) -> bool: + if re.match(r'^(\d{1,3}\.){3}\d{1,3}$', address): + parts = address.split('.') + if all(0 <= int(part) < 256 for part in parts): + return True + else: + if raise_exception: + raise ValueError(f"invalid IPv4 address: {address}") + return False + + if re.match(r'^[a-zA-Z0-9.-]+$', address): + return True + else: + if raise_exception: + raise ValueError(f"invalid hostname: {address}") + return False + + +class Addr: + host: str + port: int + + def __init__(self, host: str, port: int): + self.host = host + self.port = port + + @staticmethod + def fromstring(addr: str) -> Addr: + if addr.count(':') != 1: + raise ValueError('invalid host:port format') + + host, port = addr.split(':') + validate_ipv4_or_hostname(host, raise_exception=True) + + port = int(port) + if not 0 <= port <= 65535: + raise ValueError(f'invalid port {port}') + + return Addr(host, port) + + # https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks def chunks(lst, n): """Yield successive n-sized chunks from lst.""" @@ -45,21 +87,6 @@ def ipv4_valid(ip: str) -> bool: return False -def parse_addr(addr: str) -> Addr: - if addr.count(':') != 1: - raise ValueError('invalid host:port format') - - host, port = addr.split(':') - if not ipv4_valid(host): - raise ValueError('invalid ipv4 address') - - port = int(port) - if not 0 <= port <= 65535: - raise ValueError('invalid port') - - return host, port - - def strgen(n: int): return ''.join(random.choices(string.ascii_letters + string.digits, k=n)) @@ -193,4 +220,11 @@ def filesize_fmt(num, suffix="B") -> str: class HashableEnum(Enum): def hash(self) -> int: - return adler32(self.name.encode())
\ No newline at end of file + return adler32(self.name.encode()) + + +def next_tick_gen(freq): + t = time.time() + while True: + t += freq + yield max(t - time.time(), 0)
\ No newline at end of file diff --git a/src/inverter_bot.py b/src/inverter_bot.py index fd5acf3..ecf01fc 100755 --- a/src/inverter_bot.py +++ b/src/inverter_bot.py @@ -4,14 +4,16 @@ import re import datetime import json import itertools +import sys from inverterd import Format, InverterError from html import escape from typing import Optional, Tuple, Union from home.util import chunks -from home.config import config +from home.config import config, AppConfigUnit from home.telegram import bot +from home.telegram.config import TelegramBotConfig, TelegramUserListType from home.inverter import ( wrapper_instance as inverter, beautify_table, @@ -24,12 +26,17 @@ from home.inverter.types import ( ACMode, OutputSourcePriority ) -from home.database.inverter_time_formats import * +from home.database.inverter_time_formats import FormatDate from home.api.types import BotType from home.api import WebAPIClient from telegram import ReplyKeyboardMarkup, InlineKeyboardMarkup, InlineKeyboardButton -monitor: Optional[InverterMonitor] = None + +if __name__ != '__main__': + print(f'this script can not be imported as module', file=sys.stderr) + sys.exit(1) + + db = None LT = escape('<=') flags_map = { @@ -42,9 +49,56 @@ flags_map = { 'alarm_on_on_primary_source_interrupt': 'ALRM', 'fault_code_record': 'FTCR', } - logger = logging.getLogger(__name__) -config.load('inverter_bot') + + +class InverterBotConfig(AppConfigUnit, TelegramBotConfig): + NAME = 'inverter_bot' + + @staticmethod + def schema() -> Optional[dict]: + acmode_item_schema = { + 'thresholds': { + 'type': 'list', + 'required': True, + 'schema': { + 'type': 'list', + 'min': 40, + 'max': 60 + }, + }, + 'initial_current': {'type': 'integer'} + } + + return { + **super(TelegramBotConfig).schema(), + 'ac_mode': { + 'type': 'dict', + 'required': True, + 'schema': { + 'generator': acmode_item_schema, + 'utilities': acmode_item_schema + } + }, + 'monitor': { + 'type': 'dict', + 'required': True, + 'schema': { + 'vlow': {'type': 'integer', 'required': True}, + 'vcrit': {'type': 'integer', 'required': True}, + 'gen_currents': {'type': 'list', 'schema': {'type': 'integer'}, 'required': True}, + 'gen_raise_intervals': {'type': 'list', 'schema': {'type': 'integer'}, 'required': True}, + 'gen_cur30_v_limit': {'type': 'float', 'required': True}, + 'gen_cur20_v_limit': {'type': 'float', 'required': True}, + 'gen_cur10_v_limit': {'type': 'float', 'required': True}, + 'gen_floating_v': {'type': 'integer', 'required': True}, + 'gen_floating_time_max': {'type': 'integer', 'required': True} + } + } + } + + +config.load_app(InverterBotConfig) bot.initialize() bot.lang.ru( @@ -863,28 +917,27 @@ class InverterStore(bot.BotDatabase): self.commit() -if __name__ == '__main__': - inverter.init(host=config['inverter']['ip'], port=config['inverter']['port']) +inverter.init(host=config['inverter']['ip'], port=config['inverter']['port']) - bot.set_database(InverterStore()) - bot.enable_logging(BotType.INVERTER) +bot.set_database(InverterStore()) +bot.enable_logging(BotType.INVERTER) - bot.add_conversation(SettingsConversation(enable_back=True)) - bot.add_conversation(ConsumptionConversation(enable_back=True)) +bot.add_conversation(SettingsConversation(enable_back=True)) +bot.add_conversation(ConsumptionConversation(enable_back=True)) - monitor = InverterMonitor() - monitor.set_charging_event_handler(monitor_charging) - monitor.set_battery_event_handler(monitor_battery) - monitor.set_util_event_handler(monitor_util) - monitor.set_error_handler(monitor_error) - monitor.set_osp_need_change_callback(osp_change_cb) +monitor = InverterMonitor() +monitor.set_charging_event_handler(monitor_charging) +monitor.set_battery_event_handler(monitor_battery) +monitor.set_util_event_handler(monitor_util) +monitor.set_error_handler(monitor_error) +monitor.set_osp_need_change_callback(osp_change_cb) - setacmode(getacmode()) +setacmode(getacmode()) - if not config.get('monitor.disabled'): - logging.info('starting monitor') - monitor.start() +if not config.get('monitor.disabled'): + logging.info('starting monitor') + monitor.start() - bot.run() +bot.run() - monitor.stop() +monitor.stop() diff --git a/src/inverter_mqtt_receiver.py b/src/inverter_mqtt_receiver.py deleted file mode 100755 index d40647e..0000000 --- a/src/inverter_mqtt_receiver.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python3 -import paho.mqtt.client as mqtt -import re - -from home.mqtt import MqttBase -from home.mqtt.payload.inverter import Status, Generation -from home.database import InverterDatabase -from home.config import config - - -class MqttReceiver(MqttBase): - def __init__(self): - super().__init__(clean_session=False) - self.database = InverterDatabase() - - def on_connect(self, client: mqtt.Client, userdata, flags, rc): - super().on_connect(client, userdata, flags, rc) - self._logger.info("subscribing to hk/#") - client.subscribe('hk/#', qos=1) - - def on_message(self, client: mqtt.Client, userdata, msg): - super().on_message(client, userdata, msg) - try: - match = re.match(r'(?:home|hk)/(\d+)/(status|gen)', msg.topic) - if not match: - return - - # FIXME string home_id must be supported - home_id, what = int(match.group(1)), match.group(2) - if what == 'gen': - gen = Generation.unpack(msg.payload) - self.database.add_generation(home_id, gen.time, gen.wh) - - elif what == 'status': - s = Status.unpack(msg.payload) - self.database.add_status(home_id, - client_time=s.time, - grid_voltage=int(s.grid_voltage*10), - grid_freq=int(s.grid_freq * 10), - ac_output_voltage=int(s.ac_output_voltage * 10), - ac_output_freq=int(s.ac_output_freq * 10), - ac_output_apparent_power=s.ac_output_apparent_power, - ac_output_active_power=s.ac_output_active_power, - output_load_percent=s.output_load_percent, - battery_voltage=int(s.battery_voltage * 10), - battery_voltage_scc=int(s.battery_voltage_scc * 10), - battery_voltage_scc2=int(s.battery_voltage_scc2 * 10), - battery_discharge_current=s.battery_discharge_current, - battery_charge_current=s.battery_charge_current, - battery_capacity=s.battery_capacity, - inverter_heat_sink_temp=s.inverter_heat_sink_temp, - mppt1_charger_temp=s.mppt1_charger_temp, - mppt2_charger_temp=s.mppt2_charger_temp, - pv1_input_power=s.pv1_input_power, - pv2_input_power=s.pv2_input_power, - pv1_input_voltage=int(s.pv1_input_voltage * 10), - pv2_input_voltage=int(s.pv2_input_voltage * 10), - mppt1_charger_status=s.mppt1_charger_status, - mppt2_charger_status=s.mppt2_charger_status, - battery_power_direction=s.battery_power_direction, - dc_ac_power_direction=s.dc_ac_power_direction, - line_power_direction=s.line_power_direction, - load_connected=s.load_connected) - - except Exception as e: - self._logger.exception(str(e)) - - -if __name__ == '__main__': - config.load('inverter_mqtt_receiver') - - server = MqttReceiver() - server.connect_and_loop() - diff --git a/src/inverter_mqtt_sender.py b/src/inverter_mqtt_sender.py deleted file mode 100755 index fb2a2d8..0000000 --- a/src/inverter_mqtt_sender.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python3 -import time -import datetime -import json -import inverterd - -from home.config import config -from home.mqtt import MqttBase, poll_tick -from home.mqtt.payload.inverter import Status, Generation - - -class MqttClient(MqttBase): - def __init__(self): - super().__init__() - - self._home_id = config['mqtt']['home_id'] - - self._inverter = inverterd.Client() - self._inverter.connect() - self._inverter.format(inverterd.Format.SIMPLE_JSON) - - def poll_inverter(self): - freq = int(config['mqtt']['inverter']['poll_freq']) - gen_freq = int(config['mqtt']['inverter']['generation_poll_freq']) - - g = poll_tick(freq) - gen_prev = 0 - while True: - time.sleep(next(g)) - - # read status - now = time.time() - try: - raw = self._inverter.exec('get-status') - except inverterd.InverterError as e: - self._logger.error(f'inverter error: {str(e)}') - # TODO send to server - continue - - data = json.loads(raw)['data'] - status = Status(time=round(now), **data) # FIXME this will crash with 99% probability - - self._client.publish(f'hk/{self._home_id}/status', - payload=status.pack(), - qos=1) - - # read today's generation stat - now = time.time() - if gen_prev == 0 or now - gen_prev >= gen_freq: - gen_prev = now - today = datetime.date.today() - try: - raw = self._inverter.exec('get-day-generated', (today.year, today.month, today.day)) - except inverterd.InverterError as e: - self._logger.error(f'inverter error: {str(e)}') - # TODO send to server - continue - - data = json.loads(raw)['data'] - gen = Generation(time=round(now), wh=data['wh']) - self._client.publish(f'hk/{self._home_id}/gen', - payload=gen.pack(), - qos=1) - - -if __name__ == '__main__': - config.load('inverter_mqtt_sender') - - client = MqttClient() - client.configure_tls() - client.connect_and_loop(loop_forever=False) - client.poll_inverter()
\ No newline at end of file diff --git a/src/inverter_mqtt_util.py b/src/inverter_mqtt_util.py new file mode 100755 index 0000000..791bf80 --- /dev/null +++ b/src/inverter_mqtt_util.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +from argparse import ArgumentParser +from home.config import config, app_config +from home.mqtt import MqttWrapper, MqttNode + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('mode', type=str, choices=('sender', 'receiver'), nargs=1) + + config.load_app('inverter_mqtt_util', parser=parser) + arg = parser.parse_args() + mode = arg.mode[0] + + mqtt = MqttWrapper(client_id=f'inverter_mqtt_{mode}', + clean_session=mode != 'receiver') + node = MqttNode(node_id='inverter') + module_kwargs = {} + if mode == 'sender': + module_kwargs['status_poll_freq'] = int(app_config['poll_freq']) + module_kwargs['generation_poll_freq'] = int(app_config['generation_poll_freq']) + node.load_module('inverter', **module_kwargs) + mqtt.add_node(node) + + mqtt.connect_and_loop() diff --git a/src/ipcam_server.py b/src/ipcam_server.py index 2c4915d..a54cd35 100755 --- a/src/ipcam_server.py +++ b/src/ipcam_server.py @@ -556,7 +556,7 @@ logger = logging.getLogger(__name__) # -------------------- if __name__ == '__main__': - config.load('ipcam_server') + config.load_app('ipcam_server') open_database() diff --git a/src/mqtt_node_util.py b/src/mqtt_node_util.py new file mode 100755 index 0000000..ce954ae --- /dev/null +++ b/src/mqtt_node_util.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +import os.path + +from time import sleep +from typing import Optional +from argparse import ArgumentParser, ArgumentError + +from home.config import config +from home.mqtt import MqttNode, MqttWrapper, get_mqtt_modules +from home.mqtt import MqttNodesConfig + +mqtt_node: Optional[MqttNode] = None +mqtt: Optional[MqttWrapper] = None + + +if __name__ == '__main__': + nodes_config = MqttNodesConfig() + + parser = ArgumentParser() + parser.add_argument('--node-id', type=str, required=True, choices=nodes_config.get_nodes(only_names=True)) + parser.add_argument('--modules', type=str, choices=get_mqtt_modules(), nargs='*', + help='mqtt modules to include') + parser.add_argument('--switch-relay', choices=[0, 1], type=int, + help='send relay state') + parser.add_argument('--push-ota', type=str, metavar='OTA_FILENAME', + help='push OTA, receives path to firmware.bin') + + config.load_app(parser=parser, no_config=True) + arg = parser.parse_args() + + if arg.switch_relay is not None and 'relay' not in arg.modules: + raise ArgumentError(None, '--relay is only allowed when \'relay\' module included in --modules') + + mqtt = MqttWrapper(randomize_client_id=True, + client_id='mqtt_node_util') + mqtt_node = MqttNode(node_id=arg.node_id, + node_secret=nodes_config.get_node(arg.node_id)['password']) + + mqtt.add_node(mqtt_node) + + # must-have modules + ota_module = mqtt_node.load_module('ota') + mqtt_node.load_module('diagnostics') + + if arg.modules: + for m in arg.modules: + module_instance = mqtt_node.load_module(m) + if m == 'relay' and arg.switch_relay is not None: + module_instance.switchpower(arg.switch_relay == 1) + + try: + mqtt.connect_and_loop(loop_forever=False) + + if arg.push_ota: + if not os.path.exists(arg.push_ota): + raise OSError(f'--push-ota: file \"{arg.push_ota}\" does not exists') + ota_module.push_ota(arg.push_ota, 1) + + while True: + sleep(0.1) + + except KeyboardInterrupt: + mqtt.disconnect() diff --git a/src/openwrt_log_analyzer.py b/src/openwrt_log_analyzer.py index d31c3bf..35b755f 100755 --- a/src/openwrt_log_analyzer.py +++ b/src/openwrt_log_analyzer.py @@ -54,7 +54,7 @@ def main(mac: str, if __name__ == '__main__': - config.load('openwrt_log_analyzer') + config.load_app('openwrt_log_analyzer') for ap in config['openwrt_log_analyzer']['aps']: state_file = config['simple_state']['file'] state_file = state_file.replace('.txt', f'-{ap}.txt') diff --git a/src/openwrt_logger.py b/src/openwrt_logger.py index 3b19de2..97fe7a9 100755 --- a/src/openwrt_logger.py +++ b/src/openwrt_logger.py @@ -46,7 +46,7 @@ if __name__ == '__main__': parser.add_argument('--access-point', type=int, required=True, help='access point number') - arg = config.load('openwrt_logger', parser=parser) + arg = config.load_app('openwrt_logger', parser=parser) state = SimpleState(file=config['simple_state']['file'].replace('{ap}', str(arg.access_point)), default={'seek': 0, 'size': 0}) diff --git a/src/pio_ini.py b/src/pio_ini.py index 19dd707..920c3e5 100755 --- a/src/pio_ini.py +++ b/src/pio_ini.py @@ -54,12 +54,17 @@ def bsd_parser(product_config: dict, arg_kwargs['type'] = int elif kwargs['type'] == 'int': arg_kwargs['type'] = int + elif kwargs['type'] == 'bool': + arg_kwargs['action'] = 'store_true' + arg_kwargs['required'] = False else: raise TypeError(f'unsupported type {kwargs["type"]} for define {define_name}') else: arg_kwargs['action'] = 'store_true' - parser.add_argument(f'--{define_name}', required=True, **arg_kwargs) + if 'required' not in arg_kwargs: + arg_kwargs['required'] = True + parser.add_argument(f'--{define_name}', **arg_kwargs) bsd_walk(product_config, f) @@ -76,6 +81,9 @@ def bsd_get(product_config: dict, enums.append(f'CONFIG_{define_name}') defines[f'CONFIG_{define_name}'] = f'HOMEKIT_{attr_value.upper()}' return + if kwargs['type'] == 'bool': + defines[f'CONFIG_{define_name}'] = True + return defines[f'CONFIG_{define_name}'] = str(attr_value) bsd_walk(product_config, f) return defines, enums diff --git a/src/polaris_kettle_bot.py b/src/polaris_kettle_bot.py index 088707d..80baef3 100755 --- a/src/polaris_kettle_bot.py +++ b/src/polaris_kettle_bot.py @@ -10,7 +10,7 @@ import paho.mqtt.client as mqtt from home.telegram import bot from home.api.types import BotType -from home.mqtt import MqttBase +from home.mqtt import Mqtt from home.config import config from home.util import chunks from syncleo import ( @@ -41,7 +41,7 @@ from telegram.ext import ( ) logger = logging.getLogger(__name__) -config.load('polaris_kettle_bot') +config.load_app('polaris_kettle_bot') primary_choices = (70, 80, 90, 100) all_choices = range( @@ -204,7 +204,7 @@ class KettleInfo: class KettleController(threading.Thread, - MqttBase, + Mqtt, DeviceListener, IncomingMessageListener, KettleInfoListener, @@ -224,7 +224,7 @@ class KettleController(threading.Thread, def __init__(self): # basic setup - MqttBase.__init__(self, clean_session=False) + Mqtt.__init__(self, clean_session=False) threading.Thread.__init__(self) self._logger = logging.getLogger(self.__class__.__name__) diff --git a/src/polaris_kettle_util.py b/src/polaris_kettle_util.py index 81326dd..12c4388 100755 --- a/src/polaris_kettle_util.py +++ b/src/polaris_kettle_util.py @@ -8,7 +8,7 @@ import paho.mqtt.client as mqtt from typing import Optional from argparse import ArgumentParser from queue import SimpleQueue -from home.mqtt import MqttBase +from home.mqtt import Mqtt from home.config import config from syncleo import ( Kettle, @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) control_tasks = SimpleQueue() -class MqttServer(MqttBase): +class MqttServer(Mqtt): def __init__(self): super().__init__(clean_session=False) @@ -75,7 +75,7 @@ def main(): parser.add_argument('-t', '--temperature', dest='temp', type=int, default=tempmax, choices=range(tempmin, tempmax+tempstep, tempstep)) - arg = config.load('polaris_kettle_util', use_cli=True, parser=parser) + arg = config.load_app('polaris_kettle_util', use_cli=True, parser=parser) if arg.mode == 'mqtt': server = MqttServer() diff --git a/src/pump_bot.py b/src/pump_bot.py index de925db..25f06fd 100755 --- a/src/pump_bot.py +++ b/src/pump_bot.py @@ -2,14 +2,34 @@ from enum import Enum from typing import Optional from telegram import ReplyKeyboardMarkup, User +from time import time +from datetime import datetime -from home.config import config +from home.config import config, is_development_mode from home.telegram import bot from home.telegram._botutil import user_any_name from home.relay.sunxi_h3_client import RelayClient from home.api.types import BotType +from home.mqtt import MqttNode, MqttWrapper, MqttPayload +from home.mqtt.module.relay import MqttPowerStatusPayload, MqttRelayModule +from home.mqtt.module.temphum import MqttTemphumDataPayload +from home.mqtt.module.diagnostics import InitialDiagnosticsPayload, DiagnosticsPayload -config.load('pump_bot') + +config.load_app('pump_bot') + +mqtt: Optional[MqttWrapper] = None +mqtt_node: Optional[MqttNode] = None +mqtt_relay_module: Optional[MqttRelayModule] = None +time_format = '%d.%m.%Y, %H:%M:%S' + +watering_mcu_status = { + 'last_time': 0, + 'last_boot_time': 0, + 'relay_opened': False, + 'ambient_temp': 0.0, + 'ambient_rh': 0.0, +} bot.initialize() bot.lang.ru( @@ -18,17 +38,27 @@ bot.lang.ru( enable="Включить", enable_silently="Включить тихо", - enabled="Включен ✅", + enabled="Насос включен ✅", disable="Выключить", disable_silently="Выключить тихо", - disabled="Выключен ❌", + disabled="Насос выключен ❌", + + start_watering="Включить полив", + stop_watering="Отключить полив", + + status="Статус насоса", + watering_status="Статус полива", - status="Статус", done="Готово 👌", + sent="Команда отправлена", + user_action_notification='Пользователь <a href="tg://user?id=%d">%s</a> <b>%s</b> насос.', + user_watering_notification='Пользователь <a href="tg://user?id=%d">%s</a> <b>%s</b> полив.', user_action_on="включил", user_action_off="выключил", + user_action_watering_on="включил", + user_action_watering_off="выключил", ) bot.lang.en( start_message="Select command on the keyboard", @@ -36,23 +66,35 @@ bot.lang.en( enable="Turn ON", enable_silently="Turn ON silently", - enabled="Turned ON ✅", + enabled="The pump is turned ON ✅", disable="Turn OFF", disable_silently="Turn OFF silently", - disabled="Turned OFF ❌", + disabled="The pump is turned OFF ❌", + + start_watering="Start watering", + stop_watering="Stop watering", + + status="Pump status", + watering_status="Watering status", - status="Status", done="Done 👌", + sent="Request sent", + user_action_notification='User <a href="tg://user?id=%d">%s</a> turned the pump <b>%s</b>.', + user_watering_notification='User <a href="tg://user?id=%d">%s</a> <b>%s</b> the watering.', user_action_on="ON", user_action_off="OFF", + user_action_watering_on="started", + user_action_watering_off="stopped", ) class UserAction(Enum): ON = 'on' OFF = 'off' + WATERING_ON = 'watering_on' + WATERING_OFF = 'watering_off' def get_relay() -> RelayClient: @@ -75,11 +117,24 @@ def off(ctx: bot.Context, silent=False) -> None: notify(ctx.user, UserAction.OFF) +def watering_on(ctx: bot.Context) -> None: + mqtt_relay_module.switchpower(True, config.get('mqtt_water_relay.secret')) + ctx.reply(ctx.lang('sent')) + notify(ctx.user, UserAction.WATERING_ON) + + +def watering_off(ctx: bot.Context) -> None: + mqtt_relay_module.switchpower(False, config.get('mqtt_water_relay.secret')) + ctx.reply(ctx.lang('sent')) + notify(ctx.user, UserAction.WATERING_OFF) + + def notify(user: User, action: UserAction) -> None: + notification_key = 'user_watering_notification' if action in (UserAction.WATERING_ON, UserAction.WATERING_OFF) else 'user_action_notification' def text_getter(lang: str): action_name = bot.lang.get(f'user_action_{action.value}', lang) user_name = user_any_name(user) - return 'ℹ ' + bot.lang.get('user_action_notification', lang, + return 'ℹ ' + bot.lang.get(notification_key, lang, user.id, user_name, action_name) bot.notify_all(text_getter, exclude=(user.id,)) @@ -100,6 +155,16 @@ def disable_handler(ctx: bot.Context) -> None: off(ctx) +@bot.handler(message='start_watering') +def start_watering(ctx: bot.Context) -> None: + watering_on(ctx) + + +@bot.handler(message='stop_watering') +def stop_watering(ctx: bot.Context) -> None: + watering_off(ctx) + + @bot.handler(message='disable_silently') def disable_s_handler(ctx: bot.Context) -> None: off(ctx, True) @@ -112,20 +177,79 @@ def status(ctx: bot.Context) -> None: ) +def _get_timestamp_as_string(timestamp: int) -> str: + if timestamp != 0: + return datetime.fromtimestamp(timestamp).strftime(time_format) + else: + return 'unknown' + + +@bot.handler(message='watering_status') +def watering_status(ctx: bot.Context) -> None: + buf = '' + if 0 < watering_mcu_status["last_time"] < time()-1800: + buf += '<b>WARNING! long time no reports from mcu! maybe something\'s wrong</b>\n' + buf += f'last report time: <b>{_get_timestamp_as_string(watering_mcu_status["last_time"])}</b>\n' + if watering_mcu_status["last_boot_time"] != 0: + buf += f'boot time: <b>{_get_timestamp_as_string(watering_mcu_status["last_boot_time"])}</b>\n' + buf += 'relay opened: <b>' + ('yes' if watering_mcu_status['relay_opened'] else 'no') + '</b>\n' + buf += f'ambient temp & humidity: <b>{watering_mcu_status["ambient_temp"]} °C, {watering_mcu_status["ambient_rh"]}%</b>' + ctx.reply(buf) + + @bot.defaultreplymarkup def markup(ctx: Optional[bot.Context]) -> Optional[ReplyKeyboardMarkup]: - buttons = [ - [ctx.lang('enable'), ctx.lang('disable')], - ] - + buttons = [] if ctx.user_id in config['bot']['silent_users']: buttons.append([ctx.lang('enable_silently'), ctx.lang('disable_silently')]) - - buttons.append([ctx.lang('status')]) + buttons.append([ctx.lang('enable'), ctx.lang('disable'), ctx.lang('status')],) + buttons.append([ctx.lang('start_watering'), ctx.lang('stop_watering'), ctx.lang('watering_status')]) return ReplyKeyboardMarkup(buttons, one_time_keyboard=False) +def mqtt_payload_callback(mqtt_node: MqttNode, payload: MqttPayload): + global watering_mcu_status + + types_the_node_can_send = ( + InitialDiagnosticsPayload, + DiagnosticsPayload, + MqttTemphumDataPayload, + MqttPowerStatusPayload + ) + for cl in types_the_node_can_send: + if isinstance(payload, cl): + watering_mcu_status['last_time'] = int(time()) + break + + if isinstance(payload, InitialDiagnosticsPayload): + watering_mcu_status['last_boot_time'] = int(time()) + + elif isinstance(payload, MqttTemphumDataPayload): + watering_mcu_status['ambient_temp'] = payload.temp + watering_mcu_status['ambient_rh'] = payload.rh + + elif isinstance(payload, MqttPowerStatusPayload): + watering_mcu_status['relay_opened'] = payload.opened + + if __name__ == '__main__': + mqtt = MqttWrapper() + mqtt_node = MqttNode(node_id=config.get('mqtt_water_relay.node_id')) + if is_development_mode(): + mqtt_node.load_module('diagnostics') + + mqtt_node.load_module('temphum') + mqtt_relay_module = mqtt_node.load_module('relay') + + mqtt_node.add_payload_callback(mqtt_payload_callback) + + mqtt.connect_and_loop(loop_forever=False) + bot.enable_logging(BotType.PUMP) bot.run() + + try: + mqtt.disconnect() + except: + pass diff --git a/src/pump_mqtt_bot.py b/src/pump_mqtt_bot.py index d3b6de4..4036d3a 100755 --- a/src/pump_mqtt_bot.py +++ b/src/pump_mqtt_bot.py @@ -8,13 +8,12 @@ from telegram import ReplyKeyboardMarkup, User from home.config import config from home.telegram import bot from home.telegram._botutil import user_any_name -from home.mqtt.esp import MqttEspDevice -from home.mqtt import MqttRelay, MqttRelayState -from home.mqtt.payload import MqttPayload -from home.mqtt.payload.relay import InitialDiagnosticsPayload, DiagnosticsPayload +from home.mqtt import MqttNode, MqttPayload +from home.mqtt.module.relay import MqttRelayState +from home.mqtt.module.diagnostics import InitialDiagnosticsPayload, DiagnosticsPayload -config.load('pump_mqtt_bot') +config.load_app('pump_mqtt_bot') bot.initialize() bot.lang.ru( @@ -70,7 +69,7 @@ bot.lang.en( ) -mqtt_relay: Optional[MqttRelay] = None +mqtt: Optional[MqttNode] = None relay_state = MqttRelayState() @@ -99,14 +98,14 @@ def notify(user: User, action: UserAction) -> None: @bot.handler(message='enable') def enable_handler(ctx: bot.Context) -> None: - mqtt_relay.set_power(config['mqtt']['home_id'], True) + mqtt.set_power(config['mqtt']['home_id'], True) ctx.reply(ctx.lang('done')) notify(ctx.user, UserAction.ON) @bot.handler(message='disable') def disable_handler(ctx: bot.Context) -> None: - mqtt_relay.set_power(config['mqtt']['home_id'], False) + mqtt.set_power(config['mqtt']['home_id'], False) ctx.reply(ctx.lang('done')) notify(ctx.user, UserAction.OFF) @@ -157,13 +156,12 @@ def markup(ctx: Optional[bot.Context]) -> Optional[ReplyKeyboardMarkup]: if __name__ == '__main__': - mqtt_relay = MqttRelay(devices=MqttEspDevice(id=config['mqtt']['home_id'], - secret=config['mqtt']['home_secret'])) - mqtt_relay.set_message_callback(on_mqtt_message) - mqtt_relay.configure_tls() - mqtt_relay.connect_and_loop(loop_forever=False) + mqtt = MqttRelay(devices=MqttEspDevice(id=config['mqtt']['home_id'], + secret=config['mqtt']['home_secret'])) + mqtt.set_message_callback(on_mqtt_message) + mqtt.connect_and_loop(loop_forever=False) # bot.enable_logging(BotType.PUMP_MQTT) bot.run(start_handler=start) - mqtt_relay.disconnect() + mqtt.disconnect() diff --git a/src/relay_mqtt_bot.py b/src/relay_mqtt_bot.py index ebbff82..9de8c7e 100755 --- a/src/relay_mqtt_bot.py +++ b/src/relay_mqtt_bot.py @@ -1,18 +1,62 @@ #!/usr/bin/env python3 +import sys + from enum import Enum -from typing import Optional +from typing import Optional, Union from telegram import ReplyKeyboardMarkup from functools import partial -from home.config import config +from home.config import config, AppConfigUnit, Translation from home.telegram import bot -from home.mqtt import MqttRelay, MqttRelayState -from home.mqtt.esp import MqttEspDevice -from home.mqtt.payload import MqttPayload -from home.mqtt.payload.relay import InitialDiagnosticsPayload, DiagnosticsPayload +from home.telegram.config import TelegramBotConfig +from home.mqtt import MqttPayload, MqttNode, MqttWrapper, MqttModule +from home.mqtt import MqttNodesConfig +from home.mqtt.module.relay import MqttRelayModule, MqttRelayState +from home.mqtt.module.diagnostics import InitialDiagnosticsPayload, DiagnosticsPayload + + +if __name__ != '__main__': + print(f'this script can not be imported as module', file=sys.stderr) + sys.exit(1) + + +mqtt_nodes_config = MqttNodesConfig() + + +class RelayMqttBotConfig(AppConfigUnit, TelegramBotConfig): + NAME = 'relay_mqtt_bot' + + _strings: Translation + + def __init__(self): + super().__init__() + self._strings = Translation('mqtt_nodes') + + @staticmethod + def schema() -> Optional[dict]: + return { + **super(TelegramBotConfig).schema(), + 'relay_nodes': { + 'type': 'list', + 'required': True, + 'schema': { + 'type': 'string' + } + }, + } + @staticmethod + def custom_validator(data): + relay_node_names = mqtt_nodes_config.get_nodes(filters=('relay',), only_names=True) + for node in data['relay_nodes']: + if node not in relay_node_names: + raise ValueError(f'unknown relay node "{node}"') -config.load('relay_mqtt_bot') + def get_relay_name_translated(self, lang: str, relay_name: str) -> str: + return self._strings.get(lang)[relay_name]['relay'] + + +config.load_app(RelayMqttBotConfig) bot.initialize() bot.lang.ru( @@ -34,7 +78,10 @@ status_emoji = { 'on': '✅', 'off': '❌' } -mqtt_relay: Optional[MqttRelay] = None + + +mqtt: MqttWrapper +relay_nodes: dict[str, Union[MqttRelayModule, MqttModule]] = {} relay_states: dict[str, MqttRelayState] = {} @@ -43,70 +90,75 @@ class UserAction(Enum): OFF = 'off' -def on_mqtt_message(home_id, message: MqttPayload): +def on_mqtt_message(node: MqttNode, + message: MqttPayload): if isinstance(message, InitialDiagnosticsPayload) or isinstance(message, DiagnosticsPayload): kwargs = dict(rssi=message.rssi, enabled=message.flags.state) if isinstance(message, InitialDiagnosticsPayload): kwargs['fw_version'] = message.fw_version - if home_id not in relay_states: - relay_states[home_id] = MqttRelayState() - relay_states[home_id].update(**kwargs) + if node.id not in relay_states: + relay_states[node.id] = MqttRelayState() + relay_states[node.id].update(**kwargs) -def enable_handler(home_id: str, ctx: bot.Context) -> None: - mqtt_relay.set_power(home_id, True) - ctx.reply(ctx.lang('done')) +async def enable_handler(node_id: str, ctx: bot.Context) -> None: + relay_nodes[node_id].switchpower(True) + await ctx.reply(ctx.lang('done')) -def disable_handler(home_id: str, ctx: bot.Context) -> None: - mqtt_relay.set_power(home_id, False) - ctx.reply(ctx.lang('done')) +async def disable_handler(node_id: str, ctx: bot.Context) -> None: + relay_nodes[node_id].switchpower(False) + await ctx.reply(ctx.lang('done')) -def start(ctx: bot.Context) -> None: - ctx.reply(ctx.lang('start_message')) +async def start(ctx: bot.Context) -> None: + await ctx.reply(ctx.lang('start_message')) @bot.exceptionhandler -def exception_handler(e: Exception, ctx: bot.Context) -> bool: +async def exception_handler(e: Exception, ctx: bot.Context) -> bool: return False @bot.defaultreplymarkup def markup(ctx: Optional[bot.Context]) -> Optional[ReplyKeyboardMarkup]: buttons = [] - for device_id, data in config['relays'].items(): - labels = data['labels'] - type_emoji = type_emojis[data['type']] - row = [f'{type_emoji}{status_emoji[i.value]} {labels[ctx.user_lang]}' + for node_id in config.app_config['relay_nodes']: + node_data = mqtt_nodes_config.get_node(node_id) + type_emoji = type_emojis[node_data['relay']['device_type']] + row = [f'{type_emoji}{status_emoji[i.value]} {config.app_config.get_relay_name_translated(ctx.user_lang, node_id)}' for i in UserAction] buttons.append(row) return ReplyKeyboardMarkup(buttons, one_time_keyboard=False) -if __name__ == '__main__': - devices = [] - for device_id, data in config['relays'].items(): - devices.append(MqttEspDevice(id=device_id, - secret=data['secret'])) - labels = data['labels'] - bot.lang.ru(**{device_id: labels['ru']}) - bot.lang.en(**{device_id: labels['en']}) - - type_emoji = type_emojis[data['type']] - - for action in UserAction: - messages = [] - for _lang, _label in labels.items(): - messages.append(f'{type_emoji}{status_emoji[action.value]} {labels[_lang]}') - bot.handler(texts=messages)(partial(enable_handler if action == UserAction.ON else disable_handler, device_id)) - - mqtt_relay = MqttRelay(devices=devices) - mqtt_relay.set_message_callback(on_mqtt_message) - mqtt_relay.configure_tls() - mqtt_relay.connect_and_loop(loop_forever=False) - - # bot.enable_logging(BotType.RELAY_MQTT) - bot.run(start_handler=start) - - mqtt_relay.disconnect() +devices = [] +mqtt = MqttWrapper(client_id='relay_mqtt_bot') +for node_id in config.app_config['relay_nodes']: + node_data = mqtt_nodes_config.get_node(node_id) + mqtt_node = MqttNode(node_id=node_id, + node_secret=node_data['password']) + module_kwargs = {} + try: + if node_data['relay']['legacy_topics']: + module_kwargs['legacy_topics'] = True + except KeyError: + pass + relay_nodes[node_id] = mqtt_node.load_module('relay', **module_kwargs) + mqtt_node.add_payload_callback(on_mqtt_message) + mqtt.add_node(mqtt_node) + + type_emoji = type_emojis[node_data['relay']['device_type']] + + for action in UserAction: + messages = [] + for _lang in Translation.LANGUAGES: + _label = config.app_config.get_relay_name_translated(_lang, node_id) + messages.append(f'{type_emoji}{status_emoji[action.value]} {_label}') + bot.handler(texts=messages)(partial(enable_handler if action == UserAction.ON else disable_handler, node_id)) + +mqtt.connect_and_loop(loop_forever=False) + +bot.run(start_handler=start) + +mqtt.disconnect() diff --git a/src/relay_mqtt_http_proxy.py b/src/relay_mqtt_http_proxy.py index 098facc..2bc2c4a 100755 --- a/src/relay_mqtt_http_proxy.py +++ b/src/relay_mqtt_http_proxy.py @@ -1,17 +1,19 @@ #!/usr/bin/env python3 from home import http from home.config import config -from home.mqtt import MqttRelay, MqttRelayState -from home.mqtt.esp import MqttEspDevice -from home.mqtt.payload import MqttPayload -from home.mqtt.payload.relay import InitialDiagnosticsPayload, DiagnosticsPayload -from typing import Optional +from home.mqtt import MqttPayload, MqttWrapper, MqttNode, MqttModule +from home.mqtt.module.relay import MqttRelayState, MqttRelayModule +from home.mqtt.module.diagnostics import InitialDiagnosticsPayload, DiagnosticsPayload +from typing import Optional, Union -mqtt_relay: Optional[MqttRelay] = None +mqtt: Optional[MqttWrapper] = None +mqtt_nodes: dict[str, MqttNode] = {} +relay_modules: dict[str, Union[MqttRelayModule, MqttModule]] = {} relay_states: dict[str, MqttRelayState] = {} -def on_mqtt_message(device_id, message: MqttPayload): +def on_mqtt_message(node: MqttNode, + message: MqttPayload): if isinstance(message, InitialDiagnosticsPayload) or isinstance(message, DiagnosticsPayload): kwargs = dict(rssi=message.rssi, enabled=message.flags.state) if device_id not in relay_states: @@ -29,17 +31,22 @@ class RelayMqttHttpProxy(http.HTTPServer): async def _relay_on_off(self, enable: Optional[bool], req: http.Request): - device_id = req.match_info['id'] - device_secret = req.query['secret'] + node_id = req.match_info['id'] + node_secret = req.query['secret'] + + node = mqtt_nodes[node_id] + relay_module = relay_modules[node_id] if enable is None: - if device_id in relay_states and relay_states[device_id].ever_updated: - cur_state = relay_states[device_id].enabled + if node_id in relay_states and relay_states[node_id].ever_updated: + cur_state = relay_states[node_id].enabled else: cur_state = False enable = not cur_state - mqtt_relay.set_power(device_id, enable, device_secret) + if not node.secret: + node.secret = node_secret + relay_module.switchpower(enable) return self.ok() async def relay_on(self, req: http.Request): @@ -53,15 +60,21 @@ class RelayMqttHttpProxy(http.HTTPServer): if __name__ == '__main__': - config.load('relay_mqtt_http_proxy') + config.load_app('relay_mqtt_http_proxy') + + mqtt = MqttWrapper() + for device_id, data in config['relays'].items(): + mqtt_node = MqttNode(node_id=device_id) + relay_modules[device_id] = mqtt_node.load_module('relay') + mqtt_nodes[device_id] = mqtt_node + mqtt_node.add_payload_callback(on_mqtt_message) + mqtt.add_node(mqtt_node) + mqtt_node.add_payload_callback(on_mqtt_message) - mqtt_relay = MqttRelay(devices=[MqttEspDevice(id=device_id) for device_id in config.get('relay.devices')]) - mqtt_relay.configure_tls() - mqtt_relay.set_message_callback(on_mqtt_message) - mqtt_relay.connect_and_loop(loop_forever=False) + mqtt.connect_and_loop(loop_forever=False) proxy = RelayMqttHttpProxy(config.get_addr('server.listen')) try: proxy.run() except KeyboardInterrupt: - mqtt_relay.disconnect() + mqtt.disconnect() diff --git a/src/sensors_bot.py b/src/sensors_bot.py index dc081b0..152dd24 100755 --- a/src/sensors_bot.py +++ b/src/sensors_bot.py @@ -23,7 +23,7 @@ from home.api.types import ( TemperatureSensorLocation ) -config.load('sensors_bot') +config.load_app('sensors_bot') bot.initialize() bot.lang.ru( diff --git a/src/sensors_mqtt_sender.py b/src/sensors_mqtt_sender.py deleted file mode 100755 index 87a28ca..0000000 --- a/src/sensors_mqtt_sender.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 -import time -import json - -from home.util import parse_addr, MySimpleSocketClient -from home.mqtt import MqttBase, poll_tick -from home.mqtt.payload.sensors import Temperature -from home.config import config - - -class MqttClient(MqttBase): - def __init__(self): - super().__init__(self) - self._home_id = config['mqtt']['home_id'] - - def poll(self): - freq = int(config['mqtt']['sensors']['poll_freq']) - self._logger.debug(f'freq={freq}') - - g = poll_tick(freq) - while True: - time.sleep(next(g)) - for k, v in config['mqtt']['sensors']['si7021'].items(): - host, port = parse_addr(v['addr']) - self.publish_si7021(host, port, k) - - def publish_si7021(self, host: str, port: int, name: str): - self._logger.debug(f"publish_si7021/{name}: {host}:{port}") - - try: - now = time.time() - socket = MySimpleSocketClient(host, port) - - socket.write('read') - response = json.loads(socket.read().strip()) - - temp = response['temp'] - humidity = response['humidity'] - - self._logger.debug(f'publish_si7021/{name}: temp={temp} humidity={humidity}') - - pld = Temperature(time=round(now), - temp=temp, - rh=humidity) - self._client.publish(f'hk/{self._home_id}/si7021/{name}', - payload=pld.pack(), - qos=1) - except Exception as e: - self._logger.exception(e) - - -if __name__ == '__main__': - config.load('sensors_mqtt_sender') - - client = MqttClient() - client.configure_tls() - client.connect_and_loop(loop_forever=False) - client.poll() diff --git a/src/sound_bot.py b/src/sound_bot.py index 186337a..32371bd 100755 --- a/src/sound_bot.py +++ b/src/sound_bot.py @@ -14,7 +14,7 @@ from home.api.types import SoundSensorLocation, BotType from home.api.errors import ApiResponseError from home.media import SoundNodeClient, SoundRecordClient, SoundRecordFile, CameraNodeClient from home.soundsensor import SoundSensorServerGuardClient -from home.util import parse_addr, chunks, filesize_fmt +from home.util import Addr, chunks, filesize_fmt from home.telegram import bot @@ -23,11 +23,11 @@ from telegram import ReplyKeyboardMarkup, InlineKeyboardMarkup, InlineKeyboardBu from PIL import Image -config.load('sound_bot') +config.load_app('sound_bot') nodes = {} for nodename, nodecfg in config['nodes'].items(): - nodes[nodename] = parse_addr(nodecfg['addr']) + nodes[nodename] = Addr.fromstring(nodecfg['addr']) bot.initialize() bot.lang.ru( @@ -142,13 +142,13 @@ cam_client_links: Dict[str, CameraNodeClient] = {} def node_client(node: str) -> SoundNodeClient: if node not in node_client_links: - node_client_links[node] = SoundNodeClient(parse_addr(config['nodes'][node]['addr'])) + node_client_links[node] = SoundNodeClient(Addr.fromstring(config['nodes'][node]['addr'])) return node_client_links[node] def camera_client(cam: str) -> CameraNodeClient: if cam not in node_client_links: - cam_client_links[cam] = CameraNodeClient(parse_addr(config['cameras'][cam]['addr'])) + cam_client_links[cam] = CameraNodeClient(Addr.fromstring(config['cameras'][cam]['addr'])) return cam_client_links[cam] @@ -188,7 +188,7 @@ def manual_recording_allowed(user_id: int) -> bool: def guard_client() -> SoundSensorServerGuardClient: - return SoundSensorServerGuardClient(parse_addr(config['bot']['guard_server'])) + return SoundSensorServerGuardClient(Addr.fromstring(config['bot']['guard_server'])) # message renderers diff --git a/src/sound_node.py b/src/sound_node.py index 9d53362..b0b4a67 100755 --- a/src/sound_node.py +++ b/src/sound_node.py @@ -77,7 +77,7 @@ if __name__ == '__main__': if not os.getegid() == 0: raise RuntimeError("Must be run as root.") - config.load('sound_node') + config.load_app('sound_node') storage = SoundRecordStorage(config['node']['storage']) diff --git a/src/sound_sensor_node.py b/src/sound_sensor_node.py index d9a8999..404fdf4 100755 --- a/src/sound_sensor_node.py +++ b/src/sound_sensor_node.py @@ -4,7 +4,7 @@ import os import sys from home.config import config -from home.util import parse_addr +from home.util import Addr from home.soundsensor import SoundSensorNode logger = logging.getLogger(__name__) @@ -14,14 +14,14 @@ if __name__ == '__main__': if not os.getegid() == 0: sys.exit('Must be run as root.') - config.load('sound_sensor_node') + config.load_app('sound_sensor_node') kwargs = {} if 'delay' in config['node']: kwargs['delay'] = config['node']['delay'] if 'server_addr' in config['node']: - server_addr = parse_addr(config['node']['server_addr']) + server_addr = Addr.fromstring(config['node']['server_addr']) else: server_addr = None diff --git a/src/sound_sensor_server.py b/src/sound_sensor_server.py index aa62608..b660210 100755 --- a/src/sound_sensor_server.py +++ b/src/sound_sensor_server.py @@ -6,7 +6,7 @@ from time import sleep from typing import Optional, List, Dict, Tuple from functools import partial from home.config import config -from home.util import parse_addr +from home.util import Addr from home.api import WebAPIClient, RequestParams from home.api.types import SoundSensorLocation from home.soundsensor import SoundSensorServer, SoundSensorHitHandler @@ -159,7 +159,7 @@ def api_error_handler(exc, name, req: RequestParams): if __name__ == '__main__': - config.load('sound_sensor_server') + config.load_app('sound_sensor_server') hc = HitCounter() api = WebAPIClient(timeout=(10, 60)) @@ -172,12 +172,12 @@ if __name__ == '__main__': sound_nodes = {} if 'sound_nodes' in config: for nodename, nodecfg in config['sound_nodes'].items(): - sound_nodes[nodename] = parse_addr(nodecfg['addr']) + sound_nodes[nodename] = Addr.fromstring(nodecfg['addr']) camera_nodes = {} if 'camera_nodes' in config: for nodename, nodecfg in config['camera_nodes'].items(): - camera_nodes[nodename] = parse_addr(nodecfg['addr']) + camera_nodes[nodename] = Addr.fromstring(nodecfg['addr']) if sound_nodes: record_clients[MediaNodeType.SOUND] = SoundRecordClient(sound_nodes, diff --git a/src/ssh_tunnels_config_util.py b/src/ssh_tunnels_config_util.py index 3b2ba6e..963c01b 100755 --- a/src/ssh_tunnels_config_util.py +++ b/src/ssh_tunnels_config_util.py @@ -3,12 +3,12 @@ from home.config import config if __name__ == '__main__': - config.load('ssh_tunnels_config_util') + config.load_app('ssh_tunnels_config_util') network_prefix = config['network'] hostnames = [] - for k, v in config.items(): + for k, v in config.app_config.get().items(): if type(v) is not dict: continue hostnames.append(k) diff --git a/src/temphum_mqtt_node.py b/src/temphum_mqtt_node.py new file mode 100755 index 0000000..c3d1975 --- /dev/null +++ b/src/temphum_mqtt_node.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +import asyncio +import json +import logging + +from typing import Optional + +from home.config import config +from home.temphum import SensorType, BaseSensor +from home.temphum.i2c import create_sensor + +logger = logging.getLogger(__name__) +sensor: Optional[BaseSensor] = None +lock = asyncio.Lock() +delay = 0.01 + + +async def get_measurements(): + async with lock: + await asyncio.sleep(delay) + + temp = sensor.temperature() + rh = sensor.humidity() + + return rh, temp + + +async def handle_client(reader, writer): + request = None + while request != 'quit': + try: + request = await reader.read(255) + if request == b'\x04': + break + request = request.decode('utf-8').strip() + except Exception: + break + + if request == 'read': + try: + rh, temp = await asyncio.wait_for(get_measurements(), timeout=3) + data = dict(humidity=rh, temp=temp) + except asyncio.TimeoutError as e: + logger.exception(e) + data = dict(error='i2c call timed out') + else: + data = dict(error='invalid request') + + writer.write((json.dumps(data) + '\r\n').encode('utf-8')) + try: + await writer.drain() + except ConnectionResetError: + pass + + writer.close() + + +async def run_server(host, port): + server = await asyncio.start_server(handle_client, host, port) + async with server: + logger.info('Server started.') + await server.serve_forever() + + +if __name__ == '__main__': + config.load_app() + + if 'measure_delay' in config['sensor']: + delay = float(config['sensor']['measure_delay']) + + sensor = create_sensor(SensorType(config['sensor']['type']), + int(config['sensor']['bus'])) + + try: + host, port = config.get_addr('server.listen') + asyncio.run(run_server(host, port)) + except KeyboardInterrupt: + logging.info('Exiting...') diff --git a/src/sensors_mqtt_receiver.py b/src/temphum_mqtt_receiver.py index a377ddd..2b30800 100755 --- a/src/sensors_mqtt_receiver.py +++ b/src/temphum_mqtt_receiver.py @@ -2,21 +2,11 @@ import paho.mqtt.client as mqtt import re -from home.mqtt import MqttBase from home.config import config -from home.mqtt.payload.sensors import Temperature -from home.api.types import TemperatureSensorLocation -from home.database import SensorsDatabase +from home.mqtt import MqttWrapper, MqttNode -def get_sensor_type(sensor: str) -> TemperatureSensorLocation: - for item in TemperatureSensorLocation: - if sensor == item.name.lower(): - return item - raise ValueError(f'unexpected sensor value: {sensor}') - - -class MqttServer(MqttBase): +class MqttServer(Mqtt): def __init__(self): super().__init__(clean_session=False) self.database = SensorsDatabase() @@ -47,7 +37,11 @@ class MqttServer(MqttBase): if __name__ == '__main__': - config.load('sensors_mqtt_receiver') + config.load_app('temphum_mqtt_receiver') + + mqtt = MqttWrapper(clean_session=False) + node = MqttNode(node_id='+') + node.load_module('temphum', write_to_database=True) + mqtt.add_node(node) - server = MqttServer() - server.connect_and_loop() + mqtt.connect_and_loop()
\ No newline at end of file diff --git a/src/temphum_smbus_util.py b/src/temphum_smbus_util.py index 0f90835..c06bacd 100755 --- a/src/temphum_smbus_util.py +++ b/src/temphum_smbus_util.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 from argparse import ArgumentParser -from home.temphum import SensorType, create_sensor +from home.temphum import SensorType +from home.temphum.i2c import create_sensor if __name__ == '__main__': diff --git a/src/temphumd.py b/src/temphumd.py index f4d1fca..c3d1975 100755 --- a/src/temphumd.py +++ b/src/temphumd.py @@ -6,10 +6,11 @@ import logging from typing import Optional from home.config import config -from home.temphum import SensorType, create_sensor, TempHumSensor +from home.temphum import SensorType, BaseSensor +from home.temphum.i2c import create_sensor logger = logging.getLogger(__name__) -sensor: Optional[TempHumSensor] = None +sensor: Optional[BaseSensor] = None lock = asyncio.Lock() delay = 0.01 @@ -62,7 +63,7 @@ async def run_server(host, port): if __name__ == '__main__': - config.load() + config.load_app() if 'measure_delay' in config['sensor']: delay = float(config['sensor']['measure_delay']) diff --git a/src/test_new_config.py b/src/test_new_config.py new file mode 100755 index 0000000..db9eae3 --- /dev/null +++ b/src/test_new_config.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 +from home.config import config +from home.mqtt import MqttNodesConfig +from home.telegram.config import TelegramUserIdsConfig +from pprint import pprint + + +if __name__ == '__main__': + config.load_app(name=False) + + c = TelegramUserIdsConfig() + pprint(c.get())
\ No newline at end of file diff --git a/src/web_api.py b/src/web_api.py index 0ddc6bd..0aa994a 100755 --- a/src/web_api.py +++ b/src/web_api.py @@ -231,7 +231,7 @@ if __name__ == '__main__': _app_name = 'web_api' if is_development_mode(): _app_name += '_dev' - config.load(_app_name) + config.load_app(_app_name) loop = asyncio.get_event_loop() diff --git a/systemd/inverter_mqtt_receiver.service b/systemd/inverter_mqtt_receiver.service new file mode 100644 index 0000000..fedf11f --- /dev/null +++ b/systemd/inverter_mqtt_receiver.service @@ -0,0 +1,13 @@ +[Unit] +Description=Inverter MQTT receiver +After=clickhouse-server.service + +[Service] +User=user +Group=user +Restart=on-failure +ExecStart=/home/user/homekit/src/inverter_mqtt_util.py receiver +WorkingDirectory=/home/user + +[Install] +WantedBy=multi-user.target
\ No newline at end of file diff --git a/systemd/inverter_mqtt_sender.service b/systemd/inverter_mqtt_sender.service index e3925f6..34272bb 100644 --- a/systemd/inverter_mqtt_sender.service +++ b/systemd/inverter_mqtt_sender.service @@ -6,7 +6,7 @@ After=inverterd.service User=user Group=user Restart=on-failure -ExecStart=/home/user/homekit/src/inverter_mqtt_sender.py +ExecStart=/home/user/homekit/src/inverter_mqtt_util.py sender WorkingDirectory=/home/user [Install] diff --git a/systemd/ipcam_rtsp2hls@.service b/systemd/ipcam_rtsp2hls@.service index addd819..efcdd6a 100644 --- a/systemd/ipcam_rtsp2hls@.service +++ b/systemd/ipcam_rtsp2hls@.service @@ -9,6 +9,8 @@ User=user Group=user EnvironmentFile=/etc/ipcam_rtsp2hls.conf.d/%i.conf ExecStart=/home/user/homekit/tools/ipcam_rtsp2hls.sh --name %i --user $USER --password $PASSWORD --ip $IP --port $PORT $ARGS +Restart=on-failure +RestartSec=3 [Install] WantedBy=multi-user.target diff --git a/systemd/sensors_mqtt_receiver.service b/systemd/sensors_mqtt_receiver.service index e67c112..5b9ff6a 100644 --- a/systemd/sensors_mqtt_receiver.service +++ b/systemd/sensors_mqtt_receiver.service @@ -1,12 +1,12 @@ [Unit] -Description=sensors mqtt receiver +Description=temphum mqtt receiver After=network.target [Service] User=user Group=user Restart=on-failure -ExecStart=python3 /home/user/home/src/sensors_mqtt_receiver.py +ExecStart=python3 /home/user/home/src/temphum_mqtt_receiver.py WorkingDirectory=/home/user [Install] diff --git a/systemd/sensors_mqtt_sender.service b/systemd/sensors_mqtt_sender.service deleted file mode 100644 index a271d72..0000000 --- a/systemd/sensors_mqtt_sender.service +++ /dev/null @@ -1,13 +0,0 @@ -[Unit] -Description=Sensors MQTT sender -After=temphumd.service - -[Service] -User=user -Group=user -Restart=on-failure -ExecStart=/home/user/homekit/src/sensors_mqtt_sender.py -WorkingDirectory=/home/user - -[Install] -WantedBy=multi-user.target
\ No newline at end of file diff --git a/test/mqtt_relay_server_util.py b/test/mqtt_relay_server_util.py new file mode 100755 index 0000000..ac6a9ae --- /dev/null +++ b/test/mqtt_relay_server_util.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +import sys +import os.path +sys.path.extend([ + os.path.realpath( + os.path.join(os.path.dirname(os.path.join(__file__)), '..') + ) +]) + +from src.home.config import config +from src.home.mqtt.relay import MQTTRelayClient + + +if __name__ == '__main__': + config.load_app('test_mqtt_relay_server') + relay = MQTTRelayClient('test') + relay.connect_and_loop() diff --git a/test/mqtt_relay_util.py b/test/mqtt_relay_util.py new file mode 100755 index 0000000..0d8c764 --- /dev/null +++ b/test/mqtt_relay_util.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +import sys +import os.path +sys.path.extend([ + os.path.realpath( + os.path.join(os.path.dirname(os.path.join(__file__)), '..') + ) +]) + +from argparse import ArgumentParser +from src.home.config import config +from src.home.mqtt.relay import MQTTRelayController + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--on', action='store_true') + parser.add_argument('--off', action='store_true') + parser.add_argument('--stat', action='store_true') + + config.load_app('test_mqtt_relay', parser=parser) + arg = parser.parse_args() + + relay = MQTTRelayController('test') + relay.connect_and_loop(loop_forever=False) + + if arg.on: + relay.set_power(True) + + elif arg.off: + relay.set_power(False) + + elif arg.stat: + relay.send_stat(dict( + state=False, + signal=-59, + fw_v=1.0 + ))
\ No newline at end of file diff --git a/test/test_amixer.py b/test/test_amixer.py index c8bd546..464941e 100755 --- a/test/test_amixer.py +++ b/test/test_amixer.py @@ -28,7 +28,7 @@ if __name__ == '__main__': parser.add_argument('--decr', type=str) # parser.add_argument('--dump-config', action='store_true') - args = config.load('test_amixer', parser=parser) + args = config.load_app('test_amixer', parser=parser) # if args.dump_config: # print(config.data) diff --git a/test/test_api.py b/test/test_api.py index 1f6361c..e80eb4c 100755 --- a/test/test_api.py +++ b/test/test_api.py @@ -13,7 +13,7 @@ from src.home.config import config if __name__ == '__main__': - config.load('test_api') + config.load_app('test_api') api = WebAPIClient() print(api.log_bot_request(BotType.ADMIN, 1, "test_api.py")) diff --git a/test/test_esp32_cam.py b/test/test_esp32_cam.py index 27ce379..6a4ad25 100755 --- a/test/test_esp32_cam.py +++ b/test/test_esp32_cam.py @@ -10,7 +10,7 @@ sys.path.extend([ from pprint import pprint from argparse import ArgumentParser from time import sleep -from src.home.util import parse_addr +from src.home.util import Addr from src.home.camera import esp32 from src.home.config import config @@ -21,8 +21,8 @@ if __name__ == '__main__': parser.add_argument('--status', action='store_true', help='print status and exit') - arg = config.load(False, parser=parser) - cam = esp32.WebClient(addr=parse_addr(arg.addr)) + arg = config.load_app(False, parser=parser) + cam = esp32.WebClient(addr=Addr.fromstring(arg.addr)) if arg.status: status = cam.getstatus() diff --git a/test/test_inverter_monitor.py b/test/test_inverter_monitor.py index 3b1c6b0..621c0e9 100755 --- a/test/test_inverter_monitor.py +++ b/test/test_inverter_monitor.py @@ -372,5 +372,5 @@ def main(): if __name__ == '__main__': - config.load('test_inverter_monitor') + config.load_app('test_inverter_monitor') main() diff --git a/test/test_ipcam_server_cleanup.py b/test/test_ipcam_server_cleanup.py index b7eb23a..5f313a4 100644 --- a/test/test_ipcam_server_cleanup.py +++ b/test/test_ipcam_server_cleanup.py @@ -77,5 +77,5 @@ def cleanup_job(): if __name__ == '__main__': - config.load('ipcam_server') + config.load_app('ipcam_server') cleanup_job() diff --git a/test/test_record_upload.py b/test/test_record_upload.py index cbd3ca2..835504f 100755 --- a/test/test_record_upload.py +++ b/test/test_record_upload.py @@ -13,7 +13,7 @@ import time from src.home.api import WebAPIClient, RequestParams from src.home.config import config from src.home.media import SoundRecordClient -from src.home.util import parse_addr +from src.home.util import Addr logger = logging.getLogger(__name__) @@ -64,11 +64,11 @@ def api_success_handler(response, name, req: RequestParams): if __name__ == '__main__': - config.load('test_record_upload') + config.load_app('test_record_upload') nodes = {} for name, addr in config['nodes'].items(): - nodes[name] = parse_addr(addr) + nodes[name] = Addr(addr) record = SoundRecordClient(nodes, error_handler=record_error, finished_handler=record_finished, diff --git a/test/test_send_fake_sound_hit.py b/test/test_send_fake_sound_hit.py index 9660c45..61886cd 100755 --- a/test/test_send_fake_sound_hit.py +++ b/test/test_send_fake_sound_hit.py @@ -8,7 +8,7 @@ sys.path.extend([ ]) from argparse import ArgumentParser -from src.home.util import send_datagram, stringify, parse_addr +from src.home.util import send_datagram, stringify, Addr if __name__ == '__main__': @@ -22,4 +22,4 @@ if __name__ == '__main__': args = parser.parse_args() - send_datagram(stringify([args.name, args.hits]), parse_addr(args.server)) + send_datagram(stringify([args.name, args.hits]), Addr.fromstring(args.server)) diff --git a/test/test_sound_server_api.py b/test/test_sound_server_api.py index e68c6f8..5295a5d 100755 --- a/test/test_sound_server_api.py +++ b/test/test_sound_server_api.py @@ -56,7 +56,7 @@ def hits_sender(): if __name__ == '__main__': - config.load('test_api') + config.load_app('test_api') hc = HitCounter() api = WebAPIClient() diff --git a/test/test_telegram_aio_send_photo.py b/test/test_telegram_aio_send_photo.py index 705e534..4d05c03 100644 --- a/test/test_telegram_aio_send_photo.py +++ b/test/test_telegram_aio_send_photo.py @@ -20,7 +20,7 @@ async def main(): if __name__ == '__main__': - config.load('test_telegram_aio_send_photo') + config.load_app('test_telegram_aio_send_photo') loop = asyncio.get_event_loop() asyncio.ensure_future(main()) diff --git a/tools/mcuota.py b/tools/mcuota.py deleted file mode 100755 index 46968a8..0000000 --- a/tools/mcuota.py +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env python3 -import sys -import os.path -sys.path.extend([ - os.path.realpath( - os.path.join(os.path.dirname(os.path.join(__file__)), '..') - ) -]) - -from time import sleep -from argparse import ArgumentParser -from src.home.config import config -from src.home.mqtt import MqttRelay -from src.home.mqtt.esp import MqttEspDevice - - -def guess_filename(product: str, build_target: str): - return os.path.join( - products_dir, - product, - '.pio', - 'build', - build_target, - 'firmware.bin' - ) - - -def relayctl_publish_ota(filename: str, - device_id: str, - home_secret: str, - qos: int): - global stop - - def published(): - global stop - stop = True - - mqtt_relay = MqttRelay(devices=MqttEspDevice(id=device_id, secret=home_secret)) - mqtt_relay.configure_tls() - mqtt_relay.connect_and_loop(loop_forever=False) - mqtt_relay.push_ota(device_id, filename, published, qos) - while not stop: - sleep(0.1) - mqtt_relay.disconnect() - - -stop = False -products = { - 'relayctl': { - 'build_target': 'esp12e', - 'callback': relayctl_publish_ota - } -} - -products_dir = os.path.join( - os.path.dirname(__file__), - '..', - 'platformio' -) - - -def main(): - parser = ArgumentParser() - parser.add_argument('--filename', type=str) - parser.add_argument('--device-id', type=str, required=True) - parser.add_argument('--product', type=str, required=True) - parser.add_argument('--qos', type=int, default=1) - - config.load('mcuota_push', parser=parser) - arg = parser.parse_args() - - if arg.product not in products: - raise ValueError(f'invalid product: \'{arg.product}\' not found') - - if arg.device_id not in config['mqtt']['home_secrets']: - raise ValueError(f'home_secret for home {arg.device_id} not found in config!') - - filename = arg.filename if arg.filename else guess_filename(arg.product, products[arg.product]['build_target']) - if not os.path.exists(filename): - raise OSError(f'file \'{filename}\' does not exists') - - print('Please confirm following OTA params.') - print('') - print(f' Device ID: {arg.device_id}') - print(f' Product: {arg.product}') - print(f'Firmware file: {filename}') - print('') - input('Press any key to continue or Ctrl+C to abort.') - - products[arg.product]['callback'](filename, arg.device_id, config['mqtt']['home_secrets'][arg.device_id], qos=arg.qos) - - -if __name__ == '__main__': - try: - main() - except Exception as e: - print(str(e), file=sys.stderr) - sys.exit(1) diff --git a/tools/mcuota.sh b/tools/mcuota.sh deleted file mode 100755 index b2e7910..0000000 --- a/tools/mcuota.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" - -. "$DIR/lib.bash" - -if [ -d "$DIR/../venv" ]; then - echoinfo "activating python venv" - . "$DIR/../venv/bin/activate" -else - echowarn "python venv not found" -fi - -"$DIR/mcuota.py" "$@"
\ No newline at end of file |