summaryrefslogtreecommitdiff
path: root/src/home/mqtt
diff options
context:
space:
mode:
Diffstat (limited to 'src/home/mqtt')
-rw-r--r--src/home/mqtt/_node.py16
-rw-r--r--src/home/mqtt/_wrapper.py8
-rw-r--r--src/home/mqtt/module/ota.py11
-rw-r--r--src/home/mqtt/module/relay.py6
-rw-r--r--src/home/mqtt/module/temphum.py1
5 files changed, 29 insertions, 13 deletions
diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py
index ddf5ba2..4e259a4 100644
--- a/src/home/mqtt/_node.py
+++ b/src/home/mqtt/_node.py
@@ -14,13 +14,17 @@ class MqttNode:
_modules: List[MqttModule]
_module_subscriptions: dict[str, MqttModule]
_node_id: str
+ _node_secret: str
_payload_callbacks: list[callable]
_wrapper: Optional[MqttWrapper]
- def __init__(self, node_id: str):
+ def __init__(self,
+ node_id: str,
+ node_secret: Optional[str] = None):
self._modules = []
self._module_subscriptions = {}
self._node_id = node_id
+ self._node_secret = node_secret
self._payload_callbacks = []
self._logger = logging.getLogger(self.__class__.__name__)
self._wrapper = None
@@ -42,7 +46,7 @@ class MqttNode:
payload = self._module_subscriptions[topic].handle_payload(self, topic, payload)
if isinstance(payload, MqttPayload):
for f in self._payload_callbacks:
- f(payload)
+ f(self, payload)
def load_module(self, module_name: str, *args, **kwargs) -> MqttModule:
module = importlib.import_module(f'..module.{module_name}', __name__)
@@ -78,3 +82,11 @@ class MqttNode:
@property
def id(self) -> str:
return self._node_id
+
+ @property
+ def secret(self) -> str:
+ return self._node_secret
+
+ @secret.setter
+ def secret(self, secret: str) -> None:
+ self._node_secret = secret
diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py
index 41f9d89..0b32197 100644
--- a/src/home/mqtt/_wrapper.py
+++ b/src/home/mqtt/_wrapper.py
@@ -9,11 +9,15 @@ from ..util import strgen
class MqttWrapper(Mqtt):
_nodes: list[MqttNode]
- def __init__(self, topic_prefix='hk', randomize_client_id=False):
+ def __init__(self,
+ topic_prefix='hk',
+ randomize_client_id=False,
+ clean_session=True):
client_id = config['mqtt']['client_id']
if randomize_client_id:
client_id += '_'+strgen(6)
- super().__init__(clean_session=True, client_id=client_id)
+ super().__init__(clean_session=clean_session,
+ client_id=client_id)
self._nodes = []
self._topic_prefix = topic_prefix
diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py
index e71cccc..70c5475 100644
--- a/src/home/mqtt/module/ota.py
+++ b/src/home/mqtt/module/ota.py
@@ -41,7 +41,7 @@ class OtaPayload(MqttPayload):
class MqttOtaModule(MqttModule):
- _ota_request: Optional[tuple[str, str, int]]
+ _ota_request: Optional[tuple[str, int]]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -52,9 +52,9 @@ class MqttOtaModule(MqttModule):
mqtt.subscribe_module("otares", self)
if self._ota_request is not None:
- secret, filename, qos = self._ota_request
+ filename, qos = self._ota_request
self._ota_request = None
- self.do_push_ota(secret, filename, qos)
+ self.do_push_ota(self._mqtt_node_ref.secret, filename, qos)
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
if topic == 'otares':
@@ -69,10 +69,9 @@ class MqttOtaModule(MqttModule):
qos=qos)
def push_ota(self,
- secret: str,
filename: str,
qos: int):
if not self._initialized:
- self._ota_request = (secret, filename, qos)
+ self._ota_request = (filename, qos)
else:
- self.do_push_ota(secret, filename, qos)
+ self.do_push_ota(filename, qos)
diff --git a/src/home/mqtt/module/relay.py b/src/home/mqtt/module/relay.py
index ae88ddb..5383fb6 100644
--- a/src/home/mqtt/module/relay.py
+++ b/src/home/mqtt/module/relay.py
@@ -64,9 +64,9 @@ class MqttRelayModule(MqttModule):
mqtt.subscribe_module('relay/status', self)
def switchpower(self,
- enable: bool,
- secret: str):
- payload = MqttPowerSwitchPayload(secret=secret, state=enable)
+ enable: bool):
+ payload = MqttPowerSwitchPayload(secret=self._mqtt_node_ref.secret,
+ state=enable)
self._mqtt_node_ref.publish('relay/switch', payload=payload.pack())
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]:
diff --git a/src/home/mqtt/module/temphum.py b/src/home/mqtt/module/temphum.py
index 83ae34d..0e22793 100644
--- a/src/home/mqtt/module/temphum.py
+++ b/src/home/mqtt/module/temphum.py
@@ -48,6 +48,7 @@ class MqttTemphumDataPayload(MqttPayload):
class MqttTempHumModule(MqttModule):
def __init__(self,
sensor: Optional[BaseSensor] = None,
+ write_to_database=False,
*args, **kwargs):
if sensor is not None:
kwargs['tick_interval'] = 10