diff options
Diffstat (limited to 'src/home/mqtt')
-rw-r--r-- | src/home/mqtt/_node.py | 16 | ||||
-rw-r--r-- | src/home/mqtt/_wrapper.py | 8 | ||||
-rw-r--r-- | src/home/mqtt/module/ota.py | 11 | ||||
-rw-r--r-- | src/home/mqtt/module/relay.py | 6 | ||||
-rw-r--r-- | src/home/mqtt/module/temphum.py | 1 |
5 files changed, 29 insertions, 13 deletions
diff --git a/src/home/mqtt/_node.py b/src/home/mqtt/_node.py index ddf5ba2..4e259a4 100644 --- a/src/home/mqtt/_node.py +++ b/src/home/mqtt/_node.py @@ -14,13 +14,17 @@ class MqttNode: _modules: List[MqttModule] _module_subscriptions: dict[str, MqttModule] _node_id: str + _node_secret: str _payload_callbacks: list[callable] _wrapper: Optional[MqttWrapper] - def __init__(self, node_id: str): + def __init__(self, + node_id: str, + node_secret: Optional[str] = None): self._modules = [] self._module_subscriptions = {} self._node_id = node_id + self._node_secret = node_secret self._payload_callbacks = [] self._logger = logging.getLogger(self.__class__.__name__) self._wrapper = None @@ -42,7 +46,7 @@ class MqttNode: payload = self._module_subscriptions[topic].handle_payload(self, topic, payload) if isinstance(payload, MqttPayload): for f in self._payload_callbacks: - f(payload) + f(self, payload) def load_module(self, module_name: str, *args, **kwargs) -> MqttModule: module = importlib.import_module(f'..module.{module_name}', __name__) @@ -78,3 +82,11 @@ class MqttNode: @property def id(self) -> str: return self._node_id + + @property + def secret(self) -> str: + return self._node_secret + + @secret.setter + def secret(self, secret: str) -> None: + self._node_secret = secret diff --git a/src/home/mqtt/_wrapper.py b/src/home/mqtt/_wrapper.py index 41f9d89..0b32197 100644 --- a/src/home/mqtt/_wrapper.py +++ b/src/home/mqtt/_wrapper.py @@ -9,11 +9,15 @@ from ..util import strgen class MqttWrapper(Mqtt): _nodes: list[MqttNode] - def __init__(self, topic_prefix='hk', randomize_client_id=False): + def __init__(self, + topic_prefix='hk', + randomize_client_id=False, + clean_session=True): client_id = config['mqtt']['client_id'] if randomize_client_id: client_id += '_'+strgen(6) - super().__init__(clean_session=True, client_id=client_id) + super().__init__(clean_session=clean_session, + client_id=client_id) self._nodes = [] self._topic_prefix = topic_prefix diff --git a/src/home/mqtt/module/ota.py b/src/home/mqtt/module/ota.py index e71cccc..70c5475 100644 --- a/src/home/mqtt/module/ota.py +++ b/src/home/mqtt/module/ota.py @@ -41,7 +41,7 @@ class OtaPayload(MqttPayload): class MqttOtaModule(MqttModule): - _ota_request: Optional[tuple[str, str, int]] + _ota_request: Optional[tuple[str, int]] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -52,9 +52,9 @@ class MqttOtaModule(MqttModule): mqtt.subscribe_module("otares", self) if self._ota_request is not None: - secret, filename, qos = self._ota_request + filename, qos = self._ota_request self._ota_request = None - self.do_push_ota(secret, filename, qos) + self.do_push_ota(self._mqtt_node_ref.secret, filename, qos) def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: if topic == 'otares': @@ -69,10 +69,9 @@ class MqttOtaModule(MqttModule): qos=qos) def push_ota(self, - secret: str, filename: str, qos: int): if not self._initialized: - self._ota_request = (secret, filename, qos) + self._ota_request = (filename, qos) else: - self.do_push_ota(secret, filename, qos) + self.do_push_ota(filename, qos) diff --git a/src/home/mqtt/module/relay.py b/src/home/mqtt/module/relay.py index ae88ddb..5383fb6 100644 --- a/src/home/mqtt/module/relay.py +++ b/src/home/mqtt/module/relay.py @@ -64,9 +64,9 @@ class MqttRelayModule(MqttModule): mqtt.subscribe_module('relay/status', self) def switchpower(self, - enable: bool, - secret: str): - payload = MqttPowerSwitchPayload(secret=secret, state=enable) + enable: bool): + payload = MqttPowerSwitchPayload(secret=self._mqtt_node_ref.secret, + state=enable) self._mqtt_node_ref.publish('relay/switch', payload=payload.pack()) def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes) -> Optional[MqttPayload]: diff --git a/src/home/mqtt/module/temphum.py b/src/home/mqtt/module/temphum.py index 83ae34d..0e22793 100644 --- a/src/home/mqtt/module/temphum.py +++ b/src/home/mqtt/module/temphum.py @@ -48,6 +48,7 @@ class MqttTemphumDataPayload(MqttPayload): class MqttTempHumModule(MqttModule): def __init__(self, sensor: Optional[BaseSensor] = None, + write_to_database=False, *args, **kwargs): if sensor is not None: kwargs['tick_interval'] = 10 |