This commit is contained in:
Evgeny Zinoviev 2023-05-31 09:22:00 +03:00
parent b02a9c5473
commit c976495222
24 changed files with 443 additions and 394 deletions

View File

@ -55,13 +55,9 @@ Mqtt::Mqtt() {
}
}
// if (ota.readyToRestart) {
// restartTimer.once(1, restart);
// } else {
reconnectTimer.once(2, [&]() {
reconnect();
});
// }
reconnectTimer.once(2, [&]() {
reconnect();
});
});
client.onSubscribe([&](uint16_t packetId, const SubscribeReturncode* returncodes, size_t len) {
@ -79,7 +75,7 @@ Mqtt::Mqtt() {
PRINTF("mqtt: message received, topic=%s, qos=%d, dup=%d, retain=%d, len=%ul, index=%ul, total=%ul\n",
topic, properties.qos, (int)properties.dup, (int)properties.retain, len, index, total);
const char *ptr = topic + nodeId.length() + 10;
const char *ptr = topic + nodeId.length() + 4;
String relevantTopic(ptr);
auto it = moduleSubscriptions.find(relevantTopic);
@ -87,7 +83,7 @@ Mqtt::Mqtt() {
auto module = it->second;
module->handlePayload(*this, relevantTopic, properties.packetId, payload, len, index, total);
} else {
PRINTF("error: module subscription for topic %s not found\n", topic);
PRINTF("error: module subscription for topic %s not found\n", relevantTopic.c_str());
}
});

View File

@ -1,42 +0,0 @@
#!/usr/bin/env python3
from typing import Optional
from argparse import ArgumentParser
from enum import Enum
from home.config import config
from home.mqtt import MqttRelay
from home.mqtt.esp import MqttEspBase
from home.mqtt.temphum import MqttTempHum
from home.mqtt.esp import MqttEspDevice
mqtt_client: Optional[MqttEspBase] = None
class NodeType(Enum):
RELAY = 'relay'
TEMPHUM = 'temphum'
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--device-id', type=str, required=True)
parser.add_argument('--type', type=str, required=True,
choices=[i.name.lower() for i in NodeType])
config.load('mqtt_util', parser=parser)
arg = parser.parse_args()
mqtt_node_type = NodeType(arg.type)
devices = MqttEspDevice(id=arg.device_id)
if mqtt_node_type == NodeType.RELAY:
mqtt_client = MqttRelay(devices=devices)
elif mqtt_node_type == NodeType.TEMPHUM:
mqtt_client = MqttTempHum(devices=devices)
mqtt_client.set_message_callback(lambda device_id, payload: print(payload))
mqtt_client.configure_tls()
try:
mqtt_client.connect_and_loop()
except KeyboardInterrupt:
mqtt_client.disconnect()

View File

@ -12,6 +12,7 @@ __map__ = {
__all__ = list(itertools.chain(*__map__.values()))
def __getattr__(name):
if name in __all__:
for file, names in __map__.items():

View File

@ -1,4 +1,8 @@
from .mqtt import MqttBase
from .util import poll_tick
from .relay import MqttRelay, MqttRelayState
from .temphum import MqttTempHum
from .mqtt import MqttBase, MqttPayload, MqttPayloadCustomField
from ._node import MqttNode
from ._module import MqttModule
from .util import (
poll_tick,
get_modules as get_mqtt_modules,
import_module as import_mqtt_module
)

33
src/home/mqtt/_module.py Normal file
View File

@ -0,0 +1,33 @@
from __future__ import annotations
import abc
import logging
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ._node import MqttNode
class MqttModule(abc.ABC):
tick_interval: int
_initialized: bool
def __init__(self, tick_interval=0):
self.tick_interval = tick_interval
self._initialized = False
self._logger = logging.getLogger(self.__class__.__name__)
def init(self, mqtt: MqttNode):
pass
def is_initialized(self):
return self._initialized
def set_initialized(self):
self._initialized = True
def tick(self, mqtt: MqttNode):
pass
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes):
pass

87
src/home/mqtt/_node.py Normal file
View File

@ -0,0 +1,87 @@
import paho.mqtt.client as mqtt
from .mqtt import MqttBase
from typing import List
from ._module import MqttModule
class MqttNode(MqttBase):
_modules: List[MqttModule]
_module_subscriptions: dict[str, MqttModule]
_node_id: str
# _devices: list[MqttEspDevice]
# _message_callback: Optional[callable]
# _ota_publish_callback: Optional[callable]
def __init__(self,
node_id: str,
# devices: Union[MqttEspDevice, list[MqttEspDevice]],
subscribe_to_updates=True):
super().__init__(clean_session=True)
self._modules = []
self._module_subscriptions = {}
self._node_id = node_id
# if not isinstance(devices, list):
# devices = [devices]
# self._devices = devices
# self._message_callback = None
# self._ota_publish_callback = None
# self._subscribe_to_updates = subscribe_to_updates
# self._ota_mid = None
def on_connect(self, client: mqtt.Client, userdata, flags, rc):
super().on_connect(client, userdata, flags, rc)
for module in self._modules:
if not module.is_initialized():
module.init(self)
module.set_initialized()
def on_publish(self, client: mqtt.Client, userdata, mid):
pass # FIXME
# if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback:
# self._ota_publish_callback()
def on_message(self, client: mqtt.Client, userdata, msg):
try:
topic = msg.topic
actual_topic = topic[len(f'hk/{self._node_id}/'):]
if actual_topic in self._module_subscriptions:
self._module_subscriptions[actual_topic].handle_payload(self, actual_topic, msg.payload)
except Exception as e:
self._logger.exception(str(e))
# def push_ota(self,
# device_id,
# filename: str,
# publish_callback: callable,
# qos: int):
# device = next(d for d in self._devices if d.id == device_id)
# assert device.secret is not None, 'device secret not specified'
#
# self._ota_publish_callback = publish_callback
# payload = OtaPayload(secret=device.secret, filename=filename)
# publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
# payload=payload.pack(),
# qos=qos)
# self._ota_mid = publish_result.mid
# self._client.loop_write()
#
# @classmethod
# def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None):
# return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'
def add_module(self, module: MqttModule):
self._modules.append(module)
if self._connected:
module.init(self)
module.set_initialized()
def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1):
self._module_subscriptions[topic] = module
self._client.subscribe(f'hk/{self._node_id}/{topic}', qos)
def publish(self, topic: str, payload: bytes, qos: int = 1):
self._client.publish(f'hk/{self._node_id}/{topic}', payload, qos)
self._client.loop_write()

View File

@ -1,5 +1,5 @@
import abc
import struct
import abc
import re
from typing import Optional, Tuple
@ -142,4 +142,4 @@ def _bit_field_params(cl) -> Optional[Tuple[int, ...]]:
match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__)
if match is not None:
return tuple([int(match.group(i)) for i in range(1, 4)])
return None
return None

View File

@ -1,106 +0,0 @@
import re
import paho.mqtt.client as mqtt
from .mqtt import MqttBase
from typing import Optional, Union
from .payload.esp import (
OTAPayload,
OTAResultPayload,
DiagnosticsPayload,
InitialDiagnosticsPayload
)
class MqttEspDevice:
id: str
secret: Optional[str]
def __init__(self, id: str, secret: Optional[str] = None):
self.id = id
self.secret = secret
class MqttEspBase(MqttBase):
_devices: list[MqttEspDevice]
_message_callback: Optional[callable]
_ota_publish_callback: Optional[callable]
TOPIC_LEAF = 'esp'
def __init__(self,
devices: Union[MqttEspDevice, list[MqttEspDevice]],
subscribe_to_updates=True):
super().__init__(clean_session=True)
if not isinstance(devices, list):
devices = [devices]
self._devices = devices
self._message_callback = None
self._ota_publish_callback = None
self._subscribe_to_updates = subscribe_to_updates
self._ota_mid = None
def on_connect(self, client: mqtt.Client, userdata, flags, rc):
super().on_connect(client, userdata, flags, rc)
if self._subscribe_to_updates:
for device in self._devices:
topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#'
self._logger.debug(f"subscribing to {topic}")
client.subscribe(topic, qos=1)
def on_publish(self, client: mqtt.Client, userdata, mid):
if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback:
self._ota_publish_callback()
def set_message_callback(self, callback: callable):
self._message_callback = callback
def on_message(self, client: mqtt.Client, userdata, msg):
try:
match = re.match(self.get_mqtt_topics(), msg.topic)
self._logger.debug(f'topic: {msg.topic}')
if not match:
return
device_id = match.group(1)
subtopic = match.group(2)
# try:
next(d for d in self._devices if d.id == device_id)
# except StopIteration:h
# return
message = None
if subtopic == 'stat':
message = DiagnosticsPayload.unpack(msg.payload)
elif subtopic == 'stat1':
message = InitialDiagnosticsPayload.unpack(msg.payload)
elif subtopic == 'otares':
message = OTAResultPayload.unpack(msg.payload)
if message and self._message_callback:
self._message_callback(device_id, message)
return True
except Exception as e:
self._logger.exception(str(e))
def push_ota(self,
device_id,
filename: str,
publish_callback: callable,
qos: int):
device = next(d for d in self._devices if d.id == device_id)
assert device.secret is not None, 'device secret not specified'
self._ota_publish_callback = publish_callback
payload = OTAPayload(secret=device.secret, filename=filename)
publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
payload=payload.pack(),
qos=qos)
self._ota_mid = publish_result.mid
self._client.loop_write()
@classmethod
def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None):
return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'

View File

@ -1,39 +1,7 @@
import hashlib
from ..mqtt import MqttPayload, MqttPayloadCustomField
from .._node import MqttNode, MqttModule
from .base_payload import MqttPayload, MqttPayloadCustomField
class OTAResultPayload(MqttPayload):
FORMAT = '=BB'
result: int
error_code: int
class OTAPayload(MqttPayload):
secret: str
filename: str
# structure of returned data:
#
# uint8_t[len(secret)] secret;
# uint8_t[16] md5;
# *uint8_t data
def pack(self):
buf = bytearray(self.secret.encode())
m = hashlib.md5()
with open(self.filename, 'rb') as fd:
content = fd.read()
m.update(content)
buf.extend(m.digest())
buf.extend(content)
return buf
def unpack(cls, buf: bytes):
raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented')
# secret = buf[:12].decode()
# filename = buf[12:].decode()
# return OTAPayload(secret=secret, filename=filename)
MODULE_NAME = 'MqttDiagnosticsModule'
class DiagnosticsFlags(MqttPayloadCustomField):
@ -76,3 +44,16 @@ class DiagnosticsPayload(MqttPayload):
rssi: int
free_heap: int
flags: DiagnosticsFlags
class MqttDiagnosticsModule(MqttModule):
def init(self, mqtt: MqttNode):
for topic in ('diag', 'd1ag', 'stat', 'stat1'):
mqtt.subscribe_module(topic, self)
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes):
if topic in ('stat', 'diag'):
message = DiagnosticsPayload.unpack(payload)
elif topic in ('stat1', 'd1ag'):
message = InitialDiagnosticsPayload.unpack(payload)
self._logger.debug(message)

View File

@ -1,7 +1,7 @@
import struct
from .base_payload import MqttPayload, bit_field
from typing import Tuple
from .._node import MqttNode
from .._payload import MqttPayload, bit_field
_mult_10 = lambda n: int(n*10)
_div_10 = lambda n: n/10
@ -71,3 +71,7 @@ class Generation(MqttPayload):
time: int
wh: int
class MqttInverterModule(MqttNode):
pass

View File

@ -0,0 +1,65 @@
import hashlib
from ..mqtt import MqttPayload
from .._node import MqttModule, MqttNode
MODULE_NAME = 'MqttOtaModule'
class OtaResultPayload(MqttPayload):
FORMAT = '=BB'
result: int
error_code: int
class OtaPayload(MqttPayload):
secret: str
filename: str
# structure of returned data:
#
# uint8_t[len(secret)] secret;
# uint8_t[16] md5;
# *uint8_t data
def pack(self):
buf = bytearray(self.secret.encode())
m = hashlib.md5()
with open(self.filename, 'rb') as fd:
content = fd.read()
m.update(content)
buf.extend(m.digest())
buf.extend(content)
return buf
def unpack(cls, buf: bytes):
raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented')
# secret = buf[:12].decode()
# filename = buf[12:].decode()
# return OTAPayload(secret=secret, filename=filename)
class MqttOtaModule(MqttModule):
def init(self, mqtt: MqttNode):
mqtt.subscribe_module("otares", self)
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes):
if topic == 'otares':
message = OtaResultPayload.unpack(payload)
self._logger.debug(message)
# def push_ota(self,
# node_id,
# filename: str,
# publish_callback: callable,
# qos: int):
# device = next(d for d in self._devices if d.id == device_id)
# assert device.secret is not None, 'device secret not specified'
#
# self._ota_publish_callback = publish_callback
# payload = OtaPayload(secret=device.secret, filename=filename)
# publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
# payload=payload.pack(),
# qos=qos)
# self._ota_mid = publish_result.mid
# self._client.loop_write()

View File

@ -0,0 +1,65 @@
import paho.mqtt.client as mqtt
import re
import datetime
from .. import MqttModule, MqttPayload, MqttNode
MODULE_NAME = 'MqttRelayModule'
class MqttPowerSwitchPayload(MqttPayload):
FORMAT = '=12sB'
PACKER = {
'state': lambda n: int(n),
'secret': lambda s: s.encode('utf-8')
}
UNPACKER = {
'state': lambda n: bool(n),
'secret': lambda s: s.decode('utf-8')
}
secret: str
state: bool
class MqttRelayState:
enabled: bool
update_time: datetime.datetime
rssi: int
fw_version: int
ever_updated: bool
def __init__(self):
self.ever_updated = False
self.enabled = False
self.rssi = 0
def update(self,
enabled: bool,
rssi: int,
fw_version=None):
self.ever_updated = True
self.enabled = enabled
self.rssi = rssi
self.update_time = datetime.datetime.now()
if fw_version:
self.fw_version = fw_version
class MqttRelayModule(MqttModule):
def init(self, mqtt: MqttNode):
mqtt.subscribe_module('relay/switch', self)
@staticmethod
def switchpower(mqtt: MqttNode,
enable: bool,
secret: str):
payload = MqttPowerSwitchPayload(secret=secret, state=enable)
mqtt.publish('relay/switch', payload=payload.pack())
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes):
if topic != 'relay/switch':
return
message = MqttPowerSwitchPayload.unpack(payload)
self._logger.debug(message)

View File

@ -0,0 +1,55 @@
from enum import auto
from .._node import MqttNode
from .._module import MqttModule
from .._payload import MqttPayload
from ...util import HashableEnum
two_digits_precision = lambda x: round(x, 2)
MODULE_NAME = 'MqttTempHumModule'
class TempHumDataPayload(MqttPayload):
FORMAT = '=ddb'
UNPACKER = {
'temp': two_digits_precision,
'rh': two_digits_precision
}
temp: float
rh: float
error: int
class MqttTempHumNodes(HashableEnum):
KBN_SH_HALL = auto()
KBN_SH_BATHROOM = auto()
KBN_SH_LIVINGROOM = auto()
KBN_SH_BEDROOM = auto()
KBN_BH_2FL = auto()
KBN_BH_2FL_STREET = auto()
KBN_BH_1FL_LIVINGROOM = auto()
KBN_BH_1FL_BEDROOM = auto()
KBN_BH_1FL_BATHROOM = auto()
KBN_NH_1FL_INV = auto()
KBN_NH_1FL_CENTER = auto()
KBN_NH_1LF_KT = auto()
KBN_NH_1FL_DS = auto()
KBN_NH_1FS_EZ = auto()
SPB_FLAT120_CABINET = auto()
class MqttTempHumModule(MqttModule):
def init(self, mqtt: MqttNode):
mqtt.subscribe_module('temphum/data', self)
def handle_payload(self,
mqtt: MqttNode,
topic: str,
payload: bytes):
if topic == 'temphum/data':
message = TempHumDataPayload.unpack(payload)
self._logger.debug(message)

View File

@ -3,8 +3,8 @@ import paho.mqtt.client as mqtt
import ssl
import logging
from typing import Tuple
from ..config import config
from ._payload import *
def username_and_password() -> Tuple[str, str]:
@ -14,6 +14,8 @@ def username_and_password() -> Tuple[str, str]:
class MqttBase:
_connected: bool
def __init__(self, clean_session=True):
self._client = mqtt.Client(client_id=config['mqtt']['client_id'],
protocol=mqtt.MQTTv311,
@ -24,6 +26,7 @@ class MqttBase:
self._client.on_log = self.on_log
self._client.on_publish = self.on_publish
self._loop_started = False
self._connected = False
self._logger = logging.getLogger(self.__class__.__name__)
@ -41,7 +44,9 @@ class MqttBase:
'assets',
'mqtt_ca.crt'
))
self._client.tls_set(ca_certs=ca_certs, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_TLSv1_2)
self._client.tls_set(ca_certs=ca_certs,
cert_reqs=ssl.CERT_REQUIRED,
tls_version=ssl.PROTOCOL_TLSv1_2)
def connect_and_loop(self, loop_forever=True):
host = config['mqtt']['host']
@ -61,9 +66,11 @@ class MqttBase:
def on_connect(self, client: mqtt.Client, userdata, flags, rc):
self._logger.info("Connected with result code " + str(rc))
self._connected = True
def on_disconnect(self, client: mqtt.Client, userdata, rc):
self._logger.info("Disconnected with result code " + str(rc))
self._connected = False
def on_log(self, client: mqtt.Client, userdata, level, buf):
level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO
@ -73,4 +80,15 @@ class MqttBase:
self._logger.debug(msg.topic + ": " + str(msg.payload))
def on_publish(self, client: mqtt.Client, userdata, mid):
self._logger.debug(f'publish done, mid={mid}')
self._logger.debug(f'publish done, mid={mid}')
class MqttEspDevice:
id: str
secret: Optional[str]
def __init__(self,
node_id: str,
secret: Optional[str] = None):
self.id = node_id
self.secret = secret

View File

@ -1 +0,0 @@
from .base_payload import MqttPayload

View File

@ -1,22 +0,0 @@
from .base_payload import MqttPayload
from .esp import (
OTAResultPayload,
OTAPayload,
InitialDiagnosticsPayload,
DiagnosticsPayload
)
class PowerPayload(MqttPayload):
FORMAT = '=12sB'
PACKER = {
'state': lambda n: int(n),
'secret': lambda s: s.encode('utf-8')
}
UNPACKER = {
'state': lambda n: bool(n),
'secret': lambda s: s.decode('utf-8')
}
secret: str
state: bool

View File

@ -1,20 +0,0 @@
from .base_payload import MqttPayload
_mult_100 = lambda n: int(n*100)
_div_100 = lambda n: n/100
class Temperature(MqttPayload):
FORMAT = 'IhH'
PACKER = {
'temp': _mult_100,
'rh': _mult_100,
}
UNPACKER = {
'temp': _div_100,
'rh': _div_100,
}
time: int
temp: float
rh: float

View File

@ -1,15 +0,0 @@
from .base_payload import MqttPayload
two_digits_precision = lambda x: round(x, 2)
class TempHumDataPayload(MqttPayload):
FORMAT = '=ddb'
UNPACKER = {
'temp': two_digits_precision,
'rh': two_digits_precision
}
temp: float
rh: float
error: int

View File

@ -1,71 +0,0 @@
import paho.mqtt.client as mqtt
import re
import datetime
from .payload.relay import (
PowerPayload,
)
from .esp import MqttEspBase
class MqttRelay(MqttEspBase):
TOPIC_LEAF = 'relay'
def set_power(self, device_id, enable: bool, secret=None):
device = next(d for d in self._devices if d.id == device_id)
secret = secret if secret else device.secret
assert secret is not None, 'device secret not specified'
payload = PowerPayload(secret=secret,
state=enable)
self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/power',
payload=payload.pack(),
qos=1)
self._client.loop_write()
def on_message(self, client: mqtt.Client, userdata, msg):
if super().on_message(client, userdata, msg):
return
try:
match = re.match(self.get_mqtt_topics(['power']), msg.topic)
if not match:
return
device_id = match.group(1)
subtopic = match.group(2)
message = None
if subtopic == 'power':
message = PowerPayload.unpack(msg.payload)
if message and self._message_callback:
self._message_callback(device_id, message)
except Exception as e:
self._logger.exception(str(e))
class MqttRelayState:
enabled: bool
update_time: datetime.datetime
rssi: int
fw_version: int
ever_updated: bool
def __init__(self):
self.ever_updated = False
self.enabled = False
self.rssi = 0
def update(self,
enabled: bool,
rssi: int,
fw_version=None):
self.ever_updated = True
self.enabled = enabled
self.rssi = rssi
self.update_time = datetime.datetime.now()
if fw_version:
self.fw_version = fw_version

View File

@ -1,54 +0,0 @@
import paho.mqtt.client as mqtt
import re
from enum import auto
from .payload.temphum import TempHumDataPayload
from .esp import MqttEspBase
from ..util import HashableEnum
class MqttTempHumNodes(HashableEnum):
KBN_SH_HALL = auto()
KBN_SH_BATHROOM = auto()
KBN_SH_LIVINGROOM = auto()
KBN_SH_BEDROOM = auto()
KBN_BH_2FL = auto()
KBN_BH_2FL_STREET = auto()
KBN_BH_1FL_LIVINGROOM = auto()
KBN_BH_1FL_BEDROOM = auto()
KBN_BH_1FL_BATHROOM = auto()
KBN_NH_1FL_INV = auto()
KBN_NH_1FL_CENTER = auto()
KBN_NH_1LF_KT = auto()
KBN_NH_1FL_DS = auto()
KBN_NH_1FS_EZ = auto()
SPB_FLAT120_CABINET = auto()
class MqttTempHum(MqttEspBase):
TOPIC_LEAF = 'temphum'
def on_message(self, client: mqtt.Client, userdata, msg):
if super().on_message(client, userdata, msg):
return
try:
match = re.match(self.get_mqtt_topics(['data']), msg.topic)
if not match:
return
device_id = match.group(1)
subtopic = match.group(2)
message = None
if subtopic == 'data':
message = TempHumDataPayload.unpack(msg.payload)
if message and self._message_callback:
self._message_callback(device_id, message)
except Exception as e:
self._logger.exception(str(e))

View File

@ -1,4 +1,9 @@
import time
import os
import re
import importlib
from typing import List
def poll_tick(freq):
@ -6,3 +11,16 @@ def poll_tick(freq):
while True:
t += freq
yield max(t - time.time(), 0)
def get_modules() -> List[str]:
modules = []
for name in os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module')):
name = re.sub(r'\.py$', '', name)
modules.append(name)
return modules
def import_module(module: str):
return importlib.import_module(
f'..module.{module}', __name__)

View File

@ -16,10 +16,6 @@ _products_dir = os.path.join(
def get_products():
products = []
for f in os.listdir(_products_dir):
# temp hack
if f.endswith('-esp01'):
continue
# skip the common dir
if f in ('common',):
continue

58
src/mqtt_node_util.py Executable file
View File

@ -0,0 +1,58 @@
#!/usr/bin/env python3
from typing import Optional
from argparse import ArgumentParser, ArgumentError
from home.config import config
from home.mqtt import MqttNode, get_mqtt_modules, import_mqtt_module, MqttModule
mqtt: Optional[MqttNode] = None
def add_module(module: str) -> MqttModule:
module = import_mqtt_module(module)
if not hasattr(module, 'MODULE_NAME'):
raise RuntimeError(f'MODULE_NAME not found in module {m}')
cl = getattr(module, getattr(module, 'MODULE_NAME'))
instance = cl()
mqtt.add_module(instance)
return instance
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--node-id', type=str, required=True)
parser.add_argument('--modules', type=str, choices=get_mqtt_modules(), nargs='*',
help='mqtt modules to include')
parser.add_argument('--switch-relay', choices=[0, 1], type=int,
help='send relay state')
parser.add_argument('--switch-relay-secret', type=str,
help='secret password to switch relay')
config.load('mqtt_util', parser=parser)
arg = parser.parse_args()
if (arg.switch_relay is not None or arg.switch_relay_secret is not None) and 'relay' not in arg.modules:
raise ArgumentError(None, '--relay is only allowed when \'relay\' module included in --modules')
if (arg.switch_relay is not None and arg.switch_relay_secret is None) or (arg.switch_relay is None and arg.switch_relay_secret is not None):
raise ArgumentError(None, 'both --switch-relay and --switch-relay-secret are required')
mqtt = MqttNode(node_id=arg.node_id)
# must-have modules
add_module('ota')
add_module('diagnostics')
if arg.modules:
for m in arg.modules:
module_instance = add_module(m)
if m == 'relay' and arg.switch_relay is not None:
module_instance.switchpower(mqtt,
arg.switch_relay == 1,
arg.switch_relay_secret)
mqtt.configure_tls()
try:
mqtt.connect_and_loop()
except KeyboardInterrupt:
mqtt.disconnect()

View File

@ -8,10 +8,9 @@ from telegram import ReplyKeyboardMarkup, User
from home.config import config
from home.telegram import bot
from home.telegram._botutil import user_any_name
from home.mqtt.esp import MqttEspDevice
from home.mqtt import MqttRelay, MqttRelayState
from home.mqtt.payload import MqttPayload
from home.mqtt.payload.relay import InitialDiagnosticsPayload, DiagnosticsPayload
from home.mqtt import MqttEspDevice, MqttPayload
from home.mqtt.module.relay import MqttRelayState
from home.mqtt.module.diagnostics import InitialDiagnosticsPayload, DiagnosticsPayload
config.load('pump_mqtt_bot')