spi_flash: Make SD updates more tolerant of flaky boards

Validation passes if either of the following succeeds (in order):
1. Active firmware's raw dictionary changed after update
2. Checksum of firmware.cur matches expected

Additionally, deletes firmware.bin if found after update succeeds.

Signed-off-by: Justin Schuh <code@justinschuh.com>
This commit is contained in:
Justin Schuh 2021-06-11 17:41:15 -07:00 committed by KevinOConnor
parent 10e72c4b6f
commit bb801905be
2 changed files with 55 additions and 22 deletions

View File

@ -408,6 +408,8 @@ class MessageParser:
except Exception as e: except Exception as e:
logging.exception("process_identify error") logging.exception("process_identify error")
self._error("Error during identify: %s", str(e)) self._error("Error during identify: %s", str(e))
def get_raw_data_dictionary(self):
return self.raw_identify_data
def get_version_info(self): def get_version_info(self):
return self.version, self.build_versions return self.version, self.build_versions
def get_messages(self): def get_messages(self):

View File

@ -788,6 +788,7 @@ class MCUConnection:
self.connect_completion = None self.connect_completion = None
self.connected = False self.connected = False
self.enumerations = {} self.enumerations = {}
self.raw_dictionary = None
def connect(self): def connect(self):
output("Connecting to MCU..") output("Connecting to MCU..")
@ -814,6 +815,7 @@ class MCUConnection:
"MCU Type mismatch: Build MCU = %s, Connected MCU = %s" "MCU Type mismatch: Build MCU = %s, Connected MCU = %s"
% (build_mcu_type, mcu_type)) % (build_mcu_type, mcu_type))
self.enumerations = msgparser.get_enumerations() self.enumerations = msgparser.get_enumerations()
self.raw_dictionary = msgparser.get_raw_data_dictionary()
def _do_serial_connect(self, eventtime): def _do_serial_connect(self, eventtime):
endtime = eventtime + 60. endtime = eventtime + 60.
@ -949,13 +951,25 @@ class MCUConnection:
% (fw_path, sd_size, sd_chksm)) % (fw_path, sd_size, sd_chksm))
return sd_chksm return sd_chksm
def verify_flash(self, req_chksm): def verify_flash(self, req_chksm, old_dictionary):
output("Verifying Flash...") output("Verifying Flash...")
cur_fw_sha = hashlib.sha1() validation_passed = False
msgparser = self._serial.get_msgparser()
cur_dictionary = msgparser.get_raw_data_dictionary()
# Check that the version changed
if cur_dictionary != old_dictionary:
output("Version updated...")
validation_passed = True
else:
output("Version unchanged...")
# If the version didn't change, look for current firmware to checksum
cur_fw_sha = None
if not validation_passed:
cur_fw_path = self.board_config.get('current_firmware_path', cur_fw_path = self.board_config.get('current_firmware_path',
"FIRMWARE.CUR") "FIRMWARE.CUR")
try: try:
with self.fatfs.open_file(cur_fw_path, 'r') as sd_f: with self.fatfs.open_file(cur_fw_path, 'r') as sd_f:
cur_fw_sha = hashlib.sha1()
while True: while True:
buf = sd_f.read(4096) buf = sd_f.read(4096)
if not buf: if not buf:
@ -963,15 +977,30 @@ class MCUConnection:
cur_fw_sha.update(buf) cur_fw_sha.update(buf)
except Exception: except Exception:
msg = "Failed to read file %s" % (cur_fw_path,) msg = "Failed to read file %s" % (cur_fw_path,)
logging.exception(msg) logging.debug(msg)
raise SPIFlashError(msg) output("Checksum skipped...")
if cur_fw_sha is not None:
cur_fw_chksm = cur_fw_sha.hexdigest().upper() cur_fw_chksm = cur_fw_sha.hexdigest().upper()
if req_chksm == cur_fw_chksm: if req_chksm == cur_fw_chksm:
output_line("Done") validation_passed = True
output_line("Firmware Flash Successful") output("Checksum matched...")
else: else:
raise SPIFlashError("Checksum Mismatch: Got '%s', expected '%s'" raise SPIFlashError("Checksum Mismatch: Got '%s', "
"expected '%s'"
% (cur_fw_chksm, req_chksm)) % (cur_fw_chksm, req_chksm))
if not validation_passed:
raise SPIFlashError("Validation failure.")
output_line("Done")
# Remove firmware file if MCU bootloader failed to rename.
if cur_fw_sha is None:
try:
fw_path = self.board_config.get('firmware_path', "firmware.bin")
self.fatfs.remove_item(fw_path)
output_line("Found and deleted %s after reset" % (fw_path,))
except Exception:
pass
output_line("Firmware Flash Successful")
output_line("Current Firmware: %s" % (msgparser.get_version_info()[0],))
class SPIFlash: class SPIFlash:
def __init__(self, args): def __init__(self, args):
@ -988,6 +1017,7 @@ class SPIFlash:
self.firmware_checksum = None self.firmware_checksum = None
self.task_complete = False self.task_complete = False
self.need_upload = True self.need_upload = True
self.old_dictionary = None
def _wait_for_reconnect(self): def _wait_for_reconnect(self):
output("Waiting for device to reconnect...") output("Waiting for device to reconnect...")
@ -1022,6 +1052,7 @@ class SPIFlash:
# Reconnect and upload # Reconnect and upload
if not self.mcu_conn.connected: if not self.mcu_conn.connected:
self.mcu_conn.connect() self.mcu_conn.connect()
self.old_dictionary = self.mcu_conn.raw_dictionary
self.mcu_conn.configure_mcu(printfunc=output_line) self.mcu_conn.configure_mcu(printfunc=output_line)
self.firmware_checksum = self.mcu_conn.sdcard_upload() self.firmware_checksum = self.mcu_conn.sdcard_upload()
self.mcu_conn.reset() self.mcu_conn.reset()
@ -1031,7 +1062,7 @@ class SPIFlash:
# Reconnect and verify # Reconnect and verify
self.mcu_conn.connect() self.mcu_conn.connect()
self.mcu_conn.configure_mcu() self.mcu_conn.configure_mcu()
self.mcu_conn.verify_flash(self.firmware_checksum) self.mcu_conn.verify_flash(self.firmware_checksum, self.old_dictionary)
self.mcu_conn.reset() self.mcu_conn.reset()
self.task_complete = True self.task_complete = True