1170 lines
39 KiB
Python
1170 lines
39 KiB
Python
# Polaris PWK 1725CGLD "smart" kettle python library
|
|
# --------------------------------------------------
|
|
# Copyright (C) Evgeny Zinoviev, 2022
|
|
# License: BSD-3c
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import socket
|
|
import random
|
|
import struct
|
|
import threading
|
|
import time
|
|
|
|
from abc import abstractmethod, ABC
|
|
from enum import Enum, auto
|
|
from typing import Union, Optional, Dict, Tuple, List
|
|
from ipaddress import IPv4Address, IPv6Address
|
|
|
|
import cryptography.hazmat.primitives._serialization as srlz
|
|
|
|
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
|
|
from cryptography.hazmat.primitives import ciphers, padding, hashes
|
|
from cryptography.hazmat.primitives.ciphers import algorithms, modes
|
|
|
|
ReprDict = Dict[str, Union[str, int, float, bool]]
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
PING_FREQUENCY = 3
|
|
RESEND_ATTEMPTS = 5
|
|
ERROR_TIMEOUT = 15
|
|
MESSAGE_QUEUE_REMOVE_DELAY = 13 # after what time to delete (and pass False to handlers, if needed) messages with phase=DONE from queue
|
|
DISCONNECT_TIMEOUT = 15
|
|
|
|
|
|
def safe_callback_call(f: callable,
|
|
*args,
|
|
logger: logging.Logger = None,
|
|
error_message: str = None):
|
|
try:
|
|
return f(*args)
|
|
except Exception as exc:
|
|
logger.error(f'{error_message}, see exception below:')
|
|
logger.exception(exc)
|
|
return None
|
|
|
|
|
|
# drop-in replacement for java.lang.System.arraycopy
|
|
# TODO: rewrite
|
|
def arraycopy(src, src_pos, dest, dest_pos, length):
|
|
for i in range(length):
|
|
dest[i + dest_pos] = src[i + src_pos]
|
|
|
|
|
|
# "convert" unsigned byte to signed
|
|
def u8_to_s8(b: int) -> int:
|
|
return struct.unpack('b', bytes([b]))[0]
|
|
|
|
|
|
class PowerType(Enum):
|
|
OFF = 0 # turn off
|
|
ON = 1 # turn on, set target temperature to 100
|
|
CUSTOM = 3 # turn on, allows custom target temperature
|
|
# MYSTERY_MODE = 2 # don't know what 2 means, needs testing
|
|
# update: if I set it to '2', it just resets to '0'
|
|
|
|
|
|
# low-level protocol structures
|
|
# -----------------------------
|
|
|
|
class FrameType(Enum):
|
|
ACK = 0
|
|
CMD = 1
|
|
AUX = 2
|
|
NAK = 3
|
|
|
|
|
|
class FrameHead:
|
|
seq: Optional[int] # u8
|
|
type: FrameType # u8
|
|
length: int # u16. This is the length of FrameItem's payload
|
|
|
|
@staticmethod
|
|
def from_bytes(buf: bytes) -> FrameHead:
|
|
seq, ft, length = struct.unpack('<BBH', buf)
|
|
return FrameHead(seq, FrameType(ft), length)
|
|
|
|
def __init__(self,
|
|
seq: Optional[int],
|
|
frame_type: FrameType,
|
|
length: Optional[int] = None):
|
|
self.seq = seq
|
|
self.type = frame_type
|
|
self.length = length or 0
|
|
|
|
def pack(self) -> bytes:
|
|
assert self.length != 0, "FrameHead.length has not been set"
|
|
assert self.seq is not None, "FrameHead.seq has not been set"
|
|
return struct.pack('<BBH', self.seq, self.type.value, self.length)
|
|
|
|
|
|
class FrameItem:
|
|
head: FrameHead
|
|
payload: bytes
|
|
|
|
def __init__(self, head: FrameHead, payload: Optional[bytes] = None):
|
|
self.head = head
|
|
self.payload = payload
|
|
|
|
def setpayload(self, payload: Union[bytes, bytearray]):
|
|
if isinstance(payload, bytearray):
|
|
payload = bytes(payload)
|
|
self.payload = payload
|
|
self.head.length = len(payload)
|
|
|
|
def pack(self) -> bytes:
|
|
ba = bytearray(self.head.pack())
|
|
ba.extend(self.payload)
|
|
return bytes(ba)
|
|
|
|
|
|
# high-level wrappers around FrameItem
|
|
# ------------------------------------
|
|
|
|
class MessagePhase(Enum):
|
|
WAITING = 0
|
|
SENT = 1
|
|
DONE = 2
|
|
|
|
|
|
class Message:
|
|
frame: Optional[FrameItem]
|
|
id: int
|
|
|
|
_global_id = 0
|
|
|
|
def __init__(self):
|
|
self.frame = None
|
|
|
|
# global internal message id, only useful for debugging purposes
|
|
self.id = self.next_id()
|
|
|
|
def __repr__(self):
|
|
return f'<{self.__class__.__name__} id={self.id} seq={self.frame.head.seq}>'
|
|
|
|
@staticmethod
|
|
def next_id():
|
|
_id = Message._global_id
|
|
Message._global_id = (Message._global_id + 1) % 100000
|
|
return _id
|
|
|
|
@staticmethod
|
|
def from_encrypted(buf: bytes, inkey: bytes, outkey: bytes) -> Message:
|
|
_logger.debug(f'Message:from_encrypted: buf={buf.hex()}')
|
|
|
|
assert len(buf) >= 4, 'invalid size'
|
|
head = FrameHead.from_bytes(buf[:4])
|
|
|
|
assert len(buf) == head.length + 4, f'invalid buf size ({len(buf)} != {head.length})'
|
|
payload = buf[4:]
|
|
b = head.seq
|
|
|
|
j = b & 0xF
|
|
k = b >> 4 & 0xF
|
|
|
|
key = bytearray(len(inkey))
|
|
arraycopy(inkey, j, key, 0, len(inkey) - j)
|
|
arraycopy(inkey, 0, key, len(inkey) - j, j)
|
|
|
|
iv = bytearray(len(outkey))
|
|
arraycopy(outkey, k, iv, 0, len(outkey) - k)
|
|
arraycopy(outkey, 0, iv, len(outkey) - k, k)
|
|
|
|
cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv))
|
|
decryptor = cipher.decryptor()
|
|
decrypted_data = decryptor.update(payload) + decryptor.finalize()
|
|
|
|
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
|
decrypted_data = unpadder.update(decrypted_data)
|
|
decrypted_data += unpadder.finalize()
|
|
|
|
assert len(decrypted_data) != 0, 'decrypted data is null'
|
|
assert head.seq == decrypted_data[0], f'decrypted seq mismatch {head.seq} != {decrypted_data[0]}'
|
|
|
|
# _logger.debug('Message.from_encrypted: plaintext: '+decrypted_data.hex())
|
|
|
|
if head.type == FrameType.ACK:
|
|
return AckMessage(head.seq)
|
|
|
|
elif head.type == FrameType.NAK:
|
|
return NakMessage(head.seq)
|
|
|
|
elif head.type == FrameType.AUX:
|
|
# TODO implement AUX
|
|
raise NotImplementedError('FrameType AUX is not yet implemented')
|
|
|
|
elif head.type == FrameType.CMD:
|
|
type = decrypted_data[1]
|
|
data = decrypted_data[2:]
|
|
|
|
cl = UnknownMessage
|
|
|
|
subclasses = [cl for cl in CmdIncomingMessage.__subclasses__() if cl is not SimpleBooleanMessage]
|
|
subclasses.extend(SimpleBooleanMessage.__subclasses__())
|
|
|
|
for _cl in subclasses:
|
|
# `UnknownMessage` is a special class that holds a packed command that we don't recognize.
|
|
# It will be used anyway if we don't find a match, so skip it here
|
|
if _cl == UnknownMessage:
|
|
continue
|
|
|
|
if _cl.TYPE == type:
|
|
cl = _cl
|
|
break
|
|
|
|
m = cl.from_packed_data(data, seq=head.seq)
|
|
if isinstance(m, UnknownMessage):
|
|
m.set_type(type)
|
|
return m
|
|
|
|
else:
|
|
raise NotImplementedError(f'Unexpected frame type: {head.type}')
|
|
|
|
def pack_data(self) -> bytes:
|
|
return b''
|
|
|
|
@property
|
|
def seq(self) -> Union[int, None]:
|
|
try:
|
|
return self.frame.head.seq
|
|
except:
|
|
return None
|
|
|
|
@seq.setter
|
|
def seq(self, seq: int):
|
|
self.frame.head.seq = seq
|
|
|
|
def encrypt(self, outkey: bytes, inkey: bytes, token: bytes, pubkey: bytes):
|
|
assert self.frame is not None
|
|
|
|
data = self._get_data_to_encrypt()
|
|
assert data is not None
|
|
|
|
b = self.frame.head.seq
|
|
i = b & 0xf
|
|
j = b >> 4 & 0xf
|
|
|
|
outkey = bytearray(outkey)
|
|
|
|
l = len(outkey)
|
|
key = bytearray(l)
|
|
|
|
arraycopy(outkey, i, key, 0, l-i)
|
|
arraycopy(outkey, 0, key, l-i, i)
|
|
|
|
inkey = bytearray(inkey)
|
|
|
|
l = len(inkey)
|
|
iv = bytearray(l)
|
|
|
|
arraycopy(inkey, j, iv, 0, l-j)
|
|
arraycopy(inkey, 0, iv, l-j, j)
|
|
|
|
cipher = ciphers.Cipher(algorithms.AES(key), modes.CBC(iv))
|
|
encryptor = cipher.encryptor()
|
|
|
|
newdata = bytearray(len(data)+1)
|
|
newdata[0] = b
|
|
|
|
arraycopy(data, 0, newdata, 1, len(data))
|
|
|
|
newdata = bytes(newdata)
|
|
_logger.debug('frame payload to be encrypted: ' + newdata.hex())
|
|
|
|
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
|
ciphertext = bytearray()
|
|
ciphertext.extend(encryptor.update(padder.update(newdata) + padder.finalize()))
|
|
ciphertext.extend(encryptor.finalize())
|
|
|
|
self.frame.setpayload(ciphertext)
|
|
|
|
def _get_data_to_encrypt(self) -> bytes:
|
|
return self.pack_data()
|
|
|
|
|
|
class AckMessage(Message, ABC):
|
|
def __init__(self, seq: Optional[int] = None):
|
|
super().__init__()
|
|
self.frame = FrameItem(FrameHead(seq, FrameType.ACK, None))
|
|
|
|
|
|
class NakMessage(Message, ABC):
|
|
def __init__(self, seq: Optional[int] = None):
|
|
super().__init__()
|
|
self.frame = FrameItem(FrameHead(seq, FrameType.NAK, None))
|
|
|
|
|
|
class CmdMessage(Message):
|
|
type: Optional[int]
|
|
data: bytes
|
|
|
|
TYPE = None
|
|
|
|
def _get_data_to_encrypt(self) -> bytes:
|
|
buf = bytearray()
|
|
buf.append(self.get_type())
|
|
buf.extend(self.pack_data())
|
|
return bytes(buf)
|
|
|
|
def __init__(self, seq: Optional[int] = None):
|
|
super().__init__()
|
|
self.frame = FrameItem(FrameHead(seq, FrameType.CMD))
|
|
self.data = b''
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {
|
|
'cmd': self.get_type()
|
|
}
|
|
|
|
def __repr__(self):
|
|
params = [
|
|
__name__+'.'+self.__class__.__name__,
|
|
f'id={self.id}',
|
|
f'seq={self.seq}'
|
|
]
|
|
fields = self._repr_fields()
|
|
if fields:
|
|
for k, v in fields.items():
|
|
params.append(f'{k}={v}')
|
|
elif self.data:
|
|
params.append(f'data={self.data.hex()}')
|
|
return '<'+' '.join(params)+'>'
|
|
|
|
def get_type(self) -> int:
|
|
return self.__class__.TYPE
|
|
|
|
|
|
class CmdIncomingMessage(CmdMessage):
|
|
@staticmethod
|
|
@abstractmethod
|
|
def from_packed_data(cls, data: bytes, seq: Optional[int] = None):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _repr_fields(self) -> ReprDict:
|
|
pass
|
|
|
|
|
|
class CmdOutgoingMessage(CmdMessage):
|
|
@abstractmethod
|
|
def pack_data(self) -> bytes:
|
|
return b''
|
|
|
|
|
|
class ModeMessage(CmdOutgoingMessage, CmdIncomingMessage):
|
|
TYPE = 1
|
|
|
|
pt: PowerType
|
|
|
|
def __init__(self, power_type: PowerType, seq: Optional[int] = None):
|
|
super().__init__(seq)
|
|
self.pt = power_type
|
|
|
|
@classmethod
|
|
def from_packed_data(cls, data: bytes, seq=0) -> ModeMessage:
|
|
assert len(data) == 1, 'data size expected to be 1'
|
|
mode, = struct.unpack('B', data)
|
|
return ModeMessage(PowerType(mode), seq=seq)
|
|
|
|
def pack_data(self) -> bytes:
|
|
return self.pt.value.to_bytes(1, byteorder='little')
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {'mode': self.pt.name}
|
|
|
|
|
|
class TargetTemperatureMessage(CmdOutgoingMessage, CmdIncomingMessage):
|
|
temperature: int
|
|
|
|
TYPE = 2
|
|
|
|
def __init__(self, temp: int, seq: Optional[int] = None):
|
|
super().__init__(seq)
|
|
self.temperature = temp
|
|
|
|
@classmethod
|
|
def from_packed_data(cls, data: bytes, seq=0) -> TargetTemperatureMessage:
|
|
assert len(data) == 2, 'data size expected to be 2'
|
|
nat, frac = struct.unpack('BB', data)
|
|
temp = int(nat + (frac / 100))
|
|
return TargetTemperatureMessage(temp, seq=seq)
|
|
|
|
def pack_data(self) -> bytes:
|
|
return bytes([self.temperature, 0])
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {'temperature': self.temperature}
|
|
|
|
|
|
class PingMessage(CmdIncomingMessage, CmdOutgoingMessage):
|
|
TYPE = 255
|
|
|
|
@classmethod
|
|
def from_packed_data(cls, data: bytes, seq=0) -> PingMessage:
|
|
assert len(data) == 0, 'no data expected'
|
|
return PingMessage(seq=seq)
|
|
|
|
def pack_data(self) -> bytes:
|
|
return b''
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {}
|
|
|
|
|
|
# This is the first protocol message. Sent by a client.
|
|
# Kettle usually ACKs this, but sometimes i don't get any ACK and the very next message is HandshakeResponseMessage.
|
|
class HandshakeMessage(CmdMessage):
|
|
TYPE = 0
|
|
|
|
def encrypt(self,
|
|
outkey: bytes,
|
|
inkey: bytes,
|
|
token: bytes,
|
|
pubkey: bytes):
|
|
cipher = ciphers.Cipher(algorithms.AES(outkey), modes.CBC(inkey))
|
|
encryptor = cipher.encryptor()
|
|
|
|
ciphertext = bytearray()
|
|
ciphertext.extend(encryptor.update(token))
|
|
ciphertext.extend(encryptor.finalize())
|
|
|
|
pld = bytearray()
|
|
pld.append(0)
|
|
pld.extend(pubkey)
|
|
pld.extend(ciphertext)
|
|
|
|
self.frame.setpayload(pld)
|
|
|
|
|
|
# Kettle either sends this right after the handshake, of first it ACKs the handshake then sends this.
|
|
class HandshakeResponseMessage(CmdIncomingMessage):
|
|
TYPE = 0
|
|
|
|
protocol: int
|
|
fw_major: int
|
|
fw_minor: int
|
|
mode: int
|
|
token: bytes
|
|
|
|
def __init__(self,
|
|
protocol: int,
|
|
fw_major: int,
|
|
fw_minor: int,
|
|
mode: int,
|
|
token: bytes,
|
|
seq: Optional[int] = None):
|
|
super().__init__(seq)
|
|
self.protocol = protocol
|
|
self.fw_major = fw_major
|
|
self.fw_minor = fw_minor
|
|
self.mode = mode
|
|
self.token = token
|
|
|
|
@classmethod
|
|
def from_packed_data(cls, data: bytes, seq=0) -> HandshakeResponseMessage:
|
|
protocol, fw_major, fw_minor, mode = struct.unpack('<HBBB', data[:5])
|
|
return HandshakeResponseMessage(protocol, fw_major, fw_minor, mode, token=data[5:], seq=seq)
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {
|
|
'protocol': self.protocol,
|
|
'fw': f'{self.fw_major}.{self.fw_minor}',
|
|
'mode': self.mode,
|
|
'token': self.token.hex()
|
|
}
|
|
|
|
|
|
# Apparently, some hardware info.
|
|
# On the other hand, if you look at com.syncleiot.iottransport.commands.CmdHardware, its mqtt topic says "mcu_firmware".
|
|
# My device returns 1.1.1. The kettle uses on ESP8266 ESP-12F MCU under the hood (or, more precisely, under a piece of
|
|
# cheap plastic), so maybe 1.1.1 is some MCU ROM version.
|
|
class DeviceHardwareMessage(CmdIncomingMessage):
|
|
TYPE = 143 # -113
|
|
|
|
hw: List[int]
|
|
|
|
def __init__(self, hw: List[int], seq: Optional[int] = None):
|
|
super().__init__(seq)
|
|
self.hw = hw
|
|
|
|
@classmethod
|
|
def from_packed_data(cls, data: bytes, seq=0) -> DeviceHardwareMessage:
|
|
assert len(data) == 3, 'invalid data size, expected 3'
|
|
hw = list(struct.unpack('<BBB', data))
|
|
return DeviceHardwareMessage(hw, seq=seq)
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {'device_hardware': '.'.join(map(str, self.hw))}
|
|
|
|
|
|
# This message is sent by kettle right after the HandshakeMessageResponse.
|
|
# The diagnostic data is supposed to be sent to vendor, which we, obviously, not going to do.
|
|
# So just ACK and skip it.
|
|
class DeviceDiagnosticMessage(CmdIncomingMessage):
|
|
TYPE = 145 # -111
|
|
|
|
diag_data: bytes
|
|
|
|
def __init__(self, diag_data: bytes, seq: Optional[int] = None):
|
|
super().__init__(seq)
|
|
self.diag_data = diag_data
|
|
|
|
@classmethod
|
|
def from_packed_data(cls, data: bytes, seq=0) -> DeviceDiagnosticMessage:
|
|
return DeviceDiagnosticMessage(diag_data=data, seq=seq)
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {'diag_data': self.diag_data.hex()}
|
|
|
|
|
|
class SimpleBooleanMessage(ABC, CmdIncomingMessage):
|
|
value: bool
|
|
|
|
def __init__(self, value: bool, seq: Optional[int] = None):
|
|
super().__init__(seq)
|
|
self.value = value
|
|
|
|
@classmethod
|
|
def from_packed_data(cls, data: bytes, seq: Optional[int] = None):
|
|
assert len(data) == 1, 'invalid data size, expected 1'
|
|
enabled, = struct.unpack('<B', data)
|
|
return cls(value=enabled == 1, seq=seq)
|
|
|
|
@abstractmethod
|
|
def _repr_fields(self) -> ReprDict:
|
|
pass
|
|
|
|
|
|
class AccessControlMessage(SimpleBooleanMessage):
|
|
TYPE = 133 # -123
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {'acl_enabled': self.value}
|
|
|
|
|
|
class ErrorMessage(SimpleBooleanMessage):
|
|
TYPE = 7
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {'error': self.value}
|
|
|
|
|
|
class ChildLockMessage(SimpleBooleanMessage):
|
|
TYPE = 30
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {'child_lock': self.value}
|
|
|
|
|
|
class VolumeMessage(SimpleBooleanMessage):
|
|
TYPE = 9
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {'volume': self.value}
|
|
|
|
|
|
class BacklightMessage(SimpleBooleanMessage):
|
|
TYPE = 28
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {'backlight': self.value}
|
|
|
|
|
|
class CurrentTemperatureMessage(CmdIncomingMessage):
|
|
TYPE = 20
|
|
|
|
current_temperature: int
|
|
|
|
def __init__(self, temp: int, seq: Optional[int] = None):
|
|
super().__init__(seq)
|
|
self.current_temperature = temp
|
|
|
|
@classmethod
|
|
def from_packed_data(cls, data: bytes, seq=0) -> CurrentTemperatureMessage:
|
|
assert len(data) == 2, 'data size expected to be 2'
|
|
nat, frac = struct.unpack('BB', data)
|
|
temp = int(nat + (frac / 100))
|
|
return CurrentTemperatureMessage(temp, seq=seq)
|
|
|
|
def pack_data(self) -> bytes:
|
|
return bytes([self.current_temperature, 0])
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {'current_temperature': self.current_temperature}
|
|
|
|
|
|
class UnknownMessage(CmdIncomingMessage):
|
|
type: Optional[int]
|
|
data: bytes
|
|
|
|
def __init__(self, data: bytes, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.type = None
|
|
self.data = data
|
|
|
|
@classmethod
|
|
def from_packed_data(cls, data: bytes, seq=0) -> UnknownMessage:
|
|
return UnknownMessage(data, seq=seq)
|
|
|
|
def set_type(self, type: int):
|
|
self.type = type
|
|
|
|
def get_type(self) -> int:
|
|
return self.type
|
|
|
|
def _repr_fields(self) -> ReprDict:
|
|
return {
|
|
'type': self.type,
|
|
'data': self.data.hex()
|
|
}
|
|
|
|
|
|
class WrappedMessage:
|
|
_message: Message
|
|
_handler: Optional[callable]
|
|
_validator: Optional[callable]
|
|
_logger: Optional[logging.Logger]
|
|
_phase: MessagePhase
|
|
_phase_update_time: float
|
|
|
|
def __init__(self,
|
|
message: Message,
|
|
handler: Optional[callable] = None,
|
|
validator: Optional[callable] = None,
|
|
ack=False):
|
|
self._message = message
|
|
self._handler = handler
|
|
self._validator = validator
|
|
self._logger = None
|
|
self._phase = MessagePhase.WAITING
|
|
self._phase_update_time = 0
|
|
if not validator and ack:
|
|
self._validator = lambda m: isinstance(m, AckMessage)
|
|
|
|
def setlogger(self, logger: logging.Logger):
|
|
self._logger = logger
|
|
|
|
def validate(self, message: Message):
|
|
if not self._validator:
|
|
return True
|
|
return self._validator(message)
|
|
|
|
def call(self, *args, error_message: str = None) -> None:
|
|
if not self._handler:
|
|
return
|
|
try:
|
|
self._handler(*args)
|
|
except Exception as exc:
|
|
logger = self._logger or logging.getLogger(self.__class__.__name__)
|
|
logger.error(f'{error_message}, see exception below:')
|
|
logger.exception(exc)
|
|
|
|
@property
|
|
def phase(self) -> MessagePhase:
|
|
return self._phase
|
|
|
|
@phase.setter
|
|
def phase(self, phase: MessagePhase):
|
|
self._phase = phase
|
|
self._phase_update_time = 0 if phase == MessagePhase.WAITING else time.time()
|
|
|
|
@property
|
|
def phase_update_time(self) -> float:
|
|
return self._phase_update_time
|
|
|
|
@property
|
|
def message(self) -> Message:
|
|
return self._message
|
|
|
|
@property
|
|
def id(self) -> int:
|
|
return self._message.id
|
|
|
|
@property
|
|
def seq(self) -> int:
|
|
return self._message.seq
|
|
|
|
@seq.setter
|
|
def seq(self, seq: int):
|
|
self._message.seq = seq
|
|
|
|
def __repr__(self):
|
|
return f'<{__name__}.{self.__class__.__name__} message={self._message.__repr__()}>'
|
|
|
|
|
|
# Connection stuff
|
|
# Well, strictly speaking, as it's UDP, there's no connection, but who cares.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class IncomingMessageListener:
|
|
@abstractmethod
|
|
def incoming_message(self, message: Message) -> Optional[Message]:
|
|
pass
|
|
|
|
|
|
class ConnectionStatus(Enum):
|
|
NOT_CONNECTED = auto()
|
|
CONNECTING = auto()
|
|
CONNECTED = auto()
|
|
RECONNECTING = auto()
|
|
DISCONNECTED = auto()
|
|
|
|
|
|
class ConnectionStatusListener:
|
|
@abstractmethod
|
|
def connection_status_updated(self, status: ConnectionStatus):
|
|
pass
|
|
|
|
|
|
class UDPConnection(threading.Thread, ConnectionStatusListener):
|
|
inseq: int
|
|
outseq: int
|
|
source_port: int
|
|
device_addr: str
|
|
device_port: int
|
|
device_token: bytes
|
|
device_pubkey: bytes
|
|
interrupted: bool
|
|
response_handlers: Dict[int, WrappedMessage]
|
|
outgoing_queue: List[WrappedMessage]
|
|
pubkey: Optional[bytes]
|
|
encinkey: Optional[bytes]
|
|
encoutkey: Optional[bytes]
|
|
inc_listeners: List[IncomingMessageListener]
|
|
conn_listeners: List[ConnectionStatusListener]
|
|
outgoing_time: float
|
|
outgoing_time_1st: float
|
|
incoming_time: float
|
|
status: ConnectionStatus
|
|
reconnect_tries: int
|
|
read_timeout: int
|
|
|
|
_addr_lock: threading.Lock
|
|
_iml_lock: threading.Lock
|
|
_csl_lock: threading.Lock
|
|
_st_lock: threading.Lock
|
|
|
|
def __init__(self,
|
|
addr: Union[IPv4Address, IPv6Address],
|
|
port: int,
|
|
device_pubkey: bytes,
|
|
device_token: bytes,
|
|
read_timeout: int = 1):
|
|
super().__init__()
|
|
self._logger = logging.getLogger(f'{__name__}.{self.__class__.__name__} <{hex(id(self))}>')
|
|
self.setName(self.__class__.__name__)
|
|
|
|
self.inseq = 0
|
|
self.outseq = 0
|
|
self.source_port = random.randint(1024, 65535)
|
|
self.device_addr = str(addr)
|
|
self.device_port = port
|
|
self.device_token = device_token
|
|
self.device_pubkey = device_pubkey
|
|
self.outgoing_queue = []
|
|
self.response_handlers = {}
|
|
self.interrupted = False
|
|
self.outgoing_time = 0
|
|
self.outgoing_time_1st = 0
|
|
self.incoming_time = 0
|
|
self.inc_listeners = []
|
|
self.conn_listeners = [self]
|
|
self.status = ConnectionStatus.NOT_CONNECTED
|
|
self.reconnect_tries = 0
|
|
self.read_timeout = read_timeout
|
|
|
|
self._iml_lock = threading.Lock()
|
|
self._csl_lock = threading.Lock()
|
|
self._addr_lock = threading.Lock()
|
|
self._st_lock = threading.Lock()
|
|
|
|
self.pubkey = None
|
|
self.encinkey = None
|
|
self.encoutkey = None
|
|
|
|
def connection_status_updated(self, status: ConnectionStatus):
|
|
# self._logger.info(f'connection_status_updated: status = {status}')
|
|
with self._st_lock:
|
|
# self._logger.debug(f'connection_status_updated: lock acquired')
|
|
self.status = status
|
|
if status == ConnectionStatus.RECONNECTING:
|
|
self.reconnect_tries += 1
|
|
if status in (ConnectionStatus.CONNECTED, ConnectionStatus.NOT_CONNECTED, ConnectionStatus.DISCONNECTED):
|
|
self.reconnect_tries = 0
|
|
|
|
def _cleanup(self):
|
|
# erase outgoing queue
|
|
for wm in self.outgoing_queue:
|
|
wm.call(False,
|
|
error_message=f'_cleanup: exception while calling cb(False) on message {wm.message}')
|
|
self.outgoing_queue = []
|
|
self.response_handlers = {}
|
|
|
|
# reset timestamps
|
|
self.incoming_time = 0
|
|
self.outgoing_time = 0
|
|
self.outgoing_time_1st = 0
|
|
|
|
self._logger.debug('_cleanup: done')
|
|
|
|
def set_address(self, addr: Union[IPv4Address, IPv6Address], port: int):
|
|
with self._addr_lock:
|
|
if self.device_addr != str(addr) or self.device_port != port:
|
|
self.device_addr = str(addr)
|
|
self.device_port = port
|
|
self._logger.info(f'updated device network address: {self.device_addr}:{self.device_port}')
|
|
|
|
def set_device_pubkey(self, pubkey: bytes):
|
|
if self.device_pubkey.hex() != pubkey.hex():
|
|
self._logger.info(f'device pubkey has changed (old={self.device_pubkey.hex()}, new={pubkey.hex()})')
|
|
self.device_pubkey = pubkey
|
|
self._notify_cs(ConnectionStatus.RECONNECTING)
|
|
|
|
def get_address(self) -> Tuple[str, int]:
|
|
with self._addr_lock:
|
|
return self.device_addr, self.device_port
|
|
|
|
def add_incoming_message_listener(self, listener: IncomingMessageListener):
|
|
with self._iml_lock:
|
|
if listener not in self.inc_listeners:
|
|
self.inc_listeners.append(listener)
|
|
|
|
def add_connection_status_listener(self, listener: ConnectionStatusListener):
|
|
with self._csl_lock:
|
|
if listener not in self.conn_listeners:
|
|
self.conn_listeners.append(listener)
|
|
|
|
def _notify_cs(self, status: ConnectionStatus):
|
|
# self._logger.debug(f'_notify_cs: status={status}')
|
|
with self._csl_lock:
|
|
for obj in self.conn_listeners:
|
|
# self._logger.debug(f'_notify_cs: notifying {obj}')
|
|
obj.connection_status_updated(status)
|
|
|
|
def _prepare_keys(self):
|
|
# generate key pair
|
|
privkey = X25519PrivateKey.generate()
|
|
|
|
self.pubkey = bytes(reversed(privkey.public_key().public_bytes(encoding=srlz.Encoding.Raw,
|
|
format=srlz.PublicFormat.Raw)))
|
|
|
|
# generate shared key
|
|
device_pubkey = X25519PublicKey.from_public_bytes(
|
|
bytes(reversed(self.device_pubkey))
|
|
)
|
|
shared_key = bytes(reversed(
|
|
privkey.exchange(device_pubkey)
|
|
))
|
|
|
|
# in/out encryption keys
|
|
digest = hashes.Hash(hashes.SHA256())
|
|
digest.update(shared_key)
|
|
|
|
shared_sha256 = digest.finalize()
|
|
|
|
self.encinkey = shared_sha256[:16]
|
|
self.encoutkey = shared_sha256[16:]
|
|
|
|
self._logger.info('encryption keys have been created')
|
|
|
|
def _handshake_callback(self, r: MessageResponse):
|
|
# if got error for our HandshakeMessage, reset everything and try again
|
|
if r is False:
|
|
# self._logger.debug('_handshake_callback: set status=RECONNETING')
|
|
self._notify_cs(ConnectionStatus.RECONNECTING)
|
|
else:
|
|
# self._logger.debug('_handshake_callback: set status=CONNECTED')
|
|
self._notify_cs(ConnectionStatus.CONNECTED)
|
|
|
|
def run(self):
|
|
self._logger.info('starting server loop')
|
|
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.bind(('0.0.0.0', self.source_port))
|
|
sock.settimeout(self.read_timeout)
|
|
|
|
while not self.interrupted:
|
|
with self._st_lock:
|
|
status = self.status
|
|
|
|
if status in (ConnectionStatus.DISCONNECTED, ConnectionStatus.RECONNECTING):
|
|
self._cleanup()
|
|
if status == ConnectionStatus.DISCONNECTED:
|
|
break
|
|
|
|
# no activity for some time means connection is broken
|
|
fail = False
|
|
fail_path = 0
|
|
if self.incoming_time > 0 and time.time() - self.incoming_time >= DISCONNECT_TIMEOUT:
|
|
fail = True
|
|
fail_path = 1
|
|
elif self.outgoing_time_1st > 0 and self.incoming_time == 0 and time.time() - self.outgoing_time_1st >= DISCONNECT_TIMEOUT:
|
|
fail = True
|
|
fail_path = 2
|
|
|
|
if fail:
|
|
self._logger.debug(f'run: setting status=RECONNECTING because of long inactivity, fail_path={fail_path}')
|
|
self._notify_cs(ConnectionStatus.RECONNECTING)
|
|
|
|
# establishing a connection
|
|
if status in (ConnectionStatus.RECONNECTING, ConnectionStatus.NOT_CONNECTED):
|
|
if status == ConnectionStatus.RECONNECTING and self.reconnect_tries >= 3:
|
|
self._notify_cs(ConnectionStatus.DISCONNECTED)
|
|
continue
|
|
|
|
self._reset_outseq()
|
|
self._prepare_keys()
|
|
|
|
# shake the imaginary kettle's hand
|
|
wrapped = WrappedMessage(HandshakeMessage(),
|
|
handler=self._handshake_callback,
|
|
validator=lambda m: isinstance(m, (AckMessage, HandshakeResponseMessage)))
|
|
self.enqueue_message(wrapped, prepend=True)
|
|
self._notify_cs(ConnectionStatus.CONNECTING)
|
|
|
|
# pick next (wrapped) message to send
|
|
wm = self._get_next_message() # wm means "wrapped message"
|
|
if wm:
|
|
one_shot = isinstance(wm.message, (AckMessage, NakMessage))
|
|
|
|
if not isinstance(wm.message, (AckMessage, NakMessage)):
|
|
old_seq = wm.seq
|
|
wm.seq = self.outseq
|
|
self._set_response_handler(wm, old_seq=old_seq)
|
|
elif wm.seq is None:
|
|
# ack/nak is a response to some incoming message (and it must have the same seqno that incoming
|
|
# message had)
|
|
raise RuntimeError(f'run: seq must be set for {wm.__class__.__name__}')
|
|
|
|
self._logger.debug(f'run: sending message: {wm.message}, one_shot={one_shot}, phase={wm.phase}')
|
|
encrypted = False
|
|
try:
|
|
wm.message.encrypt(outkey=self.encoutkey, inkey=self.encinkey,
|
|
token=self.device_token, pubkey=self.pubkey)
|
|
encrypted = True
|
|
except ValueError as exc:
|
|
# handle "ValueError: Invalid padding bytes."
|
|
self._logger.error('run: failed to encrypt the message.')
|
|
self._logger.exception(exc)
|
|
|
|
if encrypted:
|
|
buf = wm.message.frame.pack()
|
|
# self._logger.debug(f'run: raw data to be sent: {buf.hex()}')
|
|
|
|
# sending the first time
|
|
if wm.phase == MessagePhase.WAITING:
|
|
sock.sendto(buf, self.get_address())
|
|
# resending
|
|
elif wm.phase == MessagePhase.SENT:
|
|
left = RESEND_ATTEMPTS
|
|
while left > 0:
|
|
sock.sendto(buf, self.get_address())
|
|
left -= 1
|
|
if left > 0:
|
|
time.sleep(0.05)
|
|
|
|
if one_shot or wm.phase == MessagePhase.SENT:
|
|
wm.phase = MessagePhase.DONE
|
|
else:
|
|
wm.phase = MessagePhase.SENT
|
|
|
|
now = time.time()
|
|
self.outgoing_time = now
|
|
if not self.outgoing_time_1st:
|
|
self.outgoing_time_1st = now
|
|
|
|
# receiving data
|
|
try:
|
|
data = sock.recv(4096)
|
|
self._handle_incoming(data)
|
|
except (TimeoutError, socket.timeout):
|
|
pass
|
|
|
|
self._logger.info('bye...')
|
|
|
|
def _get_next_message(self) -> Optional[WrappedMessage]:
|
|
message = None
|
|
lpfx = '_get_next_message:'
|
|
remove_list = []
|
|
for wm in self.outgoing_queue:
|
|
if wm.phase == MessagePhase.DONE:
|
|
if isinstance(wm.message, (AckMessage, NakMessage, PingMessage)) or time.time() - wm.phase_update_time >= MESSAGE_QUEUE_REMOVE_DELAY:
|
|
remove_list.append(wm)
|
|
continue
|
|
message = wm
|
|
break
|
|
|
|
for wm in remove_list:
|
|
self._logger.debug(f'{lpfx} rm path: removing id={wm.id} seq={wm.seq}')
|
|
|
|
# clear message handler
|
|
if wm.seq in self.response_handlers:
|
|
self.response_handlers[wm.seq].call(
|
|
False, error_message=f'{lpfx} rm path: error while calling callback for seq={wm.seq}')
|
|
del self.response_handlers[wm.seq]
|
|
|
|
# remove from queue
|
|
try:
|
|
self.outgoing_queue.remove(wm)
|
|
except ValueError as exc:
|
|
self._logger.error(f'{lpfx} rm path: removing from outgoing_queue raised an exception: {str(exc)}')
|
|
|
|
# ping pong
|
|
if not message and self.outgoing_time_1st != 0 and self.status == ConnectionStatus.CONNECTED:
|
|
now = time.time()
|
|
out_delta = now - self.outgoing_time
|
|
in_delta = now - self.incoming_time
|
|
if max(out_delta, in_delta) > PING_FREQUENCY:
|
|
self._logger.debug(f'{lpfx} no activity: in for {in_delta:.2f}s, out for {out_delta:.2f}s, time to ping the damn thing')
|
|
message = WrappedMessage(PingMessage(), ack=True)
|
|
# add it to outgoing_queue in order to be aggressively resent in future (if needed)
|
|
self.outgoing_queue.insert(0, message)
|
|
|
|
return message
|
|
|
|
def _handle_incoming(self, buf: bytes):
|
|
try:
|
|
incoming_message = Message.from_encrypted(buf, inkey=self.encinkey, outkey=self.encoutkey)
|
|
except ValueError as exc:
|
|
# handle "ValueError: Invalid padding bytes."
|
|
self._logger.error('_handle_incoming: failed to decrypt incoming frame:')
|
|
self._logger.exception(exc)
|
|
return
|
|
|
|
self.incoming_time = time.time()
|
|
seq = incoming_message.seq
|
|
|
|
lpfx = f'handle_incoming({incoming_message.id}):'
|
|
self._logger.debug(f'{lpfx} received: {incoming_message}')
|
|
|
|
if isinstance(incoming_message, (AckMessage, NakMessage)):
|
|
seq_max = self.outseq
|
|
seq_name = 'outseq'
|
|
else:
|
|
seq_max = self.inseq
|
|
seq_name = 'inseq'
|
|
self.inseq = seq
|
|
|
|
if seq < seq_max < 0xfd:
|
|
self._logger.debug(f'{lpfx} dropping: seq={seq}, {seq_name}={seq_max}')
|
|
return
|
|
|
|
if seq not in self.response_handlers:
|
|
self._handle_incoming_cmd(incoming_message)
|
|
return
|
|
|
|
callback_value = None # None means don't call a callback
|
|
handler = self.response_handlers[seq]
|
|
|
|
if handler.validate(incoming_message):
|
|
self._logger.debug(f'{lpfx} response OK')
|
|
handler.phase = MessagePhase.DONE
|
|
callback_value = incoming_message
|
|
self._incr_outseq()
|
|
else:
|
|
self._logger.warning(f'{lpfx} response is INVALID')
|
|
|
|
# It seems that we've received an incoming CmdMessage or PingMessage with the same seqno that our outgoing
|
|
# message had. Bad, but what can I say, this is quick-and-dirty made UDP based protocol and this sort of
|
|
# shit just happens.
|
|
|
|
# (To be fair, maybe my implementation is not perfect either. But hey, what did you expect from a
|
|
# reverse-engineered re-implementation of custom UDP-based protocol that some noname vendor uses for their
|
|
# cheap IoT devices? I think _that_ is _the_ definition of shit. At least my implementation is FOSS, which
|
|
# is more than you'll ever be able to say about them.)
|
|
|
|
# All this crapload of code below might not be needed at all, 'cause the protocol uses separate frame seq
|
|
# numbers for IN and OUT frames and this situation is not highly likely, as Theresa May could argue.
|
|
# After a handshake, a kettle sends us 10 or so CmdMessages, and then either we continuously ping it every
|
|
# 3 seconds, or kettle pings us. This in any case widens the gap between inseq and outseq.
|
|
|
|
# But! the seqno is only 1 byte in size and once it reaches 0xff, it circles back to zero. And that (plus,
|
|
# perhaps, some bad luck) gives a chance for a collision.
|
|
|
|
if handler.phase == MessagePhase.DONE or isinstance(handler.message, HandshakeMessage):
|
|
# no more attempts left, returning error back to user
|
|
# as to handshake, it cannot fail.
|
|
callback_value = False
|
|
|
|
# else:
|
|
# # try resending the message
|
|
# handler.phase_reset()
|
|
# max_seq = self.outseq
|
|
# wait_remap = {}
|
|
# for m in self.outgoing_queue:
|
|
# if m.seq in self.waiting_for_response:
|
|
# wait_remap[m.seq] = (m.seq+1) % 256
|
|
# m.set_seq((m.seq+1) % 256)
|
|
# if m.seq > max_seq:
|
|
# max_seq = m.seq
|
|
# if max_seq > self.outseq:
|
|
# self.outseq = max_seq % 256
|
|
# if wait_remap:
|
|
# waiting_new = {}
|
|
# for old_seq, new_seq in wait_remap.items():
|
|
# waiting_new[new_seq] = self.waiting_for_response[old_seq]
|
|
# self.waiting_for_response = waiting_new
|
|
|
|
if isinstance(incoming_message, (PingMessage, CmdIncomingMessage)):
|
|
# handle incoming message as usual, as we need to ack/nak it anyway
|
|
self._handle_incoming_cmd(incoming_message)
|
|
|
|
if callback_value is not None:
|
|
handler.call(callback_value,
|
|
error_message=f'{lpfx} error while calling callback for msg id={handler.message.id} seq={seq}')
|
|
del self.response_handlers[seq]
|
|
|
|
def _handle_incoming_cmd(self, incoming_message: Message):
|
|
if isinstance(incoming_message, (AckMessage, NakMessage)):
|
|
self._logger.debug(f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): it\'s {incoming_message.__class__.__name__}, ignoring')
|
|
return
|
|
|
|
replied = False
|
|
with self._iml_lock:
|
|
for f in self.inc_listeners:
|
|
retval = safe_callback_call(f.incoming_message, incoming_message,
|
|
logger=self._logger,
|
|
error_message=f'_handle_incoming_cmd({incoming_message.id}, seq={incoming_message.seq}): error while calling message listener')
|
|
if isinstance(retval, Message):
|
|
if isinstance(retval, (AckMessage, NakMessage)):
|
|
retval.seq = incoming_message.seq
|
|
self.enqueue_message(WrappedMessage(retval), prepend=True)
|
|
replied = True
|
|
break
|
|
else:
|
|
raise RuntimeError('are you sure your response is correct? only ack/nak are allowed')
|
|
|
|
if not replied:
|
|
self.enqueue_message(WrappedMessage(AckMessage(incoming_message.seq)), prepend=True)
|
|
|
|
def enqueue_message(self, wrapped: WrappedMessage, prepend=False):
|
|
self._logger.debug(f'enqueue_message: {wrapped.message}')
|
|
if not prepend:
|
|
self.outgoing_queue.append(wrapped)
|
|
else:
|
|
self.outgoing_queue.insert(0, wrapped)
|
|
|
|
def _set_response_handler(self, wm: WrappedMessage, old_seq=None):
|
|
if old_seq in self.response_handlers:
|
|
del self.response_handlers[old_seq]
|
|
|
|
seq = wm.seq
|
|
assert seq is not None, 'seq is not set'
|
|
|
|
if seq in self.response_handlers:
|
|
self._logger.debug(f'_set_response_handler(seq={seq}): handler is already set, cancelling it')
|
|
self.response_handlers[seq].call(False,
|
|
error_message=f'_set_response_handler({seq}): error while calling old callback')
|
|
self.response_handlers[seq] = wm
|
|
|
|
def _incr_outseq(self) -> None:
|
|
self.outseq = (self.outseq + 1) % 256
|
|
|
|
def _reset_outseq(self):
|
|
self.outseq = 0
|
|
self._logger.debug(f'_reset_outseq: set 0')
|
|
|
|
|
|
MessageResponse = Union[Message, bool]
|