port relay_mqtt_http_proxy to new config scheme; config: support addr types & normalization

This commit is contained in:
Evgeny Zinoviev 2023-06-10 21:54:56 +03:00
parent f29e139cbb
commit 327a529835
11 changed files with 158 additions and 80 deletions

View File

@ -5,8 +5,8 @@ from typing import Optional
class ServicesListConfig(ConfigUnit):
NAME = 'services_list'
@staticmethod
def schema() -> Optional[dict]:
@classmethod
def schema(cls) -> Optional[dict]:
return {
'type': 'list',
'empty': False,
@ -19,8 +19,8 @@ class ServicesListConfig(ConfigUnit):
class LinuxBoardsConfig(ConfigUnit):
NAME = 'linux_boards'
@staticmethod
def schema() -> Optional[dict]:
@classmethod
def schema(cls) -> Optional[dict]:
return {
'type': 'dict',
'schema': {

View File

@ -1,10 +1,10 @@
import yaml
import logging
import os
import pprint
import cerberus
import cerberus.errors
from abc import ABC
from cerberus import Validator, DocumentError
from typing import Optional, Any, MutableMapping, Union
from argparse import ArgumentParser
from enum import Enum, auto
@ -12,11 +12,20 @@ from os.path import join, isdir, isfile
from ..util import Addr
class MyValidator(cerberus.Validator):
def _normalize_coerce_addr(self, value):
return Addr.fromstring(value)
MyValidator.types_mapping['addr'] = cerberus.TypeDefinition('Addr', (Addr,), ())
CONFIG_DIRECTORIES = (
join(os.environ['HOME'], '.config', 'homekit'),
'/etc/homekit'
)
class RootSchemaType(Enum):
DEFAULT = auto()
DICT = auto()
@ -95,10 +104,19 @@ class ConfigUnit(BaseConfigUnit):
raise IOError(f'\'{name}.yaml\' not found')
@staticmethod
def schema() -> Optional[dict]:
@classmethod
def schema(cls) -> Optional[dict]:
return None
@classmethod
def _addr_schema(cls, required=False, **kwargs):
return {
'type': 'addr',
'coerce': Addr.fromstring,
'required': required,
**kwargs
}
def validate(self):
schema = self.schema()
if not schema:
@ -109,7 +127,7 @@ class ConfigUnit(BaseConfigUnit):
schema['logging'] = {
'type': 'dict',
'schema': {
'logging': {'type': 'bool'}
'logging': {'type': 'boolean'}
}
}
@ -125,27 +143,27 @@ class ConfigUnit(BaseConfigUnit):
except KeyError:
pass
v = MyValidator()
if rst == RootSchemaType.DICT:
v = Validator({'document': {
normalized = v.validated({'document': self._data},
{'document': {
'type': 'dict',
'keysrules': {'type': 'string'},
'valuesrules': schema
}})
result = v.validate({'document': self._data})
}})['document']
elif rst == RootSchemaType.LIST:
v = Validator({'document': schema})
result = v.validate({'document': self._data})
v = MyValidator()
normalized = v.validated({'document': self._data}, {'document': schema})['document']
else:
v = Validator(schema)
result = v.validate(self._data)
# pprint.pprint(self._data)
if not result:
# pprint.pprint(v.errors)
raise DocumentError(f'{self.__class__.__name__}: failed to validate data:\n{pprint.pformat(v.errors)}')
normalized = v.validated(self._data, schema)
self._data = normalized
try:
self.custom_validator(self._data)
except Exception as e:
raise DocumentError(f'{self.__class__.__name__}: {str(e)}')
raise cerberus.DocumentError(f'{self.__class__.__name__}: {str(e)}')
@staticmethod
def custom_validator(data):
@ -238,7 +256,7 @@ class Config:
no_config=False):
global app_config
if issubclass(name, AppConfigUnit) or name == AppConfigUnit:
if not isinstance(name, str) and not isinstance(name, bool) and issubclass(name, AppConfigUnit) or name == AppConfigUnit:
self.app_name = name.NAME
self.app_config = name()
app_config = self.app_config
@ -278,6 +296,7 @@ class Config:
if not no_config:
self.app_config.load_from(path)
self.app_config.validate()
setup_logging(self.app_config.logging_is_verbose(),
self.app_config.logging_get_file(),

View File

@ -5,8 +5,8 @@ from typing import Optional
class InverterdConfig(ConfigUnit):
NAME = 'inverterd'
@staticmethod
def schema() -> Optional[dict]:
@classmethod
def schema(cls) -> Optional[dict]:
return {
'remote_addr': {'type': 'string'},
'local_addr': {'type': 'string'},

View File

@ -9,8 +9,8 @@ MqttCreds = namedtuple('MqttCreds', 'username, password')
class MqttConfig(ConfigUnit):
NAME = 'mqtt'
@staticmethod
def schema() -> Optional[dict]:
@classmethod
def schema(cls) -> Optional[dict]:
addr_schema = {
'type': 'dict',
'required': True,
@ -64,8 +64,8 @@ class MqttConfig(ConfigUnit):
class MqttNodesConfig(ConfigUnit):
NAME = 'mqtt_nodes'
@staticmethod
def schema() -> Optional[dict]:
@classmethod
def schema(cls) -> Optional[dict]:
return {
'common': {
'type': 'dict',

View File

@ -2,7 +2,6 @@ import paho.mqtt.client as mqtt
from ._mqtt import Mqtt
from ._node import MqttNode
from ..config import config
from ..util import strgen
@ -34,7 +33,9 @@ class MqttWrapper(Mqtt):
def on_message(self, client: mqtt.Client, userdata, msg):
try:
topic = msg.topic
topic_node = topic[len(self._topic_prefix)+1:topic.find('/', len(self._topic_prefix)+1)]
for node in self._nodes:
if node.id in ('+', topic_node):
node.on_message(topic[len(f'{self._topic_prefix}/{node.id}/'):], msg.payload)
except Exception as e:
self._logger.exception(str(e))

View File

@ -12,8 +12,8 @@ class TelegramUserListType(Enum):
class TelegramUserIdsConfig(ConfigUnit):
NAME = 'telegram_user_ids'
@staticmethod
def schema() -> Optional[dict]:
@classmethod
def schema(cls) -> Optional[dict]:
return {
'roottype': 'dict',
'type': 'integer'
@ -32,8 +32,8 @@ def _user_id_mapper(user: Union[str, int]) -> int:
class TelegramChatsConfig(ConfigUnit):
NAME = 'telegram_chats'
@staticmethod
def schema() -> Optional[dict]:
@classmethod
def schema(cls) -> Optional[dict]:
return {
'type': 'dict',
'schema': {
@ -44,8 +44,8 @@ class TelegramChatsConfig(ConfigUnit):
class TelegramBotConfig(ConfigUnit, ABC):
@staticmethod
def schema() -> Optional[dict]:
@classmethod
def schema(cls) -> Optional[dict]:
return {
'bot': {
'type': 'dict',

View File

@ -12,7 +12,7 @@ import re
from enum import Enum
from datetime import datetime
from typing import Tuple, Optional, List
from typing import Optional, List
from zlib import adler32
logger = logging.getLogger(__name__)
@ -38,26 +38,43 @@ def validate_ipv4_or_hostname(address: str, raise_exception: bool = False) -> bo
class Addr:
host: str
port: int
port: Optional[int]
def __init__(self, host: str, port: int):
def __init__(self, host: str, port: Optional[int] = None):
self.host = host
self.port = port
@staticmethod
def fromstring(addr: str) -> Addr:
if addr.count(':') != 1:
colons = addr.count(':')
if colons != 1:
raise ValueError('invalid host:port format')
if not colons:
host = addr
port= None
else:
host, port = addr.split(':')
validate_ipv4_or_hostname(host, raise_exception=True)
if port is not None:
port = int(port)
if not 0 <= port <= 65535:
raise ValueError(f'invalid port {port}')
return Addr(host, port)
def __str__(self):
buf = self.host
if self.port is not None:
buf += ':'+str(self.port)
return buf
def __iter__(self):
yield self.host
yield self.port
# https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks
def chunks(lst, n):

View File

@ -55,8 +55,8 @@ logger = logging.getLogger(__name__)
class InverterBotConfig(AppConfigUnit, TelegramBotConfig):
NAME = 'inverter_bot'
@staticmethod
def schema() -> Optional[dict]:
@classmethod
def schema(cls) -> Optional[dict]:
acmode_item_schema = {
'thresholds': {
'type': 'list',

View File

@ -32,8 +32,8 @@ class RelayMqttBotConfig(AppConfigUnit, TelegramBotConfig):
super().__init__()
self._strings = Translation('mqtt_nodes')
@staticmethod
def schema() -> Optional[dict]:
@classmethod
def schema(cls) -> Optional[dict]:
return {
**super(TelegramBotConfig).schema(),
'relay_nodes': {

View File

@ -1,24 +1,69 @@
#!/usr/bin/env python3
import logging
from home import http
from home.config import config
from home.mqtt import MqttPayload, MqttWrapper, MqttNode, MqttModule
from home.mqtt.module.relay import MqttRelayState, MqttRelayModule
from home.config import config, AppConfigUnit
from home.mqtt import MqttPayload, MqttWrapper, MqttNode, MqttModule, MqttNodesConfig
from home.mqtt.module.relay import MqttRelayState, MqttRelayModule, MqttPowerStatusPayload
from home.mqtt.module.diagnostics import InitialDiagnosticsPayload, DiagnosticsPayload
from typing import Optional, Union
logger = logging.getLogger(__name__)
mqtt: Optional[MqttWrapper] = None
mqtt_nodes: dict[str, MqttNode] = {}
relay_modules: dict[str, Union[MqttRelayModule, MqttModule]] = {}
relay_states: dict[str, MqttRelayState] = {}
mqtt_nodes_config = MqttNodesConfig()
class RelayMqttHttpProxyConfig(AppConfigUnit):
NAME = 'relay_mqtt_http_proxy'
@classmethod
def schema(cls) -> Optional[dict]:
return {
'relay_nodes': {
'type': 'list',
'required': True,
'schema': {
'type': 'string'
}
},
'listen_addr': cls._addr_schema(required=True)
}
@staticmethod
def custom_validator(data):
relay_node_names = mqtt_nodes_config.get_nodes(filters=('relay',), only_names=True)
for node in data['relay_nodes']:
if node not in relay_node_names:
raise ValueError(f'unknown relay node "{node}"')
def on_mqtt_message(node: MqttNode,
message: MqttPayload):
try:
is_legacy = mqtt_nodes_config[node.id]['relay']['legacy_topics']
logger.debug(f'on_mqtt_message: relay {node.id} uses legacy topic names')
except KeyError:
is_legacy = False
kwargs = {}
if isinstance(message, InitialDiagnosticsPayload) or isinstance(message, DiagnosticsPayload):
kwargs = dict(rssi=message.rssi, enabled=message.flags.state)
if device_id not in relay_states:
relay_states[device_id] = MqttRelayState()
relay_states[device_id].update(**kwargs)
kwargs['rssi'] = message.rssi
if is_legacy:
kwargs['enabled'] = message.flags.state
if not is_legacy and isinstance(message, MqttPowerStatusPayload):
kwargs['enabled'] = message.opened
if len(kwargs):
logger.debug(f'on_mqtt_message: {node.id}: going to update relay state: {str(kwargs)}')
if node.id not in relay_states:
relay_states[node.id] = MqttRelayState()
relay_states[node.id].update(**kwargs)
class RelayMqttHttpProxy(http.HTTPServer):
@ -44,7 +89,6 @@ class RelayMqttHttpProxy(http.HTTPServer):
cur_state = False
enable = not cur_state
if not node.secret:
node.secret = node_secret
relay_module.switchpower(enable)
return self.ok()
@ -60,20 +104,29 @@ class RelayMqttHttpProxy(http.HTTPServer):
if __name__ == '__main__':
config.load_app('relay_mqtt_http_proxy')
config.load_app(RelayMqttHttpProxyConfig)
mqtt = MqttWrapper()
for device_id, data in config['relays'].items():
mqtt_node = MqttNode(node_id=device_id)
relay_modules[device_id] = mqtt_node.load_module('relay')
mqtt_nodes[device_id] = mqtt_node
mqtt = MqttWrapper(client_id='relay_mqtt_http_proxy',
randomize_client_id=True)
for node_id in config.app_config['relay_nodes']:
node_data = mqtt_nodes_config.get_node(node_id)
mqtt_node = MqttNode(node_id=node_id)
module_kwargs = {}
try:
if node_data['relay']['legacy_topics']:
module_kwargs['legacy_topics'] = True
except KeyError:
pass
relay_modules[node_id] = mqtt_node.load_module('relay', **module_kwargs)
if 'legacy_topics' in module_kwargs:
mqtt_node.load_module('diagnostics')
mqtt_node.add_payload_callback(on_mqtt_message)
mqtt.add_node(mqtt_node)
mqtt_node.add_payload_callback(on_mqtt_message)
mqtt_nodes[node_id] = mqtt_node
mqtt.connect_and_loop(loop_forever=False)
proxy = RelayMqttHttpProxy(config.get_addr('server.listen'))
proxy = RelayMqttHttpProxy(config.app_config['listen_addr'])
try:
proxy.run()
except KeyboardInterrupt:

View File

@ -1,12 +0,0 @@
#!/usr/bin/env python3
from home.config import config
from home.mqtt import MqttNodesConfig
from home.telegram.config import TelegramUserIdsConfig
from pprint import pprint
if __name__ == '__main__':
config.load_app(name=False)
c = TelegramUserIdsConfig()
pprint(c.get())