Wifi connections work!

This commit is contained in:
Ryan McCahan 2024-03-31 00:05:04 -06:00
parent 88444d3454
commit bb8f57b4e9
No known key found for this signature in database
GPG Key ID: 4AD93D9FB6994DFA
4 changed files with 111 additions and 71 deletions

View File

@ -223,7 +223,7 @@ void Span::pollTask() {
readSerial(cBuf,64); readSerial(cBuf,64);
if(strncmp(cBuf, "IMPROV", 6) == 0) { if(strncmp(cBuf, "IMPROV", 6) == 0) {
processImprovCommand(cBuf); processImprovCommand(cBuf, this);
} else { } else {
processSerialCommand(cBuf); processSerialCommand(cBuf);
} }

View File

@ -271,7 +271,6 @@ class Span{
void checkConnect(); // check WiFi connection; connect if needed void checkConnect(); // check WiFi connection; connect if needed
void commandMode(); // allows user to control and reset HomeSpan settings with the control button void commandMode(); // allows user to control and reset HomeSpan settings with the control button
void resetStatus(); // resets statusLED and calls statusCallback based on current HomeSpan status void resetStatus(); // resets statusLED and calls statusCallback based on current HomeSpan status
void reboot(); // reboots device
void printfAttributes(int flags=GET_VALUE|GET_META|GET_PERMS|GET_TYPE|GET_DESC); // writes Attributes JSON database to hapOut stream void printfAttributes(int flags=GET_VALUE|GET_META|GET_PERMS|GET_TYPE|GET_DESC); // writes Attributes JSON database to hapOut stream
@ -295,11 +294,11 @@ class Span{
const char *displayName=DEFAULT_DISPLAY_NAME, const char *displayName=DEFAULT_DISPLAY_NAME,
const char *hostNameBase=DEFAULT_HOST_NAME, const char *hostNameBase=DEFAULT_HOST_NAME,
const char *modelName=DEFAULT_MODEL_NAME); const char *modelName=DEFAULT_MODEL_NAME);
void reboot(); // reboots device
void poll(); // calls pollTask() with some error checking void poll(); // calls pollTask() with some error checking
void processSerialCommand(const char *c); // process command 'c' (typically from readSerial, though can be called with any 'c') void processSerialCommand(const char *c); // process command 'c' (typically from readSerial, though can be called with any 'c')
void processImprovCommand(const char *c); // process Improv-Serial command 'c' void processImprovCommand(const char *c, Span* span); // process Improv-Serial command 'c'
void Span::handleImprovCommand(improv::ImprovCommand cmd);
boolean updateDatabase(boolean updateMDNS=true); // updates HAP Configuration Number and Loop vector; if updateMDNS=true and config number has changed, re-broadcasts MDNS 'c#' record; returns true if config number changed boolean updateDatabase(boolean updateMDNS=true); // updates HAP Configuration Number and Loop vector; if updateMDNS=true and config number has changed, re-broadcasts MDNS 'c#' record; returns true if config number changed
boolean deleteAccessory(uint32_t aid); // deletes Accessory with matching aid; returns true if found, else returns false boolean deleteAccessory(uint32_t aid); // deletes Accessory with matching aid; returns true if found, else returns false

View File

@ -4,7 +4,9 @@
using namespace Utils; using namespace Utils;
using namespace improv; using namespace improv;
void Span::processImprovCommand(const char *c){ #define MAX_ATTEMPTS_WIFI_CONNECTION 10
void Span::processImprovCommand(const char *c, Span* span){
uint8_t len = c[8] + 10; uint8_t len = c[8] + 10;
// Copy the const char *c from the argument into a new character array // Copy the const char *c from the argument into a new character array
@ -12,55 +14,47 @@ void Span::processImprovCommand(const char *c){
strcpy(buffer, c); strcpy(buffer, c);
buffer[len-1] = c[len-1]; // Copy the checksum bit since it falls after the null terminator buffer[len-1] = c[len-1]; // Copy the checksum bit since it falls after the null terminator
Serial.print("Command: ");
for (size_t i = 0; i < len; i++) {
Serial.print("0x");
Serial.print(c[i] < 16 ? "0" : "");
Serial.print(c[i], HEX);
Serial.print(" ");
}
Serial.println();
Serial.println("Processing Improv Command " + String(buffer) + " length " + strlen(buffer));
Serial.print("Command: ");
for (size_t i = 0; i < len; i++) {
Serial.print("0x");
Serial.print(buffer[i] < 16 ? "0" : "");
Serial.print(buffer[i], HEX);
Serial.print(" ");
}
Serial.println();
Serial.println("Length " + String(len) + " char ");
Serial.println(buffer[len - 1], HEX);
Serial.println("Onwards...");
improv::parse_improv_serial_byte(len - 1, buffer[len - 1], (uint8_t *)c, [&](ImprovCommand command) { improv::parse_improv_serial_byte(len - 1, buffer[len - 1], (uint8_t *)c, [&](ImprovCommand command) {
improv::handleImprovCommand(command); improv::handleImprovCommand(command, span);
return true; return true;
}, [&](Error error) { }, [&](Error error) {
Serial.println("Error parsing Improv command"); LOG0("Error parsing Improv command");
}); });
} // Span::processImprovCommand } // Span::processImprovCommand
namespace improv { namespace improv {
void handleImprovCommand(improv::ImprovCommand cmd) { void handleImprovCommand(improv::ImprovCommand cmd, Span* span) {
switch (cmd.command) { switch(cmd.command) {
case Command::WIFI_SETTINGS: case Command::WIFI_SETTINGS:
Serial.println("WiFi Settings: "); Serial.println("WiFi Settings: ");
Serial.print(cmd.ssid.c_str()); Serial.print(cmd.ssid.c_str());
Serial.print(" "); Serial.print(" ");
Serial.println(cmd.password.c_str()); Serial.println(cmd.password.c_str());
//Span::setWifiCredentials(cmd.ssid.c_str(), cmd.password.c_str());
break;
case Command::GET_CURRENT_STATE:
if ((WiFi.status() == WL_CONNECTED)) {
sendImprovState(improv::State::STATE_PROVISIONED);
//std::vector<uint8_t> data = improv::build_rpc_response(improv::GET_CURRENT_STATE, getLocalUrl(), false);
//send_response(data);
// Attempt to use our credentials so we can see if they work
if (connectWifi(cmd.ssid.c_str(), cmd.password.c_str())) {
sendImprovState(improv::State::STATE_PROVISIONED);
std::vector<std::string> infos; // Empty vector, we could put an HTTP URL here as the next step for the user if we had one
std::vector<uint8_t> data = improv::build_rpc_response(improv::WIFI_SETTINGS, infos, false);
improv::sendImprovResponse(data);
span->setWifiCredentials(cmd.ssid.c_str(), cmd.password.c_str()); // Save credentials and reboot
delay(1000); // Give the serial on the other end a moment to process
span->reboot();
} else {
sendImprovState(improv::State::STATE_STOPPED);
improv::sendImprovError(improv::Error::ERROR_UNABLE_TO_CONNECT);
}
break;
case Command::GET_CURRENT_STATE:
if((WiFi.status() == WL_CONNECTED)) {
sendImprovState(improv::State::STATE_PROVISIONED);
std::vector<std::string> infos; // Empty vector, we could put an HTTP URL here as the next step for the user if we had one
std::vector<uint8_t> data = improv::build_rpc_response(improv::GET_CURRENT_STATE, infos, false);
improv::sendImprovResponse(data);
} else { } else {
sendImprovState(improv::State::STATE_AUTHORIZED); sendImprovState(improv::State::STATE_AUTHORIZED);
} }
@ -94,10 +88,31 @@ void handleImprovCommand(improv::ImprovCommand cmd) {
} }
} }
bool connectWifi(const char *ssid, const char *pwd) {
uint8_t attempts = 0;
WiFi.begin(ssid, pwd);
LOG2("Attempting to connect to WiFi SSID %s", ssid);
while (WiFi.status() != WL_CONNECTED) {
delay(1000);
LOG2("Attempting to connect to WiFi SSID %s, attempt #%s", ssid, String(attempts));
if (attempts > MAX_ATTEMPTS_WIFI_CONNECTION) {
LOG0("Failed to connect to WiFi");
WiFi.disconnect();
return false;
}
attempts++;
}
LOG0("Successfully connected to WiFi SSID %s", ssid);
return true;
} // connectWifi
void getAvailableWifiNetworks() { void getAvailableWifiNetworks() {
int networkNum = WiFi.scanNetworks(); int networkNum = WiFi.scanNetworks();
for (int id = 0; id < networkNum; ++id) { for(int id = 0; id < networkNum; ++id) {
std::vector<uint8_t> data = improv::build_rpc_response( std::vector<uint8_t> data = improv::build_rpc_response(
improv::GET_WIFI_NETWORKS, {WiFi.SSID(id), String(WiFi.RSSI(id)), (WiFi.encryptionType(id) == WIFI_AUTH_OPEN ? "NO" : "YES")}, false); improv::GET_WIFI_NETWORKS, {WiFi.SSID(id), String(WiFi.RSSI(id)), (WiFi.encryptionType(id) == WIFI_AUTH_OPEN ? "NO" : "YES")}, false);
improv::sendImprovResponse(data); improv::sendImprovResponse(data);
@ -117,14 +132,38 @@ void sendImprovState(improv::State state) {
data[8] = 1; data[8] = 1;
data[9] = state; data[9] = state;
uint8_t checksum = 0x00;
for(uint8_t d : data)
checksum += d;
data[10] = checksum;
Serial.write(data.data(), data.size());
Serial.println("Wrote ");
for(size_t i = 0; i < data.size(); i++) {
Serial.print("0x");
Serial.print(data[i] < 16 ? "0" : "");
Serial.print(data[i], HEX);
Serial.print(" ");
}
Serial.println();
} // sendImprovState
void sendImprovError(improv::Error error) {
std::vector<uint8_t> data = {'I', 'M', 'P', 'R', 'O', 'V'};
data.resize(11);
data[6] = improv::IMPROV_SERIAL_VERSION;
data[7] = improv::TYPE_ERROR_STATE;
data[8] = 1;
data[9] = error;
uint8_t checksum = 0x00; uint8_t checksum = 0x00;
for (uint8_t d : data) for (uint8_t d : data)
checksum += d; checksum += d;
data[10] = checksum; data[10] = checksum;
Serial.println("Writing " + String(data.size()) + " bytes to Improv");
Serial.write(data.data(), data.size()); Serial.write(data.data(), data.size());
} // sendImprovState }
void sendImprovResponse(std::vector<uint8_t> &response) { void sendImprovResponse(std::vector<uint8_t> &response) {
std::vector<uint8_t> data = {'I', 'M', 'P', 'R', 'O', 'V'}; std::vector<uint8_t> data = {'I', 'M', 'P', 'R', 'O', 'V'};
@ -135,14 +174,14 @@ void sendImprovResponse(std::vector<uint8_t> &response) {
data.insert(data.end(), response.begin(), response.end()); data.insert(data.end(), response.begin(), response.end());
uint8_t checksum = 0x00; uint8_t checksum = 0x00;
for (uint8_t d : data) for(uint8_t d : data)
checksum += d; checksum += d;
data.push_back(checksum); data.push_back(checksum);
Serial.write(data.data(), data.size()); Serial.write(data.data(), data.size());
Serial.println("Wrote "); Serial.println("Wrote ");
for (size_t i = 0; i < data.size(); i++) { for(size_t i = 0; i < data.size(); i++) {
Serial.print("0x"); Serial.print("0x");
Serial.print(data[i] < 16 ? "0" : ""); Serial.print(data[i] < 16 ? "0" : "");
Serial.print(data[i], HEX); Serial.print(data[i], HEX);
@ -161,26 +200,26 @@ ImprovCommand parse_improv_data(const uint8_t *data, size_t length, bool check_c
Command command = (Command) data[0]; Command command = (Command) data[0];
uint8_t data_length = data[1]; uint8_t data_length = data[1];
if (data_length != length - 2 - check_checksum) { if(data_length != length - 2 - check_checksum) {
improv_command.command = UNKNOWN; improv_command.command = UNKNOWN;
return improv_command; return improv_command;
} }
if (check_checksum) { if(check_checksum) {
uint8_t checksum = data[length - 1]; uint8_t checksum = data[length - 1];
uint32_t calculated_checksum = 0; uint32_t calculated_checksum = 0;
for (uint8_t i = 0; i < length - 1; i++) { for(uint8_t i = 0; i < length - 1; i++) {
calculated_checksum += data[i]; calculated_checksum += data[i];
} }
if ((uint8_t) calculated_checksum != checksum) { if((uint8_t) calculated_checksum != checksum) {
improv_command.command = BAD_CHECKSUM; improv_command.command = BAD_CHECKSUM;
return improv_command; return improv_command;
} }
} }
if (command == WIFI_SETTINGS) { if(command == WIFI_SETTINGS) {
uint8_t ssid_length = data[2]; uint8_t ssid_length = data[2];
uint8_t ssid_start = 3; uint8_t ssid_start = 3;
size_t ssid_end = ssid_start + ssid_length; size_t ssid_end = ssid_start + ssid_length;
@ -200,43 +239,43 @@ ImprovCommand parse_improv_data(const uint8_t *data, size_t length, bool check_c
bool parse_improv_serial_byte(size_t position, uint8_t byte, const uint8_t *buffer, bool parse_improv_serial_byte(size_t position, uint8_t byte, const uint8_t *buffer,
std::function<bool(ImprovCommand)> &&callback, std::function<void(Error)> &&on_error) { std::function<bool(ImprovCommand)> &&callback, std::function<void(Error)> &&on_error) {
if (position == 0) if(position == 0)
return byte == 'I'; return byte == 'I';
if (position == 1) if(position == 1)
return byte == 'M'; return byte == 'M';
if (position == 2) if(position == 2)
return byte == 'P'; return byte == 'P';
if (position == 3) if(position == 3)
return byte == 'R'; return byte == 'R';
if (position == 4) if(position == 4)
return byte == 'O'; return byte == 'O';
if (position == 5) if(position == 5)
return byte == 'V'; return byte == 'V';
if (position == 6) if(position == 6)
return byte == IMPROV_SERIAL_VERSION; return byte == IMPROV_SERIAL_VERSION;
if (position <= 8) if(position <= 8)
return true; return true;
uint8_t type = buffer[7]; uint8_t type = buffer[7];
uint8_t data_len = buffer[8]; uint8_t data_len = buffer[8];
if (position <= 8 + data_len) if(position <= 8 + data_len)
return true; return true;
if (position == 8 + data_len + 1) { if(position == 8 + data_len + 1) {
uint8_t checksum = 0x00; uint8_t checksum = 0x00;
for (size_t i = 0; i < position; i++) for(size_t i = 0; i < position; i++)
checksum += buffer[i]; checksum += buffer[i];
Serial.println("Checksum: " + String(checksum) + " Byte: " + String(byte)); Serial.println("Checksum: " + String(checksum) + " Byte: " + String(byte));
if (checksum != byte) { if(checksum != byte) {
on_error(ERROR_INVALID_RPC); on_error(ERROR_INVALID_RPC);
return false; return false;
} }
if (type == TYPE_RPC) { if(type == TYPE_RPC) {
auto command = parse_improv_data(&buffer[9], data_len, false); auto command = parse_improv_data(&buffer[9], data_len, false);
return callback(command); return callback(command);
} }
@ -249,7 +288,7 @@ std::vector<uint8_t> build_rpc_response(Command command, const std::vector<std::
std::vector<uint8_t> out; std::vector<uint8_t> out;
uint32_t length = 0; uint32_t length = 0;
out.push_back(command); out.push_back(command);
for (const auto &str : datum) { for(const auto &str : datum) {
uint8_t len = str.length(); uint8_t len = str.length();
length += len + 1; length += len + 1;
out.push_back(len); out.push_back(len);
@ -257,10 +296,10 @@ std::vector<uint8_t> build_rpc_response(Command command, const std::vector<std::
} }
out.insert(out.begin() + 1, length); out.insert(out.begin() + 1, length);
if (add_checksum) { if(add_checksum) {
uint32_t calculated_checksum = 0; uint32_t calculated_checksum = 0;
for (uint8_t byte : out) { for(uint8_t byte : out) {
calculated_checksum += byte; calculated_checksum += byte;
} }
out.push_back(calculated_checksum); out.push_back(calculated_checksum);
@ -273,7 +312,7 @@ std::vector<uint8_t> build_rpc_response(Command command, const std::vector<Strin
std::vector<uint8_t> out; std::vector<uint8_t> out;
uint32_t length = 0; uint32_t length = 0;
out.push_back(command); out.push_back(command);
for (const auto &str : datum) { for(const auto &str : datum) {
uint8_t len = str.length(); uint8_t len = str.length();
length += len; length += len;
out.push_back(len); out.push_back(len);
@ -281,10 +320,10 @@ std::vector<uint8_t> build_rpc_response(Command command, const std::vector<Strin
} }
out.insert(out.begin() + 1, length); out.insert(out.begin() + 1, length);
if (add_checksum) { if(add_checksum) {
uint32_t calculated_checksum = 0; uint32_t calculated_checksum = 0;
for (uint8_t byte : out) { for(uint8_t byte : out) {
calculated_checksum += byte; calculated_checksum += byte;
} }
out.push_back(calculated_checksum); out.push_back(calculated_checksum);

View File

@ -8,6 +8,7 @@
#include <functional> #include <functional>
#include <string> #include <string>
#include <vector> #include <vector>
#include "HomeSpan.h"
namespace improv { namespace improv {
@ -73,10 +74,11 @@ std::vector<uint8_t> build_rpc_response(Command command, const std::vector<std::
std::vector<uint8_t> build_rpc_response(Command command, const std::vector<String> &datum, bool add_checksum = true); std::vector<uint8_t> build_rpc_response(Command command, const std::vector<String> &datum, bool add_checksum = true);
#endif // ARDUINO #endif // ARDUINO
void handleImprovCommand(improv::ImprovCommand cmd); void handleImprovCommand(improv::ImprovCommand cmd, Span* span);
void sendImprovState(improv::State state); void sendImprovState(improv::State state);
void sendImprovResponse(std::vector<uint8_t> &response); void sendImprovResponse(std::vector<uint8_t> &response);
void sendImprovRPCResponse(std::vector<uint8_t> &response); void sendImprovError(improv::Error error);
void getAvailableWifiNetworks(); void getAvailableWifiNetworks();
bool connectWifi(const char *ssid, const char *pwd);
} // namespace improv } // namespace improv