diff --git a/mindinsight/mindconverter/graph_based_converter/generator/generator.py b/mindinsight/mindconverter/graph_based_converter/generator/generator.py index f39f8bbb..8f39e0d1 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/generator.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/generator.py @@ -137,7 +137,7 @@ class CodeStruct: self.new_line = f"{NEW_LINE * 2}" # define header of construct block - inputs = ['self'] + list(md_struct.construct_header_x.keys()) + inputs = ['self'] + list(md_struct.inputs_register.values()) self.new_line = f"{FIRST_LEVEL_INDENT}def construct({', '.join(inputs)}):" # add construct code lines to code line list. self.code_line_list += cons_lines diff --git a/mindinsight/mindconverter/graph_based_converter/generator/matcher.py b/mindinsight/mindconverter/graph_based_converter/generator/matcher.py index 7a0a7981..e29ef0a1 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/matcher.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/matcher.py @@ -26,15 +26,27 @@ class MatcherHelper: """Call in preprocess""" # allocate main model construct x prec_edges = main_model.external_precursor_nodes_names - default_x_str = "x" + graph_inputs = GlobalContext().onnx_graph_info.get('graph_inputs') inputs = dict() - for idx, edge in enumerate(prec_edges): - if not edge in inputs: - # idx-1 here as we have a x without index and another x0 in module inputs - # so the idx 0 position is the second input, not the first x. - inputs[edge] = "".join([default_x_str, str(idx-1)]) if idx > 0 else default_x_str + for edge in graph_inputs: + if not edge in inputs and edge in prec_edges: + regular_edge = MatcherHelper.regular_edge_name(edge) + inputs[edge] = regular_edge main_model.inputs_register = inputs + @staticmethod + def regular_edge_name(name: str) -> str: + """Regular the edge name to adapt the python grammar.""" + regular = "" + for char in name: + if char.isalpha() or char.isdigit(): + regular = f"{regular}{char}" + else: + regular = f"{regular}_" + if not regular[0].isalpha(): + regular = f"input_{regular}" + return regular + @staticmethod def get_public_parent_module(node_a: NodeStruct, node_b: NodeStruct): """Return the public parent module of both Node A and Node B.""" diff --git a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py index 690a1a9a..d12a7b3d 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/module_struct.py @@ -79,7 +79,7 @@ class ModuleStruct: self.rapid_reference = dict() # new vars for matcher - self.inputs_register = dict() # reg by sub + self.inputs_register = OrderedDict() # reg by sub self.outputs_register = OrderedDict() # reg by sub self.internal_outputs_collection = dict() # reg by sub diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index 9982870e..4bb70f44 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -496,6 +496,9 @@ class OnnxDataLoader: """Parse ONNX Graph Info For usage in generator.""" graph_inputs = [inp.name for inp in self.graph.input] graph_outputs = [out.name for out in self.graph.output] + for output_node in self.output_nodes: + if output_node not in graph_outputs: + raise ValueError(f"Unexpected Node {output_node} detected which should not be a graph output.") self._global_context.onnx_graph_info['graph_inputs'] = graph_inputs self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs