diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index bf9a83f9..15b13d9d 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -50,6 +50,8 @@ ARGUMENT_LENGTH_LIMIT = 512 EXPECTED_NUMBER = 1 +MIN_SCOPE_LENGTH = 2 + @unique class CodeFormatConfig(Enum): diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py index 72049646..d3ea4375 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py @@ -520,7 +520,7 @@ class GraphNode(abc.ABC): if input_type == InputType.TENSOR.value: ipt_args_settings_in_construct = ipt_args_in_construct elif input_type == InputType.LIST.value: - ipt_args_settings_in_construct = f"({ipt_args_in_construct})" + ipt_args_settings_in_construct = f"({ipt_args_in_construct},)" else: raise NodeInputTypeNotSupport(f"Input type[{input_type}] is not supported now.") else: diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py index e47519a1..4e68289a 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py @@ -14,6 +14,7 @@ # ============================================================================== """Define PyTorch graph.""" import re +from copy import deepcopy from typing import Dict, NoReturn from mindinsight.mindconverter.common.log import logger as log @@ -22,7 +23,8 @@ from .input_node import InputNode from .pytorch_graph_node import PyTorchGraphNode from .pytorch_graph_parser import PyTorchGraphParser -from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID +from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \ + MIN_SCOPE_LENGTH from ..constant import LEFT_BUCKET, RIGHT_BUCKET NONE_SCOPE_OP = { @@ -206,6 +208,8 @@ class PyTorchGraph(Graph): ) self.build_connection(node_input_name, node_name) + self._unmerge_multi_ipt_opt_script() + super(PyTorchGraph, self).build(input_shape=input_shape) self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape) @@ -227,13 +231,84 @@ class PyTorchGraph(Graph): input_node.set_successor_nodes(node_name) self._shape_dict[ipt_nd_name] = input_node.output_shape + if not self._shape_dict[node_name]: + self._shape_dict[node_name] = SCALAR_WITHOUT_SHAPE + ipt_shape = [] for p_nd in node.precursor_nodes: shp = self._shape_dict.get(p_nd) - ipt_shape.append(tuple(shp)) + ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp) self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape + def _generate_module(self): + """Generate modules.""" + module_dict = dict() + for node_key, _ in self._nodes_collection.items(): + node_key_in_scope = node_key.split(SEPARATOR_IN_SCOPE) + if len(node_key_in_scope) < MIN_SCOPE_LENGTH: + continue + + for idx in range(1, len(node_key_in_scope)): + node_key_module = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx]) + node_name = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx+1]) + if not module_dict.get(node_key_module, None): + module_dict[node_key_module] = {node_name} + else: + module_dict[node_key_module].add(node_name) + + return module_dict + + def _check_multi_ipt(self): + """Check whether multi-input exists.""" + module_dict = self._generate_module() + for _, nodes_per_module in module_dict.items(): + prcs_nodes_out_from_module = set() + for node_name in nodes_per_module: + node = self._nodes_collection.get(node_name, None) + if node: + prcs_nodes = node.precursor_nodes + else: + continue + + for prcs_node in prcs_nodes: + if prcs_node not in nodes_per_module: + prcs_node_module = SEPARATOR_IN_SCOPE.join(prcs_node.split(SEPARATOR_IN_SCOPE)[:-1]) + if prcs_node_module not in nodes_per_module: + prcs_nodes_out_from_module.add(prcs_node) + + if len(prcs_nodes_out_from_module) > 1: + return True + + return False + + def _unmerge_multi_ipt_opt_script(self): + """Unmerge all submodule.""" + if self._check_multi_ipt(): + for node_key, node_inst in deepcopy(self._nodes_collection).items(): + prsc_nodes = node_inst.precursor_nodes + scsr_nodes = node_inst.successor_nodes + + node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0], + prsc_node.split(SEPARATOR_IN_SCOPE)[-1])) + for prsc_node in deepcopy(prsc_nodes)] + node_inst.successor_nodes = [SEPARATOR_IN_SCOPE.join((scsr_node.split(SEPARATOR_IN_SCOPE)[0], + scsr_node.split(SEPARATOR_IN_SCOPE)[-1])) + for scsr_node in deepcopy(scsr_nodes)] + + reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0], + node_key.split(SEPARATOR_IN_SCOPE)[-1])) + + del self._nodes_collection[node_key] + self._nodes_collection[reduce_node_key] = node_inst + + for node_key, shape in deepcopy(self._shape_dict).items(): + reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0], + node_key.split(SEPARATOR_IN_SCOPE)[-1])) + + del self._shape_dict[node_key] + self._shape_dict[reduce_node_key] = shape + def sub_graph_merging(self): """ Merge split operation into one. diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py index 26935a75..bfa3fff3 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py @@ -48,18 +48,8 @@ class TFGraphParser(GraphParser): log.error(str(error)) raise error - try: - model = convert_tf_graph_to_onnx(model_path, - model_inputs=tf_input_nodes, - model_outputs=tf_output_nodes, - ) # need pass more args - - except ModuleNotFoundError: - error_msg = \ - "Cannot find model scripts in system path, " \ - "set `--project_path` to the path of model scripts folder correctly." - error = ModuleNotFoundError(error_msg) - log.error(str(error)) - raise error from None - + model = convert_tf_graph_to_onnx(model_path, + model_inputs=tf_input_nodes, + model_outputs=tf_output_nodes, + ) return model