msgproto: Avoid peeking into the msgproto class members

Update callers to only use exported methods of the msgproto objects.
This makes it easier to make internal changes to the code.

Signed-off-by: Kevin O'Connor <kevin@koconnor.net>
This commit is contained in:
Kevin O'Connor 2021-02-18 14:01:40 -05:00
parent 319c36df52
commit efa497dfd8
4 changed files with 78 additions and 60 deletions

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python2 #!/usr/bin/env python2
# Script to implement a test console with firmware over serial port # Script to implement a test console with firmware over serial port
# #
# Copyright (C) 2016,2017 Kevin O'Connor <kevin@koconnor.net> # Copyright (C) 2016-2021 Kevin O'Connor <kevin@koconnor.net>
# #
# This file may be distributed under the terms of the GNU GPLv3 license. # This file may be distributed under the terms of the GNU GPLv3 license.
import sys, optparse, os, re, logging import sys, optparse, os, re, logging
@ -54,11 +54,12 @@ class KeyboardReader:
self.output("="*20 + " attempting to connect " + "="*20) self.output("="*20 + " attempting to connect " + "="*20)
self.ser.connect() self.ser.connect()
msgparser = self.ser.get_msgparser() msgparser = self.ser.get_msgparser()
self.output("Loaded %d commands (%s / %s)" % ( message_count = len(msgparser.get_messages())
len(msgparser.messages_by_id), version, build_versions = msgparser.get_version_info()
msgparser.version, msgparser.build_versions)) self.output("Loaded %d commands (%s / %s)"
% (message_count, version, build_versions))
self.output("MCU config: %s" % (" ".join( self.output("MCU config: %s" % (" ".join(
["%s=%s" % (k, v) for k, v in msgparser.config.items()]))) ["%s=%s" % (k, v) for k, v in msgparser.get_constants().items()])))
self.clocksync.connect(self.ser) self.clocksync.connect(self.ser)
self.ser.handle_default = self.handle_default self.ser.handle_default = self.handle_default
self.ser.register_response(self.handle_output, '#output') self.ser.register_response(self.handle_output, '#output')
@ -137,9 +138,10 @@ class KeyboardReader:
def command_LIST(self, parts): def command_LIST(self, parts):
self.update_evals(self.reactor.monotonic()) self.update_evals(self.reactor.monotonic())
mp = self.ser.get_msgparser() mp = self.ser.get_msgparser()
cmds = [msgformat for msgid, msgtype, msgformat in mp.get_messages()
if msgtype == 'command']
out = "Available mcu commands:" out = "Available mcu commands:"
out += "\n ".join([""] + sorted([ out += "\n ".join([""] + sorted(cmds))
mp.messages_by_id[i].msgformat for i in mp.command_ids]))
out += "\nAvailable artificial commands:" out += "\nAvailable artificial commands:"
out += "\n ".join([""] + [n for n in sorted(self.local_commands)]) out += "\n ".join([""] + [n for n in sorted(self.local_commands)])
out += "\nAvailable local variables:" out += "\nAvailable local variables:"

View File

@ -1,6 +1,6 @@
# Interface to Klipper micro-controller code # Interface to Klipper micro-controller code
# #
# Copyright (C) 2016-2020 Kevin O'Connor <kevin@koconnor.net> # Copyright (C) 2016-2021 Kevin O'Connor <kevin@koconnor.net>
# #
# This file may be distributed under the terms of the GNU GPLv3 license. # This file may be distributed under the terms of the GNU GPLv3 license.
import sys, os, zlib, logging, math import sys, os, zlib, logging, math
@ -564,10 +564,11 @@ class MCU:
return config_params return config_params
def _log_info(self): def _log_info(self):
msgparser = self._serial.get_msgparser() msgparser = self._serial.get_msgparser()
message_count = len(msgparser.get_messages())
version, build_versions = msgparser.get_version_info()
log_info = [ log_info = [
"Loaded MCU '%s' %d commands (%s / %s)" % ( "Loaded MCU '%s' %d commands (%s / %s)"
self._name, len(msgparser.messages_by_id), % (self._name, message_count, version, build_versions),
msgparser.version, msgparser.build_versions),
"MCU '%s' config: %s" % (self._name, " ".join( "MCU '%s' config: %s" % (self._name, " ".join(
["%s=%s" % (k, v) for k, v in self.get_constants().items()]))] ["%s=%s" % (k, v) for k, v in self.get_constants().items()]))]
return "\n".join(log_info) return "\n".join(log_info)
@ -635,8 +636,9 @@ class MCU:
mbaud = msgparser.get_constant('SERIAL_BAUD', None) mbaud = msgparser.get_constant('SERIAL_BAUD', None)
if self._restart_method is None and mbaud is None and not ext_only: if self._restart_method is None and mbaud is None and not ext_only:
self._restart_method = 'command' self._restart_method = 'command'
self._get_status_info['mcu_version'] = msgparser.version version, build_versions = msgparser.get_version_info()
self._get_status_info['mcu_build_versions'] = msgparser.build_versions self._get_status_info['mcu_version'] = version
self._get_status_info['mcu_build_versions'] = build_versions
self._get_status_info['mcu_constants'] = msgparser.get_constants() self._get_status_info['mcu_constants'] = msgparser.get_constants()
self.register_response(self._handle_shutdown, 'shutdown') self.register_response(self._handle_shutdown, 'shutdown')
self.register_response(self._handle_shutdown, 'is_shutdown') self.register_response(self._handle_shutdown, 'is_shutdown')
@ -693,7 +695,8 @@ class MCU:
except self._serial.get_msgparser().error as e: except self._serial.get_msgparser().error as e:
return None return None
def lookup_command_id(self, msgformat): def lookup_command_id(self, msgformat):
return self._serial.get_msgparser().lookup_command(msgformat).msgid all_msgs = self._serial.get_msgparser().get_messages()
return {msgfmt: msgid for msgid, msgtype, msgfmt in all_msgs}[msgformat]
def get_enumerations(self): def get_enumerations(self):
return self._serial.get_msgparser().get_enumerations() return self._serial.get_msgparser().get_enumerations()
def get_constants(self): def get_constants(self):

View File

@ -1,6 +1,6 @@
# Protocol definitions for firmware communication # Protocol definitions for firmware communication
# #
# Copyright (C) 2016-2019 Kevin O'Connor <kevin@koconnor.net> # Copyright (C) 2016-2021 Kevin O'Connor <kevin@koconnor.net>
# #
# This file may be distributed under the terms of the GNU GPLv3 license. # This file may be distributed under the terms of the GNU GPLv3 license.
import json, zlib, logging import json, zlib, logging
@ -128,6 +128,25 @@ def lookup_params(msgformat, enumerations={}):
out.append((name, pt)) out.append((name, pt))
return out return out
# Lookup the message types for a debugging "output()" format string
def lookup_output_params(msgformat):
param_types = []
args = msgformat
while 1:
pos = args.find('%')
if pos < 0:
break
if pos+1 >= len(args) or args[pos+1] != '%':
for i in range(4):
t = MessageTypes.get(args[pos:pos+1+i])
if t is not None:
param_types.append(t)
break
else:
raise error("Invalid output format for '%s'" % (msgformat,))
args = args[pos+1:]
return param_types
# Update the message format to be compatible with python's % operator # Update the message format to be compatible with python's % operator
def convert_msg_format(msgformat): def convert_msg_format(msgformat):
for c in ['%u', '%i', '%hu', '%hi', '%c', '%.*s', '%*s']: for c in ['%u', '%i', '%hu', '%hi', '%c', '%.*s', '%*s']:
@ -177,21 +196,7 @@ class OutputFormat:
self.msgid = msgid self.msgid = msgid
self.msgformat = msgformat self.msgformat = msgformat
self.debugformat = convert_msg_format(msgformat) self.debugformat = convert_msg_format(msgformat)
self.param_types = [] self.param_types = lookup_output_params(msgformat)
args = msgformat
while 1:
pos = args.find('%')
if pos < 0:
break
if pos+1 >= len(args) or args[pos+1] != '%':
for i in range(4):
t = MessageTypes.get(args[pos:pos+1+i])
if t is not None:
self.param_types.append(t)
break
else:
raise error("Invalid output format for '%s'" % (msgformat,))
args = args[pos+1:]
def parse(self, s, pos): def parse(self, s, pos):
pos += 1 pos += 1
out = [] out = []
@ -219,7 +224,7 @@ class MessageParser:
def __init__(self): def __init__(self):
self.unknown = UnknownFormat() self.unknown = UnknownFormat()
self.enumerations = {} self.enumerations = {}
self.command_ids = [] self.messages = []
self.messages_by_id = {} self.messages_by_id = {}
self.messages_by_name = {} self.messages_by_name = {}
self.config = {} self.config = {}
@ -334,7 +339,7 @@ class MessageParser:
#logging.exception("Unable to encode") #logging.exception("Unable to encode")
raise error("Unable to encode: %s" % (msgname,)) raise error("Unable to encode: %s" % (msgname,))
return cmd return cmd
def _fill_enumerations(self, enumerations): def fill_enumerations(self, enumerations):
for add_name, add_enums in enumerations.items(): for add_name, add_enums in enumerations.items():
enums = self.enumerations.setdefault(add_name, {}) enums = self.enumerations.setdefault(add_name, {})
for enum, value in add_enums.items(): for enum, value in add_enums.items():
@ -352,12 +357,17 @@ class MessageParser:
start_value, count = value start_value, count = value
for i in range(count): for i in range(count):
enums[enum_root + str(start_enum + i)] = start_value + i enums[enum_root + str(start_enum + i)] = start_value + i
def _init_messages(self, messages, output_ids=[]): def _init_messages(self, messages, command_ids=[], output_ids=[]):
for msgformat, msgid in messages.items(): for msgformat, msgid in messages.items():
msgid = int(msgid) msgtype = 'response'
if msgid in output_ids: if msgid in command_ids:
msgtype = 'command'
elif msgid in output_ids:
msgtype = 'output'
self.messages.append((msgid, msgtype, msgformat))
if msgtype == 'output':
self.messages_by_id[msgid] = OutputFormat(msgid, msgformat) self.messages_by_id[msgid] = OutputFormat(msgid, msgformat)
continue else:
msg = MessageFormat(msgid, msgformat, self.enumerations) msg = MessageFormat(msgid, msgformat, self.enumerations)
self.messages_by_id[msgid] = msg self.messages_by_id[msgid] = msg
self.messages_by_name[msg.name] = msg self.messages_by_name[msg.name] = msg
@ -367,15 +377,15 @@ class MessageParser:
data = zlib.decompress(data) data = zlib.decompress(data)
self.raw_identify_data = data self.raw_identify_data = data
data = json.loads(data) data = json.loads(data)
self._fill_enumerations(data.get('enumerations', {})) self.fill_enumerations(data.get('enumerations', {}))
commands = data.get('commands') commands = data.get('commands')
responses = data.get('responses') responses = data.get('responses')
output = data.get('output', {}) output = data.get('output', {})
all_messages = dict(commands) all_messages = dict(commands)
all_messages.update(responses) all_messages.update(responses)
all_messages.update(output) all_messages.update(output)
self.command_ids = sorted(commands.values()) self._init_messages(all_messages, commands.values(),
self._init_messages(all_messages, output.values()) output.values())
self.config.update(data.get('config', {})) self.config.update(data.get('config', {}))
self.version = data.get('version', '') self.version = data.get('version', '')
self.build_versions = data.get('build_versions', '') self.build_versions = data.get('build_versions', '')
@ -384,6 +394,10 @@ class MessageParser:
except Exception as e: except Exception as e:
logging.exception("process_identify error") logging.exception("process_identify error")
raise error("Error during identify: %s" % (str(e),)) raise error("Error during identify: %s" % (str(e),))
def get_version_info(self):
return self.version, self.build_versions
def get_messages(self):
return list(self.messages)
def get_enumerations(self): def get_enumerations(self):
return dict(self.enumerations) return dict(self.enumerations)
def get_constants(self): def get_constants(self):

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python2 #!/usr/bin/env python2
# Script to handle build time requests embedded in C code. # Script to handle build time requests embedded in C code.
# #
# Copyright (C) 2016-2018 Kevin O'Connor <kevin@koconnor.net> # Copyright (C) 2016-2021 Kevin O'Connor <kevin@koconnor.net>
# #
# This file may be distributed under the terms of the GNU GPLv3 license. # This file may be distributed under the terms of the GNU GPLv3 license.
import sys, os, subprocess, optparse, logging, shlex, socket, time, traceback import sys, os, subprocess, optparse, logging, shlex, socket, time, traceback
@ -172,8 +172,8 @@ class HandleInitialPins:
if not self.initial_pins: if not self.initial_pins:
return [] return []
mp = msgproto.MessageParser() mp = msgproto.MessageParser()
mp._fill_enumerations(HandlerEnumerations.enumerations) mp.fill_enumerations(HandlerEnumerations.enumerations)
pinmap = mp.enumerations.get('pin', {}) pinmap = mp.get_enumerations().get('pin', {})
out = [] out = []
for p in self.initial_pins: for p in self.initial_pins:
flag = "IP_OUT_HIGH" flag = "IP_OUT_HIGH"
@ -304,13 +304,15 @@ class HandleCommandGeneration:
if msgid not in command_ids and msgid not in response_ids } if msgid not in command_ids and msgid not in response_ids }
if output: if output:
data['output'] = output data['output'] = output
def build_parser(self, parser, iscmd): def build_parser(self, msgid, msgformat, msgtype):
if parser.name == "#output": if msgtype == "output":
comment = "Output: " + parser.msgformat param_types = msgproto.lookup_output_params(msgformat)
comment = "Output: " + msgformat
else: else:
comment = parser.msgformat param_types = [t for name, t in msgproto.lookup_params(msgformat)]
comment = msgformat
params = '0' params = '0'
types = tuple([t.__class__.__name__ for t in parser.param_types]) types = tuple([t.__class__.__name__ for t in param_types])
if types: if types:
paramid = self.all_param_types.get(types) paramid = self.all_param_types.get(types)
if paramid is None: if paramid is None:
@ -322,15 +324,15 @@ class HandleCommandGeneration:
.msg_id=%d, .msg_id=%d,
.num_params=%d, .num_params=%d,
.param_types = %s, .param_types = %s,
""" % (comment, parser.msgid, len(types), params) """ % (comment, msgid, len(types), params)
if iscmd: if msgtype == 'response':
num_args = (len(types) + types.count('PT_progmem_buffer') num_args = (len(types) + types.count('PT_progmem_buffer')
+ types.count('PT_buffer')) + types.count('PT_buffer'))
out += " .num_args=%d," % (num_args,) out += " .num_args=%d," % (num_args,)
else: else:
max_size = min(msgproto.MESSAGE_MAX, max_size = min(msgproto.MESSAGE_MAX,
(msgproto.MESSAGE_MIN + 1 (msgproto.MESSAGE_MIN + 1
+ sum([t.max_length for t in parser.param_types]))) + sum([t.max_length for t in param_types])))
out += " .max_size=%d," % (max_size,) out += " .max_size=%d," % (max_size,)
return out return out
def generate_responses_code(self): def generate_responses_code(self):
@ -342,17 +344,15 @@ class HandleCommandGeneration:
msgid = self.msg_to_id[msg] msgid = self.msg_to_id[msg]
if msgid in did_output: if msgid in did_output:
continue continue
s = msg
did_output[msgid] = True did_output[msgid] = True
code = (' if (__builtin_strcmp(str, "%s") == 0)\n' code = (' if (__builtin_strcmp(str, "%s") == 0)\n'
' return &command_encoder_%s;\n' % (s, msgid)) ' return &command_encoder_%s;\n' % (msg, msgid))
if msgname is None: if msgname is None:
parser = msgproto.OutputFormat(msgid, msg) parsercode = self.build_parser(msgid, msg, 'output')
output_code.append(code) output_code.append(code)
else: else:
parser = msgproto.MessageFormat(msgid, msg) parsercode = self.build_parser(msgid, msg, 'command')
encoder_code.append(code) encoder_code.append(code)
parsercode = self.build_parser(parser, 0)
encoder_defs.append( encoder_defs.append(
"const struct command_encoder command_encoder_%s PROGMEM = {" "const struct command_encoder command_encoder_%s PROGMEM = {"
" %s\n};\n" % ( " %s\n};\n" % (
@ -392,8 +392,7 @@ ctr_lookup_output(const char *str)
funcname, flags, msgname = cmd_by_id[msgid] funcname, flags, msgname = cmd_by_id[msgid]
msg = self.messages_by_name[msgname] msg = self.messages_by_name[msgname]
externs[funcname] = 1 externs[funcname] = 1
parser = msgproto.MessageFormat(msgid, msg) parsercode = self.build_parser(msgid, msg, 'response')
parsercode = self.build_parser(parser, 1)
index.append(" {%s\n .flags=%s,\n .func=%s\n}," % ( index.append(" {%s\n .flags=%s,\n .func=%s\n}," % (
parsercode, flags, funcname)) parsercode, flags, funcname))
index = "".join(index).strip() index = "".join(index).strip()