| @@ -137,7 +137,7 @@ class CodeStruct: | |||||
| self.new_line = f"{NEW_LINE * 2}" | self.new_line = f"{NEW_LINE * 2}" | ||||
| # define header of construct block | # define header of construct block | ||||
| inputs = ['self'] + list(md_struct.construct_header_x.keys()) | |||||
| inputs = ['self'] + list(md_struct.inputs_register.values()) | |||||
| self.new_line = f"{FIRST_LEVEL_INDENT}def construct({', '.join(inputs)}):" | self.new_line = f"{FIRST_LEVEL_INDENT}def construct({', '.join(inputs)}):" | ||||
| # add construct code lines to code line list. | # add construct code lines to code line list. | ||||
| self.code_line_list += cons_lines | self.code_line_list += cons_lines | ||||
| @@ -26,15 +26,27 @@ class MatcherHelper: | |||||
| """Call in preprocess""" | """Call in preprocess""" | ||||
| # allocate main model construct x | # allocate main model construct x | ||||
| prec_edges = main_model.external_precursor_nodes_names | prec_edges = main_model.external_precursor_nodes_names | ||||
| default_x_str = "x" | |||||
| graph_inputs = GlobalContext().onnx_graph_info.get('graph_inputs') | |||||
| inputs = dict() | inputs = dict() | ||||
| for idx, edge in enumerate(prec_edges): | |||||
| if not edge in inputs: | |||||
| # idx-1 here as we have a x without index and another x0 in module inputs | |||||
| # so the idx 0 position is the second input, not the first x. | |||||
| inputs[edge] = "".join([default_x_str, str(idx-1)]) if idx > 0 else default_x_str | |||||
| for edge in graph_inputs: | |||||
| if not edge in inputs and edge in prec_edges: | |||||
| regular_edge = MatcherHelper.regular_edge_name(edge) | |||||
| inputs[edge] = regular_edge | |||||
| main_model.inputs_register = inputs | main_model.inputs_register = inputs | ||||
| @staticmethod | |||||
| def regular_edge_name(name: str) -> str: | |||||
| """Regular the edge name to adapt the python grammar.""" | |||||
| regular = "" | |||||
| for char in name: | |||||
| if char.isalpha() or char.isdigit(): | |||||
| regular = f"{regular}{char}" | |||||
| else: | |||||
| regular = f"{regular}_" | |||||
| if not regular[0].isalpha(): | |||||
| regular = f"input_{regular}" | |||||
| return regular | |||||
| @staticmethod | @staticmethod | ||||
| def get_public_parent_module(node_a: NodeStruct, node_b: NodeStruct): | def get_public_parent_module(node_a: NodeStruct, node_b: NodeStruct): | ||||
| """Return the public parent module of both Node A and Node B.""" | """Return the public parent module of both Node A and Node B.""" | ||||
| @@ -79,7 +79,7 @@ class ModuleStruct: | |||||
| self.rapid_reference = dict() | self.rapid_reference = dict() | ||||
| # new vars for matcher | # new vars for matcher | ||||
| self.inputs_register = dict() # reg by sub | |||||
| self.inputs_register = OrderedDict() # reg by sub | |||||
| self.outputs_register = OrderedDict() # reg by sub | self.outputs_register = OrderedDict() # reg by sub | ||||
| self.internal_outputs_collection = dict() # reg by sub | self.internal_outputs_collection = dict() # reg by sub | ||||
| @@ -496,6 +496,9 @@ class OnnxDataLoader: | |||||
| """Parse ONNX Graph Info For usage in generator.""" | """Parse ONNX Graph Info For usage in generator.""" | ||||
| graph_inputs = [inp.name for inp in self.graph.input] | graph_inputs = [inp.name for inp in self.graph.input] | ||||
| graph_outputs = [out.name for out in self.graph.output] | graph_outputs = [out.name for out in self.graph.output] | ||||
| for output_node in self.output_nodes: | |||||
| if output_node not in graph_outputs: | |||||
| raise ValueError(f"Unexpected Node {output_node} detected which should not be a graph output.") | |||||
| self._global_context.onnx_graph_info['graph_inputs'] = graph_inputs | self._global_context.onnx_graph_info['graph_inputs'] = graph_inputs | ||||
| self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs | self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs | ||||