From bb801905be7b4cc0fe12b350cd2c04230e1f7062 Mon Sep 17 00:00:00 2001 From: Justin Schuh Date: Fri, 11 Jun 2021 17:41:15 -0700 Subject: [PATCH] spi_flash: Make SD updates more tolerant of flaky boards Validation passes if either of the following succeeds (in order): 1. Active firmware's raw dictionary changed after update 2. Checksum of firmware.cur matches expected Additionally, deletes firmware.bin if found after update succeeds. Signed-off-by: Justin Schuh --- klippy/msgproto.py | 2 + scripts/spi_flash/spi_flash.py | 75 ++++++++++++++++++++++++---------- 2 files changed, 55 insertions(+), 22 deletions(-) diff --git a/klippy/msgproto.py b/klippy/msgproto.py index 7391a6f4..5a4233e7 100644 --- a/klippy/msgproto.py +++ b/klippy/msgproto.py @@ -408,6 +408,8 @@ class MessageParser: except Exception as e: logging.exception("process_identify error") self._error("Error during identify: %s", str(e)) + def get_raw_data_dictionary(self): + return self.raw_identify_data def get_version_info(self): return self.version, self.build_versions def get_messages(self): diff --git a/scripts/spi_flash/spi_flash.py b/scripts/spi_flash/spi_flash.py index 137adf5a..e6c91b2b 100644 --- a/scripts/spi_flash/spi_flash.py +++ b/scripts/spi_flash/spi_flash.py @@ -788,6 +788,7 @@ class MCUConnection: self.connect_completion = None self.connected = False self.enumerations = {} + self.raw_dictionary = None def connect(self): output("Connecting to MCU..") @@ -814,6 +815,7 @@ class MCUConnection: "MCU Type mismatch: Build MCU = %s, Connected MCU = %s" % (build_mcu_type, mcu_type)) self.enumerations = msgparser.get_enumerations() + self.raw_dictionary = msgparser.get_raw_data_dictionary() def _do_serial_connect(self, eventtime): endtime = eventtime + 60. @@ -949,29 +951,56 @@ class MCUConnection: % (fw_path, sd_size, sd_chksm)) return sd_chksm - def verify_flash(self, req_chksm): + def verify_flash(self, req_chksm, old_dictionary): output("Verifying Flash...") - cur_fw_sha = hashlib.sha1() - cur_fw_path = self.board_config.get('current_firmware_path', - "FIRMWARE.CUR") - try: - with self.fatfs.open_file(cur_fw_path, 'r') as sd_f: - while True: - buf = sd_f.read(4096) - if not buf: - break - cur_fw_sha.update(buf) - except Exception: - msg = "Failed to read file %s" % (cur_fw_path,) - logging.exception(msg) - raise SPIFlashError(msg) - cur_fw_chksm = cur_fw_sha.hexdigest().upper() - if req_chksm == cur_fw_chksm: - output_line("Done") - output_line("Firmware Flash Successful") + validation_passed = False + msgparser = self._serial.get_msgparser() + cur_dictionary = msgparser.get_raw_data_dictionary() + # Check that the version changed + if cur_dictionary != old_dictionary: + output("Version updated...") + validation_passed = True else: - raise SPIFlashError("Checksum Mismatch: Got '%s', expected '%s'" - % (cur_fw_chksm, req_chksm)) + output("Version unchanged...") + # If the version didn't change, look for current firmware to checksum + cur_fw_sha = None + if not validation_passed: + cur_fw_path = self.board_config.get('current_firmware_path', + "FIRMWARE.CUR") + try: + with self.fatfs.open_file(cur_fw_path, 'r') as sd_f: + cur_fw_sha = hashlib.sha1() + while True: + buf = sd_f.read(4096) + if not buf: + break + cur_fw_sha.update(buf) + except Exception: + msg = "Failed to read file %s" % (cur_fw_path,) + logging.debug(msg) + output("Checksum skipped...") + if cur_fw_sha is not None: + cur_fw_chksm = cur_fw_sha.hexdigest().upper() + if req_chksm == cur_fw_chksm: + validation_passed = True + output("Checksum matched...") + else: + raise SPIFlashError("Checksum Mismatch: Got '%s', " + "expected '%s'" + % (cur_fw_chksm, req_chksm)) + if not validation_passed: + raise SPIFlashError("Validation failure.") + output_line("Done") + # Remove firmware file if MCU bootloader failed to rename. + if cur_fw_sha is None: + try: + fw_path = self.board_config.get('firmware_path', "firmware.bin") + self.fatfs.remove_item(fw_path) + output_line("Found and deleted %s after reset" % (fw_path,)) + except Exception: + pass + output_line("Firmware Flash Successful") + output_line("Current Firmware: %s" % (msgparser.get_version_info()[0],)) class SPIFlash: def __init__(self, args): @@ -988,6 +1017,7 @@ class SPIFlash: self.firmware_checksum = None self.task_complete = False self.need_upload = True + self.old_dictionary = None def _wait_for_reconnect(self): output("Waiting for device to reconnect...") @@ -1022,6 +1052,7 @@ class SPIFlash: # Reconnect and upload if not self.mcu_conn.connected: self.mcu_conn.connect() + self.old_dictionary = self.mcu_conn.raw_dictionary self.mcu_conn.configure_mcu(printfunc=output_line) self.firmware_checksum = self.mcu_conn.sdcard_upload() self.mcu_conn.reset() @@ -1031,7 +1062,7 @@ class SPIFlash: # Reconnect and verify self.mcu_conn.connect() self.mcu_conn.configure_mcu() - self.mcu_conn.verify_flash(self.firmware_checksum) + self.mcu_conn.verify_flash(self.firmware_checksum, self.old_dictionary) self.mcu_conn.reset() self.task_complete = True