106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
import re
|
|
import paho.mqtt.client as mqtt
|
|
|
|
from .mqtt import MqttBase
|
|
from typing import Optional, Union
|
|
from .payload.esp import (
|
|
OTAPayload,
|
|
OTAResultPayload,
|
|
DiagnosticsPayload,
|
|
InitialDiagnosticsPayload
|
|
)
|
|
|
|
|
|
class MqttEspDevice:
|
|
id: str
|
|
secret: Optional[str]
|
|
|
|
def __init__(self, id: str, secret: Optional[str] = None):
|
|
self.id = id
|
|
self.secret = secret
|
|
|
|
|
|
class MqttEspBase(MqttBase):
|
|
_devices: list[MqttEspDevice]
|
|
_message_callback: Optional[callable]
|
|
_ota_publish_callback: Optional[callable]
|
|
|
|
TOPIC_LEAF = 'esp'
|
|
|
|
def __init__(self,
|
|
devices: Union[MqttEspDevice, list[MqttEspDevice]],
|
|
subscribe_to_updates=True):
|
|
super().__init__(clean_session=True)
|
|
if not isinstance(devices, list):
|
|
devices = [devices]
|
|
self._devices = devices
|
|
self._message_callback = None
|
|
self._ota_publish_callback = None
|
|
self._subscribe_to_updates = subscribe_to_updates
|
|
self._ota_mid = None
|
|
|
|
def on_connect(self, client: mqtt.Client, userdata, flags, rc):
|
|
super().on_connect(client, userdata, flags, rc)
|
|
|
|
if self._subscribe_to_updates:
|
|
for device in self._devices:
|
|
topic = f'hk/{device.id}/{self.TOPIC_LEAF}/#'
|
|
self._logger.debug(f"subscribing to {topic}")
|
|
client.subscribe(topic, qos=1)
|
|
|
|
def on_publish(self, client: mqtt.Client, userdata, mid):
|
|
if self._ota_mid is not None and mid == self._ota_mid and self._ota_publish_callback:
|
|
self._ota_publish_callback()
|
|
|
|
def set_message_callback(self, callback: callable):
|
|
self._message_callback = callback
|
|
|
|
def on_message(self, client: mqtt.Client, userdata, msg):
|
|
try:
|
|
match = re.match(self.get_mqtt_topics(), msg.topic)
|
|
self._logger.debug(f'topic: {msg.topic}')
|
|
if not match:
|
|
return
|
|
|
|
device_id = match.group(1)
|
|
subtopic = match.group(2)
|
|
|
|
# try:
|
|
next(d for d in self._devices if d.id == device_id)
|
|
# except StopIteration:h
|
|
# return
|
|
|
|
message = None
|
|
if subtopic == 'stat':
|
|
message = DiagnosticsPayload.unpack(msg.payload)
|
|
elif subtopic == 'stat1':
|
|
message = InitialDiagnosticsPayload.unpack(msg.payload)
|
|
elif subtopic == 'otares':
|
|
message = OTAResultPayload.unpack(msg.payload)
|
|
|
|
if message and self._message_callback:
|
|
self._message_callback(device_id, message)
|
|
return True
|
|
|
|
except Exception as e:
|
|
self._logger.exception(str(e))
|
|
|
|
def push_ota(self,
|
|
device_id,
|
|
filename: str,
|
|
publish_callback: callable,
|
|
qos: int):
|
|
device = next(d for d in self._devices if d.id == device_id)
|
|
assert device.secret is not None, 'device secret not specified'
|
|
|
|
self._ota_publish_callback = publish_callback
|
|
payload = OTAPayload(secret=device.secret, filename=filename)
|
|
publish_result = self._client.publish(f'hk/{device.id}/{self.TOPIC_LEAF}/admin/ota',
|
|
payload=payload.pack(),
|
|
qos=qos)
|
|
self._ota_mid = publish_result.mid
|
|
self._client.loop_write()
|
|
|
|
@classmethod
|
|
def get_mqtt_topics(cls, additional_topics: Optional[list[str]] = None):
|
|
return rf'^hk/(.*?)/{cls.TOPIC_LEAF}/(stat|stat1|otares'+('|'+('|'.join(additional_topics)) if additional_topics else '')+')$' |