summaryrefslogtreecommitdiff
path: root/src/home/mqtt
diff options
context:
space:
mode:
Diffstat (limited to 'src/home/mqtt')
-rw-r--r--src/home/mqtt/__init__.py10
-rw-r--r--src/home/mqtt/_module.py43
-rw-r--r--src/home/mqtt/_node.py111
-rw-r--r--src/home/mqtt/_wrapper.py55
-rw-r--r--src/home/mqtt/module/diagnostics.py7
-rw-r--r--src/home/mqtt/module/ota.py15
-rw-r--r--src/home/mqtt/module/relay.py8
-rw-r--r--src/home/mqtt/module/temphum.py13
-rw-r--r--src/home/mqtt/mqtt.py20
-rw-r--r--src/home/mqtt/relay.py59
-rw-r--r--src/home/mqtt/util.py31
11 files changed, 176 insertions, 196 deletions
diff --git a/src/home/mqtt/__init__.py b/src/home/mqtt/__init__.py
index c95061f..8633437 100644
--- a/src/home/mqtt/__init__.py
+++ b/src/home/mqtt/__init__.py
@@ -1,9 +1,5 @@
-from .mqtt import MqttBase, MqttPayload, MqttPayloadCustomField
+from .mqtt import Mqtt, MqttPayload, MqttPayloadCustomField
from ._node import MqttNode
from ._module import MqttModule
-from .util import (
- poll_tick,
- get_modules as get_mqtt_modules,
- import_module as import_mqtt_module,
- add_module as add_mqtt_module
-) \ No newline at end of file
+from ._wrapper import MqttWrapper
+from .util import get_modules as get_mqtt_modules \ No newline at end of file
diff --git a/src/home/mqtt/_module.py b/src/home/mqtt/_module.py
index 840534e..80f27bb 100644
--- a/src/home/mqtt/_module.py
+++ b/src/home/mqtt/_module.py
@@ -2,6 +2,10 @@ from __future__ import annotations
import abc
import logging
+import threading
+
+from time import sleep
+from ..util import next_tick_gen
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
@@ -10,16 +14,29 @@ if TYPE_CHECKING:
class MqttModule(abc.ABC):
- tick_interval: int
+ _tick_interval: int
_initialized: bool
+ _connected: bool
+ _ticker: Optional[threading.Thread]
+ _mqtt_node_ref: Optional[MqttNode]
def __init__(self, tick_interval=0):
- self.tick_interval = tick_interval
+ self._tick_interval = tick_interval
self._initialized = False
+ self._ticker = None
self._logger = logging.getLogger(self.__class__.__name__)
+ self._connected = False
+ self._mqtt_node_ref = None
- def init(self, mqtt: MqttNode):
- pass
+ def on_connect(self, mqtt: MqttNode):
+ self._connected = True
+ self._mqtt_node_ref = mqtt
+ if self._tick_interval:
+ self._start_ticker()
+
+ def on_disconnect(self, mqtt: MqttNode):
+ self._connected = False
+ self._mqtt_node_ref = None
def is_initialized(self):
return self._initialized
@@ -30,8 +47,24 @@ class MqttModule(abc.ABC):
def unset_initialized(self):
self._initialized = False
- def tick(self, mqtt: MqttNode):
+ def tick(self):
pass
+ def _tick(self):
+ g = next_tick_gen(self._tick_interval)
+ while self._connected:
+ sleep(next(g))
+ if not self._connected:
+ break
+ self.tick()
+
+ def _start_ticker(self):
+ if not self._ticker or not self._ticker.is_alive():
+ name_part = f'{self._mqtt_node_ref.id}/' if self._mqtt_node_ref else ''
+ self._ticker = None
+ self._ticker = threading.Thread(target=self._tick,
+ name=f'mqtt:{self.__class__.__name__}/{name_part}ticker')
+ self._ticker.start()
+
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
pass
diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py
index f34da0c..ddf5ba2 100644
--- a/src/home/mqtt/_node.py
+++ b/src/home/mqtt/_node.py
@@ -1,103 +1,80 @@
-import paho.mqtt.client as mqtt
+import logging
+import importlib
-from .mqtt import MqttBase
-from typing import List, Optional
-from ._module import MqttModule
+from typing import List, TYPE_CHECKING, Optional
from ._payload import MqttPayload
+from ._module import MqttModule
+if TYPE_CHECKING:
+ from ._wrapper import MqttWrapper
+else:
+ MqttWrapper = None
-class MqttNode(MqttBase):
+class MqttNode:
_modules: List[MqttModule]
_module_subscriptions: dict[str, MqttModule]
_node_id: str
_payload_callbacks: list[callable]
- # _devices: list[MqttEspDevice]
- # _message_callback: Optional[callable]
- # _ota_publish_callback: Optional[callable]
+ _wrapper: Optional[MqttWrapper]
- def __init__(self,
- node_id: str,
- # devices: Union[MqttEspDevice, list[MqttEspDevice]]
- ):
- super().__init__(clean_session=True)
+ def __init__(self, node_id: str):
self._modules = []
self._module_subscriptions = {}
self._node_id = node_id
self._payload_callbacks = []
- # if not isinstance(devices, list):
- # devices = [devices]
- # self._devices = devices
- # self._message_callback = None
- # self._ota_publish_callback = None
- # self._ota_mid = None
+ self._logger = logging.getLogger(self.__class__.__name__)
+ self._wrapper = None
- def on_connect(self, client: mqtt.Client, userdata, flags, rc):
- super().on_connect(client, userdata, flags, rc)
+ def on_connect(self, wrapper: MqttWrapper):
+ self._wrapper = wrapper
for module in self._modules:
if not module.is_initialized():
- module.init(self)
+ module.on_connect(self)
module.set_initialized()
- def on_disconnect(self, client: mqtt.Client, userdata, rc):
- super().on_disconnect(client, userdata, rc)
+ def on_disconnect(self):
+ self._wrapper = None
for module in self._modules:
module.unset_initialized()
- def on_publish(self, client: mqtt.Client, userdata, mid):
- pass # FIXME
- # if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback:
- # self._ota_publish_callback()
-
- def on_message(self, client: mqtt.Client, userdata, msg):
- try:
- topic = msg.topic
- actual_topic = topic[len(f'hk/{self._node_id}/'):]
-
- if actual_topic in self._module_subscriptions:
- payload = self._module_subscriptions[actual_topic].handle_payload(self, actual_topic, msg.payload)
- if isinstance(payload, MqttPayload):
- for f in self._payload_callbacks:
- f(payload)
+ def on_message(self, topic, payload):
+ if topic in self._module_subscriptions:
+ payload = self._module_subscriptions[topic].handle_payload(self, topic, payload)
+ if isinstance(payload, MqttPayload):
+ for f in self._payload_callbacks:
+ f(payload)
- except Exception as e:
- self._logger.exception(str(e))
-
- # def push_ota(self,
- # device_id,
- # filename: str,
- # publish_callback: callable,
- # qos: int):
- # device = next(d for d in self._devices if d.id == device_id)
- # assert device.secret is not None, 'device secret not specified'
- #
- # self._ota_publish_callback = publish_callback
- # payload = OtaPayload(secret=device.secret, filename=filename)
- # publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
- # payload=payload.pack(),
- # qos=qos)
- # self._ota_mid = publish_result.mid
- # self._client.loop_write()
- #
- # @classmethod
- # def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None):
- # return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'
+ def load_module(self, module_name: str, *args, **kwargs) -> MqttModule:
+ module = importlib.import_module(f'..module.{module_name}', __name__)
+ if not hasattr(module, 'MODULE_NAME'):
+ raise RuntimeError(f'MODULE_NAME not found in module {module}')
+ cl = getattr(module, getattr(module, 'MODULE_NAME'))
+ instance = cl(*args, **kwargs)
+ self.add_module(instance)
+ return instance
def add_module(self, module: MqttModule):
self._modules.append(module)
- if self._connected:
- module.init(self)
+ if self._wrapper and self._wrapper._connected:
+ module.on_connect(self)
module.set_initialized()
def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1):
+ if not self._wrapper or not self._wrapper._connected:
+ raise RuntimeError('not connected')
+
self._module_subscriptions[topic] = module
- self._client.subscribe(f'hk/{self._node_id}/{topic}', qos)
+ self._wrapper.subscribe(self.id, topic, qos)
def publish(self,
topic: str,
payload: bytes,
qos: int = 1):
- self._client.publish(f'hk/{self._node_id}/{topic}', payload, qos)
- self._client.loop_write()
+ self._wrapper.publish(self.id, topic, payload, qos)
def add_payload_callback(self, callback: callable):
- self._payload_callbacks.append(callback) \ No newline at end of file
+ self._payload_callbacks.append(callback)
+
+ @property
+ def id(self) -> str:
+ return self._node_id
diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py
new file mode 100644
index 0000000..41f9d89
--- /dev/null
+++ b/src/home/mqtt/_wrapper.py
@@ -0,0 +1,55 @@
+import paho.mqtt.client as mqtt
+
+from .mqtt import Mqtt
+from ._node import MqttNode
+from ..config import config
+from ..util import strgen
+
+
+class MqttWrapper(Mqtt):
+ _nodes: list[MqttNode]
+
+ def __init__(self, topic_prefix='hk', randomize_client_id=False):
+ client_id = config['mqtt']['client_id']
+ if randomize_client_id:
+ client_id += '_'+strgen(6)
+ super().__init__(clean_session=True, client_id=client_id)
+ self._nodes = []
+ self._topic_prefix = topic_prefix
+
+ def on_connect(self, client: mqtt.Client, userdata, flags, rc):
+ super().on_connect(client, userdata, flags, rc)
+ for node in self._nodes:
+ node.on_connect(self)
+
+ def on_disconnect(self, client: mqtt.Client, userdata, rc):
+ super().on_disconnect(client, userdata, rc)
+ for node in self._nodes:
+ node.on_disconnect()
+
+ def on_message(self, client: mqtt.Client, userdata, msg):
+ try:
+ topic = msg.topic
+ for node in self._nodes:
+ node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload)
+ except Exception as e:
+ self._logger.exception(str(e))
+
+ def add_node(self, node: MqttNode):
+ self._nodes.append(node)
+ if self._connected:
+ node.on_connect(self)
+
+ def subscribe(self,
+ node_id: str,
+ topic: str,
+ qos: int):
+ self._client.subscribe(f'{self._topic_prefix}/{node_id}/{topic}', qos)
+
+ def publish(self,
+ node_id: str,
+ topic: str,
+ payload: bytes,
+ qos: int):
+ self._client.publish(f'{self._topic_prefix}/{node_id}/{topic}', payload, qos)
+ self._client.loop_write()
diff --git a/src/home/mqtt/module/diagnostics.py b/src/home/mqtt/module/diagnostics.py
index c31cce2..fa6cc8e 100644
--- a/src/home/mqtt/module/diagnostics.py
+++ b/src/home/mqtt/module/diagnostics.py
@@ -48,14 +48,17 @@ class DiagnosticsPayload(MqttPayload):
class MqttDiagnosticsModule(MqttModule):
- def init(self, mqtt: MqttNode):
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
for topic in ('diag', 'd1ag', 'stat', 'stat1'):
mqtt.subscribe_module(topic, self)
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
+ message = None
if topic in ('stat', 'diag'):
message = DiagnosticsPayload.unpack(payload)
elif topic in ('stat1', 'd1ag'):
message = InitialDiagnosticsPayload.unpack(payload)
- self._logger.debug(message)
+ if message:
+ self._logger.debug(message)
return message
diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py
index 5a1a309..e71cccc 100644
--- a/src/home/mqtt/module/ota.py
+++ b/src/home/mqtt/module/ota.py
@@ -42,18 +42,15 @@ class OtaPayload(MqttPayload):
class MqttOtaModule(MqttModule):
_ota_request: Optional[tuple[str, str, int]]
- _mqtt_ref: Optional[MqttNode]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._ota_request = None
- self._mqtt_ref = None
- def init(self, mqtt: MqttNode):
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
mqtt.subscribe_module("otares", self)
- self._mqtt_ref = mqtt
-
if self._ota_request is not None:
secret, filename, qos = self._ota_request
self._ota_request = None
@@ -67,9 +64,9 @@ class MqttOtaModule(MqttModule):
def do_push_ota(self, secret: str, filename: str, qos: int):
payload = OtaPayload(secret=secret, filename=filename)
- self._mqtt_ref.publish('ota',
- payload=payload.pack(),
- qos=qos)
+ self._mqtt_node_ref.publish('ota',
+ payload=payload.pack(),
+ qos=qos)
def push_ota(self,
secret: str,
@@ -78,4 +75,4 @@ class MqttOtaModule(MqttModule):
if not self._initialized:
self._ota_request = (secret, filename, qos)
else:
- self.do_push_ota(secret, filename, qos) \ No newline at end of file
+ self.do_push_ota(secret, filename, qos)
diff --git a/src/home/mqtt/module/relay.py b/src/home/mqtt/module/relay.py
index bf22bfe..ae88ddb 100644
--- a/src/home/mqtt/module/relay.py
+++ b/src/home/mqtt/module/relay.py
@@ -58,16 +58,16 @@ class MqttRelayState:
class MqttRelayModule(MqttModule):
- def init(self, mqtt: MqttNode):
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
mqtt.subscribe_module('relay/switch', self)
mqtt.subscribe_module('relay/status', self)
- @staticmethod
- def switchpower(mqtt: MqttNode,
+ def switchpower(self,
enable: bool,
secret: str):
payload = MqttPowerSwitchPayload(secret=secret, state=enable)
- mqtt.publish('relay/switch', payload=payload.pack())
+ self._mqtt_node_ref.publish('relay/switch', payload=payload.pack())
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
message = None
diff --git a/src/home/mqtt/module/temphum.py b/src/home/mqtt/module/temphum.py
index 0e43f1b..9cdfedb 100644
--- a/src/home/mqtt/module/temphum.py
+++ b/src/home/mqtt/module/temphum.py
@@ -4,6 +4,7 @@ from .._module import MqttModule
from .._payload import MqttPayload
from ...util import HashableEnum
from typing import Optional
+from ...temphum import BaseSensor
two_digits_precision = lambda x: round(x, 2)
@@ -44,9 +45,17 @@ class MqttTempHumNodes(HashableEnum):
class MqttTempHumModule(MqttModule):
- def init(self, mqtt: MqttNode):
+ def __init__(self, sensor: Optional[BaseSensor] = None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._sensor = sensor
+
+ def on_connect(self, mqtt: MqttNode):
+ super().on_connect(mqtt)
mqtt.subscribe_module('temphum/data', self)
+ def tick(self):
+ pass
+
def handle_payload(self,
mqtt: MqttNode,
topic: str,
@@ -54,4 +63,4 @@ class MqttTempHumModule(MqttModule):
if topic == 'temphum/data':
message = MqttTemphumDataPayload.unpack(payload)
self._logger.debug(message)
- return message \ No newline at end of file
+ return message
diff --git a/src/home/mqtt/mqtt.py b/src/home/mqtt/mqtt.py
index fad5d26..ba32889 100644
--- a/src/home/mqtt/mqtt.py
+++ b/src/home/mqtt/mqtt.py
@@ -5,6 +5,7 @@ import logging
from ..config import config
from ._payload import *
+from typing import Optional
def username_and_password() -> Tuple[str, str]:
@@ -13,11 +14,13 @@ def username_and_password() -> Tuple[str, str]:
return username, password
-class MqttBase:
+class Mqtt:
_connected: bool
- def __init__(self, clean_session=True):
- self._client = mqtt.Client(client_id=config['mqtt']['client_id'],
+ def __init__(self,
+ clean_session=True,
+ client_id: Optional[str] = None):
+ self._client = mqtt.Client(client_id=config['mqtt']['client_id'] if not client_id else client_id,
protocol=mqtt.MQTTv311,
clean_session=clean_session)
self._client.on_connect = self.on_connect
@@ -81,14 +84,3 @@ class MqttBase:
def on_publish(self, client: mqtt.Client, userdata, mid):
self._logger.debug(f'publish done, mid={mid}')
-
-
-class MqttEspDevice:
- id: str
- secret: Optional[str]
-
- def __init__(self,
- node_id: str,
- secret: Optional[str] = None):
- self.id = node_id
- self.secret = secret
diff --git a/src/home/mqtt/relay.py b/src/home/mqtt/relay.py
deleted file mode 100644
index cf657f7..0000000
--- a/src/home/mqtt/relay.py
+++ /dev/null
@@ -1,59 +0,0 @@
-#!/usr/bin/env python3
-import paho.mqtt.client as mqtt
-import re
-import logging
-
-from .mqtt import MQTTBase
-
-
-class MQTTRelayClient(MQTTBase):
- _home_id: str
-
- def __init__(self, home_id: str):
- super().__init__(clean_session=True)
- self._home_id = home_id
-
- def on_connect(self, client: mqtt.Client, userdata, flags, rc):
- super().on_connect(client, userdata, flags, rc)
-
- topic = f'home/{self._home_id}/#'
- self._logger.info(f"subscribing to {topic}")
-
- client.subscribe(topic, qos=1)
-
- def on_message(self, client: mqtt.Client, userdata, msg):
- try:
- match = re.match(r'^home/(.*?)/relay/(stat|power)(?:/(.+))?$', msg.topic)
- self._logger.info(f'topic: {msg.topic}')
- if not match:
- return
-
- name = match.group(1)
- subtopic = match.group(2)
-
- if name != self._home_id:
- return
-
- if subtopic == 'stat':
- stat_name, stat_value = match.group(3).split('/')
- self._logger.info(f'stat: {stat_name} = {stat_value}')
-
- except Exception as e:
- self._logger.exception(str(e))
-
-
-class MQTTRelayController(MQTTBase):
- _home_id: str
-
- def __init__(self, home_id: str):
- super().__init__(clean_session=True)
- self._home_id = home_id
-
- def set_power(self, enable: bool):
- self._client.publish(f'home/{self._home_id}/relay/power',
- payload=int(enable),
- qos=1)
- self._client.loop_write()
-
- def send_stat(self, stat: dict):
- pass
diff --git a/src/home/mqtt/util.py b/src/home/mqtt/util.py
index 91b6baf..390d463 100644
--- a/src/home/mqtt/util.py
+++ b/src/home/mqtt/util.py
@@ -1,38 +1,15 @@
-import time
import os
import re
-import importlib
-from ._node import MqttNode
-from . import MqttModule
from typing import List
-def poll_tick(freq):
- t = time.time()
- while True:
- t += freq
- yield max(t - time.time(), 0)
-
-
def get_modules() -> List[str]:
modules = []
- for name in os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module')):
+ modules_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module')
+ for name in os.listdir(modules_dir):
+ if os.path.isdir(os.path.join(modules_dir, name)):
+ continue
name = re.sub(r'\.py$', '', name)
modules.append(name)
return modules
-
-
-def import_module(module: str):
- return importlib.import_module(
- f'..module.{module}', __name__)
-
-
-def add_module(mqtt_node: MqttNode, module: str) -> MqttModule:
- module = import_module(module)
- if not hasattr(module, 'MODULE_NAME'):
- raise RuntimeError(f'MODULE_NAME not found in module {module}')
- cl = getattr(module, getattr(module, 'MODULE_NAME'))
- instance = cl()
- mqtt_node.add_module(instance)
- return instance \ No newline at end of file