2023-05-17 04:06:18 +03:00

345 lines
10 KiB
C++

#include <ESP8266httpUpdate.h>
#include <homekit/logging.h>
#include <homekit/wifi.h>
#include <homekit/util.h>
#include <homekit/mqtt.h>
#include "relay.h"
#include "mqtt.h"
#include "leds.h"
namespace homekit::mqtt {
static const char TOPIC_DIAGNOSTICS[] = "stat";
static const char TOPIC_INITIAL_DIAGNOSTICS[] = "stat1";
static const char TOPIC_OTA_RESPONSE[] = "otares";
static const char TOPIC_RELAY_POWER[] = "power";
static const char TOPIC_ADMIN_OTA[] = "admin/ota";
static const uint16_t MQTT_KEEPALIVE = 30;
enum class IncomingMessage {
UNKNOWN,
RELAY_POWER,
OTA
};
using namespace espMqttClientTypes;
#define MD5_SIZE 16
MQTT::MQTT() {
auto cfg = config::read();
homeId = String(cfg.flags.node_configured ? cfg.node_id : wifi::NODE_ID);
randomSeed(micros());
client.onConnect([&](bool sessionPresent) {
PRINTLN("mqtt: connected");
sendInitialDiagnostics();
subscribe(TOPIC_RELAY_POWER, 1);
subscribe(TOPIC_ADMIN_OTA);
});
client.onDisconnect([&](DisconnectReason reason) {
PRINTF("mqtt: disconnected, reason=%d\n", static_cast<int>(reason));
#ifdef DEBUG
if (reason == DisconnectReason::TLS_BAD_FINGERPRINT)
PRINTLN("reason: bad fingerprint");
#endif
if (ota.started()) {
PRINTLN("mqtt: update was in progress, canceling..");
ota.clean();
Update.end();
Update.clearError();
}
if (ota.readyToRestart) {
restartTimer.once(1, restart);
} else {
reconnectTimer.once(2, [&]() {
reconnect();
});
}
});
client.onSubscribe([&](uint16_t packetId, const SubscribeReturncode* returncodes, size_t len) {
PRINTF("mqtt: subscribe ack, packet_id=%d\n", packetId);
for (size_t i = 0; i < len; i++) {
PRINTF(" return code: %u\n", static_cast<unsigned int>(*(returncodes+i)));
}
});
client.onUnsubscribe([&](uint16_t packetId) {
PRINTF("mqtt: unsubscribe ack, packet_id=%d\n", packetId);
});
client.onMessage([&](const MessageProperties& properties, const char* topic, const uint8_t* payload, size_t len, size_t index, size_t total) {
PRINTF("mqtt: message received, topic=%s, qos=%d, dup=%d, retain=%d, len=%ul, index=%ul, total=%ul\n",
topic, properties.qos, (int)properties.dup, (int)properties.retain, len, index, total);
IncomingMessage msgType = IncomingMessage::UNKNOWN;
const char *ptr = topic + homeId.length() + 10;
String relevantTopic(ptr);
if (relevantTopic == TOPIC_RELAY_POWER)
msgType = IncomingMessage::RELAY_POWER;
else if (relevantTopic == TOPIC_ADMIN_OTA)
msgType = IncomingMessage::OTA;
if (len != total && msgType != IncomingMessage::OTA) {
PRINTLN("mqtt: received partial message, not supported");
return;
}
switch (msgType) {
case IncomingMessage::RELAY_POWER:
handleRelayPowerPayload(payload, total);
break;
case IncomingMessage::OTA:
if (ota.finished)
break;
handleAdminOtaPayload(properties.packetId, payload, len, index, total);
break;
case IncomingMessage::UNKNOWN:
PRINTF("error: invalid topic %s\n", topic);
break;
}
});
client.onPublish([&](uint16_t packetId) {
PRINTF("mqtt: publish ack, packet_id=%d\n", packetId);
if (ota.finished && packetId == ota.publishResultPacketId) {
ota.readyToRestart = true;
}
});
client.setServer(MQTT_SERVER, MQTT_PORT);
client.setClientId(MQTT_CLIENT_ID);
client.setCredentials(MQTT_USERNAME, MQTT_PASSWORD);
client.setCleanSession(true);
client.setFingerprint(MQTT_CA_FINGERPRINT);
client.setKeepAlive(MQTT_KEEPALIVE);
}
void MQTT::connect() {
reconnect();
}
void MQTT::reconnect() {
if (client.connected()) {
PRINTLN("warning: already connected");
return;
}
client.connect();
}
void MQTT::disconnect() {
// TODO test how this works???
reconnectTimer.detach();
client.disconnect();
}
uint16_t MQTT::publish(const String &topic, uint8_t *payload, size_t length) {
String fullTopic = "hk/" + homeId + "/relay/" + topic;
return client.publish(fullTopic.c_str(), 1, false, payload, length);
}
void MQTT::loop() {
client.loop();
}
uint16_t MQTT::subscribe(const String &topic, uint8_t qos) {
String fullTopic = "hk/" + homeId + "/relay/" + topic;
PRINTF("mqtt: subscribing to %s...\n", fullTopic.c_str());
uint16_t packetId = client.subscribe(fullTopic.c_str(), qos);
if (!packetId)
PRINTF("error: failed to subscribe to %s\n", fullTopic.c_str());
return packetId;
}
void MQTT::sendInitialDiagnostics() {
auto cfg = config::read();
InitialDiagnosticsPayload stat{
.ip = wifi::getIPAsInteger(),
.fw_version = CONFIG_FW_VERSION,
.rssi = wifi::getRSSI(),
.free_heap = ESP.getFreeHeap(),
.flags = DiagnosticsFlags{
.state = static_cast<uint8_t>(relay::getState() ? 1 : 0),
.config_changed_value_present = 1,
.config_changed = static_cast<uint8_t>(cfg.flags.node_configured ||
cfg.flags.wifi_configured ? 1 : 0)
}
};
publish(TOPIC_INITIAL_DIAGNOSTICS, reinterpret_cast<uint8_t*>(&stat), sizeof(stat));
diagnosticsStopWatch.save();
}
void MQTT::sendDiagnostics() {
DiagnosticsPayload stat{
.rssi = wifi::getRSSI(),
.free_heap = ESP.getFreeHeap(),
.flags = DiagnosticsFlags{
.state = static_cast<uint8_t>(relay::getState() ? 1 : 0),
.config_changed_value_present = 0,
.config_changed = 0
}
};
publish(TOPIC_DIAGNOSTICS, reinterpret_cast<uint8_t*>(&stat), sizeof(stat));
diagnosticsStopWatch.save();
}
uint16_t MQTT::sendOtaResponse(OTAResult status, uint8_t error_code) {
OTAResponse resp{
.status = status,
.error_code = error_code
};
return publish(TOPIC_OTA_RESPONSE, reinterpret_cast<uint8_t*>(&resp), sizeof(resp));
}
void MQTT::handleRelayPowerPayload(const uint8_t *payload, uint32_t length) {
if (length != sizeof(PowerPayload)) {
PRINTF("error: size of payload (%ul) does not match expected (%ul)\n",
length, sizeof(PowerPayload));
return;
}
auto pd = reinterpret_cast<const struct PowerPayload*>(payload);
if (strncmp(pd->secret, MQTT_SECRET, sizeof(pd->secret)) != 0) {
PRINTLN("error: invalid secret");
return;
}
if (pd->state == 1) {
PRINTLN("mqtt: turning relay on");
relay::setOn();
} else if (pd->state == 0) {
PRINTLN("mqtt: turning relay off");
relay::setOff();
} else {
PRINTLN("error: unexpected state value");
}
sendDiagnostics();
}
void MQTT::handleAdminOtaPayload(uint16_t packetId, const uint8_t *payload, size_t length, size_t index, size_t total) {
char md5[33];
char* md5Ptr = md5;
if (index != 0 && ota.dataPacketId != packetId) {
PRINTLN("mqtt/ota: non-matching packet id");
return;
}
Update.runAsync(true);
if (index == 0) {
if (length < CONFIG_NODE_SECRET_SIZE + MD5_SIZE) {
PRINTLN("mqtt/ota: failed to check secret, first packet size is too small");
return;
}
if (memcmp((const char*)payload, CONFIG_NODE_SECRET, CONFIG_NODE_SECRET_SIZE) != 0) {
PRINTLN("mqtt/ota: invalid secret");
return;
}
PRINTF("mqtt/ota: starting update, total=%ul\n", total-NODE_SECRET_SIZE);
for (int i = 0; i < MD5_SIZE; i++) {
md5Ptr += sprintf(md5Ptr, "%02x", *((unsigned char*)(payload+CONFIG_NODE_SECRET_SIZE+i)));
}
md5[32] = '\0';
PRINTF("mqtt/ota: md5 is %s\n", md5);
PRINTF("mqtt/ota: first packet is %ul bytes length\n", length);
md5[32] = '\0';
if (Update.isRunning()) {
Update.end();
Update.clearError();
}
if (!Update.setMD5(md5)) {
PRINTLN("mqtt/ota: setMD5 failed");
return;
}
ota.dataPacketId = packetId;
if (!Update.begin(total - CONFIG_NODE_SECRET_SIZE - MD5_SIZE)) {
ota.clean();
#ifdef DEBUG
Update.printError(Serial);
#endif
sendOtaResponse(OTAResult::UPDATE_ERROR, Update.getError());
}
ota.written = Update.write(const_cast<uint8_t*>(payload)+CONFIG_NODE_SECRET_SIZE + MD5_SIZE, length-CONFIG_NODE_SECRET_SIZE - MD5_SIZE);
ota.written += CONFIG_NODE_SECRET_SIZE + MD5_SIZE;
mcu_led->blink(1, 1);
PRINTF("mqtt/ota: updating %u/%u\n", ota.written, Update.size());
} else {
if (!Update.isRunning()) {
PRINTLN("mqtt/ota: update is not running");
return;
}
if (index == ota.written) {
size_t written;
if ((written = Update.write(const_cast<uint8_t*>(payload), length)) != length) {
PRINTF("mqtt/ota: error: tried to write %ul bytes, write() returned %ul\n",
length, written);
ota.clean();
Update.end();
Update.clearError();
sendOtaResponse(OTAResult::WRITE_ERROR);
return;
}
ota.written += length;
mcu_led->blink(1, 1);
PRINTF("mqtt/ota: updating %u/%u\n",
ota.written - CONFIG_NODE_SECRET_SIZE - MD5_SIZE,
Update.size());
} else {
PRINTF("mqtt/ota: position is invalid, expected %ul, got %ul\n", ota.written, index);
ota.clean();
Update.end();
Update.clearError();
}
}
if (Update.isFinished()) {
ota.dataPacketId = 0;
if (Update.end()) {
ota.finished = true;
ota.publishResultPacketId = sendOtaResponse(OTAResult::OK);
PRINTF("mqtt/ota: ok, otares packet_id=%d\n", ota.publishResultPacketId);
} else {
ota.clean();
PRINTF("mqtt/ota: error: %u\n", Update.getError());
#ifdef DEBUG
Update.printError(Serial);
#endif
Update.clearError();
sendOtaResponse(OTAResult::UPDATE_ERROR, Update.getError());
}
}
}
}