wip
This commit is contained in:
parent
b02a9c5473
commit
c976495222
@ -55,13 +55,9 @@ Mqtt::Mqtt() {
|
||||
}
|
||||
}
|
||||
|
||||
// if (ota.readyToRestart) {
|
||||
// restartTimer.once(1, restart);
|
||||
// } else {
|
||||
reconnectTimer.once(2, [&]() {
|
||||
reconnect();
|
||||
});
|
||||
// }
|
||||
reconnectTimer.once(2, [&]() {
|
||||
reconnect();
|
||||
});
|
||||
});
|
||||
|
||||
client.onSubscribe([&](uint16_t packetId, const SubscribeReturncode* returncodes, size_t len) {
|
||||
@ -79,7 +75,7 @@ Mqtt::Mqtt() {
|
||||
PRINTF("mqtt: message received, topic=%s, qos=%d, dup=%d, retain=%d, len=%ul, index=%ul, total=%ul\n",
|
||||
topic, properties.qos, (int)properties.dup, (int)properties.retain, len, index, total);
|
||||
|
||||
const char *ptr = topic + nodeId.length() + 10;
|
||||
const char *ptr = topic + nodeId.length() + 4;
|
||||
String relevantTopic(ptr);
|
||||
|
||||
auto it = moduleSubscriptions.find(relevantTopic);
|
||||
@ -87,7 +83,7 @@ Mqtt::Mqtt() {
|
||||
auto module = it->second;
|
||||
module->handlePayload(*this, relevantTopic, properties.packetId, payload, len, index, total);
|
||||
} else {
|
||||
PRINTF("error: module subscription for topic %s not found\n", topic);
|
||||
PRINTF("error: module subscription for topic %s not found\n", relevantTopic.c_str());
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -1,42 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
from typing import Optional
|
||||
from argparse import ArgumentParser
|
||||
from enum import Enum
|
||||
|
||||
from home.config import config
|
||||
from home.mqtt import MqttRelay
|
||||
from home.mqtt.esp import MqttEspBase
|
||||
from home.mqtt.temphum import MqttTempHum
|
||||
from home.mqtt.esp import MqttEspDevice
|
||||
|
||||
mqtt_client: Optional[MqttEspBase] = None
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
RELAY = 'relay'
|
||||
TEMPHUM = 'temphum'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--device-id', type=str, required=True)
|
||||
parser.add_argument('--type', type=str, required=True,
|
||||
choices=[i.name.lower() for i in NodeType])
|
||||
|
||||
config.load('mqtt_util', parser=parser)
|
||||
arg = parser.parse_args()
|
||||
|
||||
mqtt_node_type = NodeType(arg.type)
|
||||
devices = MqttEspDevice(id=arg.device_id)
|
||||
|
||||
if mqtt_node_type == NodeType.RELAY:
|
||||
mqtt_client = MqttRelay(devices=devices)
|
||||
elif mqtt_node_type == NodeType.TEMPHUM:
|
||||
mqtt_client = MqttTempHum(devices=devices)
|
||||
|
||||
mqtt_client.set_message_callback(lambda device_id, payload: print(payload))
|
||||
mqtt_client.configure_tls()
|
||||
try:
|
||||
mqtt_client.connect_and_loop()
|
||||
except KeyboardInterrupt:
|
||||
mqtt_client.disconnect()
|
@ -12,6 +12,7 @@ __map__ = {
|
||||
|
||||
__all__ = list(itertools.chain(*__map__.values()))
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name in __all__:
|
||||
for file, names in __map__.items():
|
||||
|
@ -1,4 +1,8 @@
|
||||
from .mqtt import MqttBase
|
||||
from .util import poll_tick
|
||||
from .relay import MqttRelay, MqttRelayState
|
||||
from .temphum import MqttTempHum
|
||||
from .mqtt import MqttBase, MqttPayload, MqttPayloadCustomField
|
||||
from ._node import MqttNode
|
||||
from ._module import MqttModule
|
||||
from .util import (
|
||||
poll_tick,
|
||||
get_modules as get_mqtt_modules,
|
||||
import_module as import_mqtt_module
|
||||
)
|
33
src/home/mqtt/_module.py
Normal file
33
src/home/mqtt/_module.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from ._node import MqttNode
|
||||
|
||||
|
||||
class MqttModule(abc.ABC):
|
||||
tick_interval: int
|
||||
_initialized: bool
|
||||
|
||||
def __init__(self, tick_interval=0):
|
||||
self.tick_interval = tick_interval
|
||||
self._initialized = False
|
||||
self._logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
def init(self, mqtt: MqttNode):
|
||||
pass
|
||||
|
||||
def is_initialized(self):
|
||||
return self._initialized
|
||||
|
||||
def set_initialized(self):
|
||||
self._initialized = True
|
||||
|
||||
def tick(self, mqtt: MqttNode):
|
||||
pass
|
||||
|
||||
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes):
|
||||
pass
|
87
src/home/mqtt/_node.py
Normal file
87
src/home/mqtt/_node.py
Normal file
@ -0,0 +1,87 @@
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
from .mqtt import MqttBase
|
||||
from typing import List
|
||||
from ._module import MqttModule
|
||||
|
||||
|
||||
class MqttNode(MqttBase):
|
||||
_modules: List[MqttModule]
|
||||
_module_subscriptions: dict[str, MqttModule]
|
||||
_node_id: str
|
||||
# _devices: list[MqttEspDevice]
|
||||
# _message_callback: Optional[callable]
|
||||
# _ota_publish_callback: Optional[callable]
|
||||
|
||||
def __init__(self,
|
||||
node_id: str,
|
||||
# devices: Union[MqttEspDevice, list[MqttEspDevice]],
|
||||
subscribe_to_updates=True):
|
||||
super().__init__(clean_session=True)
|
||||
self._modules = []
|
||||
self._module_subscriptions = {}
|
||||
self._node_id = node_id
|
||||
# if not isinstance(devices, list):
|
||||
# devices = [devices]
|
||||
# self._devices = devices
|
||||
# self._message_callback = None
|
||||
# self._ota_publish_callback = None
|
||||
# self._subscribe_to_updates = subscribe_to_updates
|
||||
# self._ota_mid = None
|
||||
|
||||
def on_connect(self, client: mqtt.Client, userdata, flags, rc):
|
||||
super().on_connect(client, userdata, flags, rc)
|
||||
for module in self._modules:
|
||||
if not module.is_initialized():
|
||||
module.init(self)
|
||||
module.set_initialized()
|
||||
|
||||
def on_publish(self, client: mqtt.Client, userdata, mid):
|
||||
pass # FIXME
|
||||
# if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback:
|
||||
# self._ota_publish_callback()
|
||||
|
||||
def on_message(self, client: mqtt.Client, userdata, msg):
|
||||
try:
|
||||
topic = msg.topic
|
||||
actual_topic = topic[len(f'hk/{self._node_id}/'):]
|
||||
|
||||
if actual_topic in self._module_subscriptions:
|
||||
self._module_subscriptions[actual_topic].handle_payload(self, actual_topic, msg.payload)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.exception(str(e))
|
||||
|
||||
# def push_ota(self,
|
||||
# device_id,
|
||||
# filename: str,
|
||||
# publish_callback: callable,
|
||||
# qos: int):
|
||||
# device = next(d for d in self._devices if d.id == device_id)
|
||||
# assert device.secret is not None, 'device secret not specified'
|
||||
#
|
||||
# self._ota_publish_callback = publish_callback
|
||||
# payload = OtaPayload(secret=device.secret, filename=filename)
|
||||
# publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
|
||||
# payload=payload.pack(),
|
||||
# qos=qos)
|
||||
# self._ota_mid = publish_result.mid
|
||||
# self._client.loop_write()
|
||||
#
|
||||
# @classmethod
|
||||
# def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None):
|
||||
# return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'
|
||||
|
||||
def add_module(self, module: MqttModule):
|
||||
self._modules.append(module)
|
||||
if self._connected:
|
||||
module.init(self)
|
||||
module.set_initialized()
|
||||
|
||||
def subscribe_module(self, topic: str, module: MqttModule, qos: int = 1):
|
||||
self._module_subscriptions[topic] = module
|
||||
self._client.subscribe(f'hk/{self._node_id}/{topic}', qos)
|
||||
|
||||
def publish(self, topic: str, payload: bytes, qos: int = 1):
|
||||
self._client.publish(f'hk/{self._node_id}/{topic}', payload, qos)
|
||||
self._client.loop_write()
|
@ -1,5 +1,5 @@
|
||||
import abc
|
||||
import struct
|
||||
import abc
|
||||
import re
|
||||
|
||||
from typing import Optional, Tuple
|
||||
@ -142,4 +142,4 @@ def _bit_field_params(cl) -> Optional[Tuple[int, ...]]:
|
||||
match = re.match(r'MQTTPayloadBitField_(\d+)_(\d+)_(\d)$', cl.__name__)
|
||||
if match is not None:
|
||||
return tuple([int(match.group(i)) for i in range(1, 4)])
|
||||
return None
|
||||
return None
|
@ -1,106 +0,0 @@
|
||||
import re
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
from .mqtt import MqttBase
|
||||
from typing import Optional, Union
|
||||
from .payload.esp import (
|
||||
OTAPayload,
|
||||
OTAResultPayload,
|
||||
DiagnosticsPayload,
|
||||
InitialDiagnosticsPayload
|
||||
)
|
||||
|
||||
|
||||
class MqttEspDevice:
|
||||
id: str
|
||||
secret: Optional[str]
|
||||
|
||||
def __init__(self, id: str, secret: Optional[str] = None):
|
||||
self.id = id
|
||||
self.secret = secret
|
||||
|
||||
|
||||
class MqttEspBase(MqttBase):
|
||||
_devices: list[MqttEspDevice]
|
||||
_message_callback: Optional[callable]
|
||||
_ota_publish_callback: Optional[callable]
|
||||
|
||||
TOPIC_LEAF = 'esp'
|
||||
|
||||
def __init__(self,
|
||||
devices: Union[MqttEspDevice, list[MqttEspDevice]],
|
||||
subscribe_to_updates=True):
|
||||
super().__init__(clean_session=True)
|
||||
if not isinstance(devices, list):
|
||||
devices = [devices]
|
||||
self._devices = devices
|
||||
self._message_callback = None
|
||||
self._ota_publish_callback = None
|
||||
self._subscribe_to_updates = subscribe_to_updates
|
||||
self._ota_mid = None
|
||||
|
||||
def on_connect(self, client: mqtt.Client, userdata, flags, rc):
|
||||
super().on_connect(client, userdata, flags, rc)
|
||||
|
||||
if self._subscribe_to_updates:
|
||||
for device in self._devices:
|
||||
topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#'
|
||||
self._logger.debug(f"subscribing to {topic}")
|
||||
client.subscribe(topic, qos=1)
|
||||
|
||||
def on_publish(self, client: mqtt.Client, userdata, mid):
|
||||
if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback:
|
||||
self._ota_publish_callback()
|
||||
|
||||
def set_message_callback(self, callback: callable):
|
||||
self._message_callback = callback
|
||||
|
||||
def on_message(self, client: mqtt.Client, userdata, msg):
|
||||
try:
|
||||
match = re.match(self.get_mqtt_topics(), msg.topic)
|
||||
self._logger.debug(f'topic: {msg.topic}')
|
||||
if not match:
|
||||
return
|
||||
|
||||
device_id = match.group(1)
|
||||
subtopic = match.group(2)
|
||||
|
||||
# try:
|
||||
next(d for d in self._devices if d.id == device_id)
|
||||
# except StopIteration:h
|
||||
# return
|
||||
|
||||
message = None
|
||||
if subtopic == 'stat':
|
||||
message = DiagnosticsPayload.unpack(msg.payload)
|
||||
elif subtopic == 'stat1':
|
||||
message = InitialDiagnosticsPayload.unpack(msg.payload)
|
||||
elif subtopic == 'otares':
|
||||
message = OTAResultPayload.unpack(msg.payload)
|
||||
|
||||
if message and self._message_callback:
|
||||
self._message_callback(device_id, message)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._logger.exception(str(e))
|
||||
|
||||
def push_ota(self,
|
||||
device_id,
|
||||
filename: str,
|
||||
publish_callback: callable,
|
||||
qos: int):
|
||||
device = next(d for d in self._devices if d.id == device_id)
|
||||
assert device.secret is not None, 'device secret not specified'
|
||||
|
||||
self._ota_publish_callback = publish_callback
|
||||
payload = OTAPayload(secret=device.secret, filename=filename)
|
||||
publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
|
||||
payload=payload.pack(),
|
||||
qos=qos)
|
||||
self._ota_mid = publish_result.mid
|
||||
self._client.loop_write()
|
||||
|
||||
@classmethod
|
||||
def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None):
|
||||
return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$'
|
@ -1,39 +1,7 @@
|
||||
import hashlib
|
||||
from ..mqtt import MqttPayload, MqttPayloadCustomField
|
||||
from .._node import MqttNode, MqttModule
|
||||
|
||||
from .base_payload import MqttPayload, MqttPayloadCustomField
|
||||
|
||||
|
||||
class OTAResultPayload(MqttPayload):
|
||||
FORMAT = '=BB'
|
||||
result: int
|
||||
error_code: int
|
||||
|
||||
|
||||
class OTAPayload(MqttPayload):
|
||||
secret: str
|
||||
filename: str
|
||||
|
||||
# structure of returned data:
|
||||
#
|
||||
# uint8_t[len(secret)] secret;
|
||||
# uint8_t[16] md5;
|
||||
# *uint8_t data
|
||||
|
||||
def pack(self):
|
||||
buf = bytearray(self.secret.encode())
|
||||
m = hashlib.md5()
|
||||
with open(self.filename, 'rb') as fd:
|
||||
content = fd.read()
|
||||
m.update(content)
|
||||
buf.extend(m.digest())
|
||||
buf.extend(content)
|
||||
return buf
|
||||
|
||||
def unpack(cls, buf: bytes):
|
||||
raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented')
|
||||
# secret = buf[:12].decode()
|
||||
# filename = buf[12:].decode()
|
||||
# return OTAPayload(secret=secret, filename=filename)
|
||||
MODULE_NAME = 'MqttDiagnosticsModule'
|
||||
|
||||
|
||||
class DiagnosticsFlags(MqttPayloadCustomField):
|
||||
@ -76,3 +44,16 @@ class DiagnosticsPayload(MqttPayload):
|
||||
rssi: int
|
||||
free_heap: int
|
||||
flags: DiagnosticsFlags
|
||||
|
||||
|
||||
class MqttDiagnosticsModule(MqttModule):
|
||||
def init(self, mqtt: MqttNode):
|
||||
for topic in ('diag', 'd1ag', 'stat', 'stat1'):
|
||||
mqtt.subscribe_module(topic, self)
|
||||
|
||||
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes):
|
||||
if topic in ('stat', 'diag'):
|
||||
message = DiagnosticsPayload.unpack(payload)
|
||||
elif topic in ('stat1', 'd1ag'):
|
||||
message = InitialDiagnosticsPayload.unpack(payload)
|
||||
self._logger.debug(message)
|
@ -1,7 +1,7 @@
|
||||
import struct
|
||||
|
||||
from .base_payload import MqttPayload, bit_field
|
||||
from typing import Tuple
|
||||
from .._node import MqttNode
|
||||
from .._payload import MqttPayload, bit_field
|
||||
|
||||
_mult_10 = lambda n: int(n*10)
|
||||
_div_10 = lambda n: n/10
|
||||
@ -71,3 +71,7 @@ class Generation(MqttPayload):
|
||||
|
||||
time: int
|
||||
wh: int
|
||||
|
||||
|
||||
class MqttInverterModule(MqttNode):
|
||||
pass
|
65
src/home/mqtt/module/ota.py
Normal file
65
src/home/mqtt/module/ota.py
Normal file
@ -0,0 +1,65 @@
|
||||
import hashlib
|
||||
|
||||
from ..mqtt import MqttPayload
|
||||
from .._node import MqttModule, MqttNode
|
||||
|
||||
MODULE_NAME = 'MqttOtaModule'
|
||||
|
||||
|
||||
class OtaResultPayload(MqttPayload):
|
||||
FORMAT = '=BB'
|
||||
result: int
|
||||
error_code: int
|
||||
|
||||
|
||||
class OtaPayload(MqttPayload):
|
||||
secret: str
|
||||
filename: str
|
||||
|
||||
# structure of returned data:
|
||||
#
|
||||
# uint8_t[len(secret)] secret;
|
||||
# uint8_t[16] md5;
|
||||
# *uint8_t data
|
||||
|
||||
def pack(self):
|
||||
buf = bytearray(self.secret.encode())
|
||||
m = hashlib.md5()
|
||||
with open(self.filename, 'rb') as fd:
|
||||
content = fd.read()
|
||||
m.update(content)
|
||||
buf.extend(m.digest())
|
||||
buf.extend(content)
|
||||
return buf
|
||||
|
||||
def unpack(cls, buf: bytes):
|
||||
raise RuntimeError(f'{cls.__class__.__name__}.unpack: not implemented')
|
||||
# secret = buf[:12].decode()
|
||||
# filename = buf[12:].decode()
|
||||
# return OTAPayload(secret=secret, filename=filename)
|
||||
|
||||
|
||||
class MqttOtaModule(MqttModule):
|
||||
def init(self, mqtt: MqttNode):
|
||||
mqtt.subscribe_module("otares", self)
|
||||
|
||||
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes):
|
||||
if topic == 'otares':
|
||||
message = OtaResultPayload.unpack(payload)
|
||||
self._logger.debug(message)
|
||||
|
||||
# def push_ota(self,
|
||||
# node_id,
|
||||
# filename: str,
|
||||
# publish_callback: callable,
|
||||
# qos: int):
|
||||
# device = next(d for d in self._devices if d.id == device_id)
|
||||
# assert device.secret is not None, 'device secret not specified'
|
||||
#
|
||||
# self._ota_publish_callback = publish_callback
|
||||
# payload = OtaPayload(secret=device.secret, filename=filename)
|
||||
# publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
|
||||
# payload=payload.pack(),
|
||||
# qos=qos)
|
||||
# self._ota_mid = publish_result.mid
|
||||
# self._client.loop_write()
|
65
src/home/mqtt/module/relay.py
Normal file
65
src/home/mqtt/module/relay.py
Normal file
@ -0,0 +1,65 @@
|
||||
import paho.mqtt.client as mqtt
|
||||
import re
|
||||
import datetime
|
||||
|
||||
from .. import MqttModule, MqttPayload, MqttNode
|
||||
|
||||
MODULE_NAME = 'MqttRelayModule'
|
||||
|
||||
|
||||
class MqttPowerSwitchPayload(MqttPayload):
|
||||
FORMAT = '=12sB'
|
||||
PACKER = {
|
||||
'state': lambda n: int(n),
|
||||
'secret': lambda s: s.encode('utf-8')
|
||||
}
|
||||
UNPACKER = {
|
||||
'state': lambda n: bool(n),
|
||||
'secret': lambda s: s.decode('utf-8')
|
||||
}
|
||||
|
||||
secret: str
|
||||
state: bool
|
||||
|
||||
|
||||
class MqttRelayState:
|
||||
enabled: bool
|
||||
update_time: datetime.datetime
|
||||
rssi: int
|
||||
fw_version: int
|
||||
ever_updated: bool
|
||||
|
||||
def __init__(self):
|
||||
self.ever_updated = False
|
||||
self.enabled = False
|
||||
self.rssi = 0
|
||||
|
||||
def update(self,
|
||||
enabled: bool,
|
||||
rssi: int,
|
||||
fw_version=None):
|
||||
self.ever_updated = True
|
||||
self.enabled = enabled
|
||||
self.rssi = rssi
|
||||
self.update_time = datetime.datetime.now()
|
||||
if fw_version:
|
||||
self.fw_version = fw_version
|
||||
|
||||
|
||||
class MqttRelayModule(MqttModule):
|
||||
def init(self, mqtt: MqttNode):
|
||||
mqtt.subscribe_module('relay/switch', self)
|
||||
|
||||
@staticmethod
|
||||
def switchpower(mqtt: MqttNode,
|
||||
enable: bool,
|
||||
secret: str):
|
||||
payload = MqttPowerSwitchPayload(secret=secret, state=enable)
|
||||
mqtt.publish('relay/switch', payload=payload.pack())
|
||||
|
||||
def handle_payload(self, mqtt: MqttNode, topic: str, payload: bytes):
|
||||
if topic != 'relay/switch':
|
||||
return
|
||||
|
||||
message = MqttPowerSwitchPayload.unpack(payload)
|
||||
self._logger.debug(message)
|
55
src/home/mqtt/module/temphum.py
Normal file
55
src/home/mqtt/module/temphum.py
Normal file
@ -0,0 +1,55 @@
|
||||
from enum import auto
|
||||
from .._node import MqttNode
|
||||
from .._module import MqttModule
|
||||
from .._payload import MqttPayload
|
||||
from ...util import HashableEnum
|
||||
|
||||
two_digits_precision = lambda x: round(x, 2)
|
||||
|
||||
MODULE_NAME = 'MqttTempHumModule'
|
||||
|
||||
|
||||
class TempHumDataPayload(MqttPayload):
|
||||
FORMAT = '=ddb'
|
||||
UNPACKER = {
|
||||
'temp': two_digits_precision,
|
||||
'rh': two_digits_precision
|
||||
}
|
||||
|
||||
temp: float
|
||||
rh: float
|
||||
error: int
|
||||
|
||||
|
||||
class MqttTempHumNodes(HashableEnum):
|
||||
KBN_SH_HALL = auto()
|
||||
KBN_SH_BATHROOM = auto()
|
||||
KBN_SH_LIVINGROOM = auto()
|
||||
KBN_SH_BEDROOM = auto()
|
||||
|
||||
KBN_BH_2FL = auto()
|
||||
KBN_BH_2FL_STREET = auto()
|
||||
KBN_BH_1FL_LIVINGROOM = auto()
|
||||
KBN_BH_1FL_BEDROOM = auto()
|
||||
KBN_BH_1FL_BATHROOM = auto()
|
||||
|
||||
KBN_NH_1FL_INV = auto()
|
||||
KBN_NH_1FL_CENTER = auto()
|
||||
KBN_NH_1LF_KT = auto()
|
||||
KBN_NH_1FL_DS = auto()
|
||||
KBN_NH_1FS_EZ = auto()
|
||||
|
||||
SPB_FLAT120_CABINET = auto()
|
||||
|
||||
|
||||
class MqttTempHumModule(MqttModule):
|
||||
def init(self, mqtt: MqttNode):
|
||||
mqtt.subscribe_module('temphum/data', self)
|
||||
|
||||
def handle_payload(self,
|
||||
mqtt: MqttNode,
|
||||
topic: str,
|
||||
payload: bytes):
|
||||
if topic == 'temphum/data':
|
||||
message = TempHumDataPayload.unpack(payload)
|
||||
self._logger.debug(message)
|
@ -3,8 +3,8 @@ import paho.mqtt.client as mqtt
|
||||
import ssl
|
||||
import logging
|
||||
|
||||
from typing import Tuple
|
||||
from ..config import config
|
||||
from ._payload import *
|
||||
|
||||
|
||||
def username_and_password() -> Tuple[str, str]:
|
||||
@ -14,6 +14,8 @@ def username_and_password() -> Tuple[str, str]:
|
||||
|
||||
|
||||
class MqttBase:
|
||||
_connected: bool
|
||||
|
||||
def __init__(self, clean_session=True):
|
||||
self._client = mqtt.Client(client_id=config['mqtt']['client_id'],
|
||||
protocol=mqtt.MQTTv311,
|
||||
@ -24,6 +26,7 @@ class MqttBase:
|
||||
self._client.on_log = self.on_log
|
||||
self._client.on_publish = self.on_publish
|
||||
self._loop_started = False
|
||||
self._connected = False
|
||||
|
||||
self._logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
@ -41,7 +44,9 @@ class MqttBase:
|
||||
'assets',
|
||||
'mqtt_ca.crt'
|
||||
))
|
||||
self._client.tls_set(ca_certs=ca_certs, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_TLSv1_2)
|
||||
self._client.tls_set(ca_certs=ca_certs,
|
||||
cert_reqs=ssl.CERT_REQUIRED,
|
||||
tls_version=ssl.PROTOCOL_TLSv1_2)
|
||||
|
||||
def connect_and_loop(self, loop_forever=True):
|
||||
host = config['mqtt']['host']
|
||||
@ -61,9 +66,11 @@ class MqttBase:
|
||||
|
||||
def on_connect(self, client: mqtt.Client, userdata, flags, rc):
|
||||
self._logger.info("Connected with result code " + str(rc))
|
||||
self._connected = True
|
||||
|
||||
def on_disconnect(self, client: mqtt.Client, userdata, rc):
|
||||
self._logger.info("Disconnected with result code " + str(rc))
|
||||
self._connected = False
|
||||
|
||||
def on_log(self, client: mqtt.Client, userdata, level, buf):
|
||||
level = mqtt.LOGGING_LEVEL[level] if level in mqtt.LOGGING_LEVEL else logging.INFO
|
||||
@ -73,4 +80,15 @@ class MqttBase:
|
||||
self._logger.debug(msg.topic + ": " + str(msg.payload))
|
||||
|
||||
def on_publish(self, client: mqtt.Client, userdata, mid):
|
||||
self._logger.debug(f'publish done, mid={mid}')
|
||||
self._logger.debug(f'publish done, mid={mid}')
|
||||
|
||||
|
||||
class MqttEspDevice:
|
||||
id: str
|
||||
secret: Optional[str]
|
||||
|
||||
def __init__(self,
|
||||
node_id: str,
|
||||
secret: Optional[str] = None):
|
||||
self.id = node_id
|
||||
self.secret = secret
|
||||
|
@ -1 +0,0 @@
|
||||
from .base_payload import MqttPayload
|
@ -1,22 +0,0 @@
|
||||
from .base_payload import MqttPayload
|
||||
from .esp import (
|
||||
OTAResultPayload,
|
||||
OTAPayload,
|
||||
InitialDiagnosticsPayload,
|
||||
DiagnosticsPayload
|
||||
)
|
||||
|
||||
|
||||
class PowerPayload(MqttPayload):
|
||||
FORMAT = '=12sB'
|
||||
PACKER = {
|
||||
'state': lambda n: int(n),
|
||||
'secret': lambda s: s.encode('utf-8')
|
||||
}
|
||||
UNPACKER = {
|
||||
'state': lambda n: bool(n),
|
||||
'secret': lambda s: s.decode('utf-8')
|
||||
}
|
||||
|
||||
secret: str
|
||||
state: bool
|
@ -1,20 +0,0 @@
|
||||
from .base_payload import MqttPayload
|
||||
|
||||
_mult_100 = lambda n: int(n*100)
|
||||
_div_100 = lambda n: n/100
|
||||
|
||||
|
||||
class Temperature(MqttPayload):
|
||||
FORMAT = 'IhH'
|
||||
PACKER = {
|
||||
'temp': _mult_100,
|
||||
'rh': _mult_100,
|
||||
}
|
||||
UNPACKER = {
|
||||
'temp': _div_100,
|
||||
'rh': _div_100,
|
||||
}
|
||||
|
||||
time: int
|
||||
temp: float
|
||||
rh: float
|
@ -1,15 +0,0 @@
|
||||
from .base_payload import MqttPayload
|
||||
|
||||
two_digits_precision = lambda x: round(x, 2)
|
||||
|
||||
|
||||
class TempHumDataPayload(MqttPayload):
|
||||
FORMAT = '=ddb'
|
||||
UNPACKER = {
|
||||
'temp': two_digits_precision,
|
||||
'rh': two_digits_precision
|
||||
}
|
||||
|
||||
temp: float
|
||||
rh: float
|
||||
error: int
|
@ -1,71 +0,0 @@
|
||||
import paho.mqtt.client as mqtt
|
||||
import re
|
||||
import datetime
|
||||
|
||||
from .payload.relay import (
|
||||
PowerPayload,
|
||||
)
|
||||
from .esp import MqttEspBase
|
||||
|
||||
|
||||
class MqttRelay(MqttEspBase):
|
||||
TOPIC_LEAF = 'relay'
|
||||
|
||||
def set_power(self, device_id, enable: bool, secret=None):
|
||||
device = next(d for d in self._devices if d.id == device_id)
|
||||
secret = secret if secret else device.secret
|
||||
|
||||
assert secret is not None, 'device secret not specified'
|
||||
|
||||
payload = PowerPayload(secret=secret,
|
||||
state=enable)
|
||||
self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/power',
|
||||
payload=payload.pack(),
|
||||
qos=1)
|
||||
self._client.loop_write()
|
||||
|
||||
def on_message(self, client: mqtt.Client, userdata, msg):
|
||||
if super().on_message(client, userdata, msg):
|
||||
return
|
||||
|
||||
try:
|
||||
match = re.match(self.get_mqtt_topics(['power']), msg.topic)
|
||||
if not match:
|
||||
return
|
||||
|
||||
device_id = match.group(1)
|
||||
subtopic = match.group(2)
|
||||
|
||||
message = None
|
||||
if subtopic == 'power':
|
||||
message = PowerPayload.unpack(msg.payload)
|
||||
|
||||
if message and self._message_callback:
|
||||
self._message_callback(device_id, message)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.exception(str(e))
|
||||
|
||||
|
||||
class MqttRelayState:
|
||||
enabled: bool
|
||||
update_time: datetime.datetime
|
||||
rssi: int
|
||||
fw_version: int
|
||||
ever_updated: bool
|
||||
|
||||
def __init__(self):
|
||||
self.ever_updated = False
|
||||
self.enabled = False
|
||||
self.rssi = 0
|
||||
|
||||
def update(self,
|
||||
enabled: bool,
|
||||
rssi: int,
|
||||
fw_version=None):
|
||||
self.ever_updated = True
|
||||
self.enabled = enabled
|
||||
self.rssi = rssi
|
||||
self.update_time = datetime.datetime.now()
|
||||
if fw_version:
|
||||
self.fw_version = fw_version
|
@ -1,54 +0,0 @@
|
||||
import paho.mqtt.client as mqtt
|
||||
import re
|
||||
|
||||
from enum import auto
|
||||
from .payload.temphum import TempHumDataPayload
|
||||
from .esp import MqttEspBase
|
||||
from ..util import HashableEnum
|
||||
|
||||
|
||||
class MqttTempHumNodes(HashableEnum):
|
||||
KBN_SH_HALL = auto()
|
||||
KBN_SH_BATHROOM = auto()
|
||||
KBN_SH_LIVINGROOM = auto()
|
||||
KBN_SH_BEDROOM = auto()
|
||||
|
||||
KBN_BH_2FL = auto()
|
||||
KBN_BH_2FL_STREET = auto()
|
||||
KBN_BH_1FL_LIVINGROOM = auto()
|
||||
KBN_BH_1FL_BEDROOM = auto()
|
||||
KBN_BH_1FL_BATHROOM = auto()
|
||||
|
||||
KBN_NH_1FL_INV = auto()
|
||||
KBN_NH_1FL_CENTER = auto()
|
||||
KBN_NH_1LF_KT = auto()
|
||||
KBN_NH_1FL_DS = auto()
|
||||
KBN_NH_1FS_EZ = auto()
|
||||
|
||||
SPB_FLAT120_CABINET = auto()
|
||||
|
||||
|
||||
class MqttTempHum(MqttEspBase):
|
||||
TOPIC_LEAF = 'temphum'
|
||||
|
||||
def on_message(self, client: mqtt.Client, userdata, msg):
|
||||
if super().on_message(client, userdata, msg):
|
||||
return
|
||||
|
||||
try:
|
||||
match = re.match(self.get_mqtt_topics(['data']), msg.topic)
|
||||
if not match:
|
||||
return
|
||||
|
||||
device_id = match.group(1)
|
||||
subtopic = match.group(2)
|
||||
|
||||
message = None
|
||||
if subtopic == 'data':
|
||||
message = TempHumDataPayload.unpack(msg.payload)
|
||||
|
||||
if message and self._message_callback:
|
||||
self._message_callback(device_id, message)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.exception(str(e))
|
@ -1,4 +1,9 @@
|
||||
import time
|
||||
import os
|
||||
import re
|
||||
import importlib
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def poll_tick(freq):
|
||||
@ -6,3 +11,16 @@ def poll_tick(freq):
|
||||
while True:
|
||||
t += freq
|
||||
yield max(t - time.time(), 0)
|
||||
|
||||
|
||||
def get_modules() -> List[str]:
|
||||
modules = []
|
||||
for name in os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'module')):
|
||||
name = re.sub(r'\.py$', '', name)
|
||||
modules.append(name)
|
||||
return modules
|
||||
|
||||
|
||||
def import_module(module: str):
|
||||
return importlib.import_module(
|
||||
f'..module.{module}', __name__)
|
@ -16,10 +16,6 @@ _products_dir = os.path.join(
|
||||
def get_products():
|
||||
products = []
|
||||
for f in os.listdir(_products_dir):
|
||||
# temp hack
|
||||
if f.endswith('-esp01'):
|
||||
continue
|
||||
# skip the common dir
|
||||
if f in ('common',):
|
||||
continue
|
||||
|
||||
|
58
src/mqtt_node_util.py
Executable file
58
src/mqtt_node_util.py
Executable file
@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python3
|
||||
from typing import Optional
|
||||
from argparse import ArgumentParser, ArgumentError
|
||||
|
||||
from home.config import config
|
||||
from home.mqtt import MqttNode, get_mqtt_modules, import_mqtt_module, MqttModule
|
||||
|
||||
mqtt: Optional[MqttNode] = None
|
||||
|
||||
|
||||
def add_module(module: str) -> MqttModule:
|
||||
module = import_mqtt_module(module)
|
||||
if not hasattr(module, 'MODULE_NAME'):
|
||||
raise RuntimeError(f'MODULE_NAME not found in module {m}')
|
||||
cl = getattr(module, getattr(module, 'MODULE_NAME'))
|
||||
instance = cl()
|
||||
mqtt.add_module(instance)
|
||||
return instance
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--node-id', type=str, required=True)
|
||||
parser.add_argument('--modules', type=str, choices=get_mqtt_modules(), nargs='*',
|
||||
help='mqtt modules to include')
|
||||
parser.add_argument('--switch-relay', choices=[0, 1], type=int,
|
||||
help='send relay state')
|
||||
parser.add_argument('--switch-relay-secret', type=str,
|
||||
help='secret password to switch relay')
|
||||
|
||||
config.load('mqtt_util', parser=parser)
|
||||
arg = parser.parse_args()
|
||||
|
||||
if (arg.switch_relay is not None or arg.switch_relay_secret is not None) and 'relay' not in arg.modules:
|
||||
raise ArgumentError(None, '--relay is only allowed when \'relay\' module included in --modules')
|
||||
|
||||
if (arg.switch_relay is not None and arg.switch_relay_secret is None) or (arg.switch_relay is None and arg.switch_relay_secret is not None):
|
||||
raise ArgumentError(None, 'both --switch-relay and --switch-relay-secret are required')
|
||||
|
||||
mqtt = MqttNode(node_id=arg.node_id)
|
||||
|
||||
# must-have modules
|
||||
add_module('ota')
|
||||
add_module('diagnostics')
|
||||
|
||||
if arg.modules:
|
||||
for m in arg.modules:
|
||||
module_instance = add_module(m)
|
||||
if m == 'relay' and arg.switch_relay is not None:
|
||||
module_instance.switchpower(mqtt,
|
||||
arg.switch_relay == 1,
|
||||
arg.switch_relay_secret)
|
||||
|
||||
mqtt.configure_tls()
|
||||
try:
|
||||
mqtt.connect_and_loop()
|
||||
except KeyboardInterrupt:
|
||||
mqtt.disconnect()
|
@ -8,10 +8,9 @@ from telegram import ReplyKeyboardMarkup, User
|
||||
from home.config import config
|
||||
from home.telegram import bot
|
||||
from home.telegram._botutil import user_any_name
|
||||
from home.mqtt.esp import MqttEspDevice
|
||||
from home.mqtt import MqttRelay, MqttRelayState
|
||||
from home.mqtt.payload import MqttPayload
|
||||
from home.mqtt.payload.relay import InitialDiagnosticsPayload, DiagnosticsPayload
|
||||
from home.mqtt import MqttEspDevice, MqttPayload
|
||||
from home.mqtt.module.relay import MqttRelayState
|
||||
from home.mqtt.module.diagnostics import InitialDiagnosticsPayload, DiagnosticsPayload
|
||||
|
||||
|
||||
config.load('pump_mqtt_bot')
|
||||
|
Loading…
x
Reference in New Issue
Block a user