Browse Source

add graph outputs check with user provided and customized model inputs name with user provided

tags/v1.2.0-rc1
liangtianshu 4 years ago
parent
commit
a06d08d3d8
4 changed files with 23 additions and 8 deletions
  1. +1
    -1
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  2. +18
    -6
      mindinsight/mindconverter/graph_based_converter/generator/matcher.py
  3. +1
    -1
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  4. +3
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -137,7 +137,7 @@ class CodeStruct:
self.new_line = f"{NEW_LINE * 2}"

# define header of construct block
inputs = ['self'] + list(md_struct.construct_header_x.keys())
inputs = ['self'] + list(md_struct.inputs_register.values())
self.new_line = f"{FIRST_LEVEL_INDENT}def construct({', '.join(inputs)}):"
# add construct code lines to code line list.
self.code_line_list += cons_lines


+ 18
- 6
mindinsight/mindconverter/graph_based_converter/generator/matcher.py View File

@@ -26,15 +26,27 @@ class MatcherHelper:
"""Call in preprocess"""
# allocate main model construct x
prec_edges = main_model.external_precursor_nodes_names
default_x_str = "x"
graph_inputs = GlobalContext().onnx_graph_info.get('graph_inputs')
inputs = dict()
for idx, edge in enumerate(prec_edges):
if not edge in inputs:
# idx-1 here as we have a x without index and another x0 in module inputs
# so the idx 0 position is the second input, not the first x.
inputs[edge] = "".join([default_x_str, str(idx-1)]) if idx > 0 else default_x_str
for edge in graph_inputs:
if not edge in inputs and edge in prec_edges:
regular_edge = MatcherHelper.regular_edge_name(edge)
inputs[edge] = regular_edge
main_model.inputs_register = inputs

@staticmethod
def regular_edge_name(name: str) -> str:
"""Regular the edge name to adapt the python grammar."""
regular = ""
for char in name:
if char.isalpha() or char.isdigit():
regular = f"{regular}{char}"
else:
regular = f"{regular}_"
if not regular[0].isalpha():
regular = f"input_{regular}"
return regular

@staticmethod
def get_public_parent_module(node_a: NodeStruct, node_b: NodeStruct):
"""Return the public parent module of both Node A and Node B."""


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -79,7 +79,7 @@ class ModuleStruct:
self.rapid_reference = dict()

# new vars for matcher
self.inputs_register = dict() # reg by sub
self.inputs_register = OrderedDict() # reg by sub
self.outputs_register = OrderedDict() # reg by sub
self.internal_outputs_collection = dict() # reg by sub



+ 3
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -496,6 +496,9 @@ class OnnxDataLoader:
"""Parse ONNX Graph Info For usage in generator."""
graph_inputs = [inp.name for inp in self.graph.input]
graph_outputs = [out.name for out in self.graph.output]
for output_node in self.output_nodes:
if output_node not in graph_outputs:
raise ValueError(f"Unexpected Node {output_node} detected which should not be a graph output.")
self._global_context.onnx_graph_info['graph_inputs'] = graph_inputs
self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs



Loading…
Cancel
Save