diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py index 4efcf31d..91825732 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/utils.py +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -81,7 +81,7 @@ def build_feed_dict(onnx_model, input_nodes: dict): for node in onnx_model.graph.input } feed_dict = { - name: np.random.rand(*shape).astype(input_nodes_types[name.split(":")[0]]) + name: np.random.rand(*shape).astype(input_nodes_types[name]) for name, shape in input_nodes.items() } return feed_dict diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 35c1052d..e1d822f8 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -83,7 +83,7 @@ def torch_installation_validation(func): type, inner function. """ - def _f(graph_path: str, sample_shape: tuple, input_nodes: str, output_nodes: str, + def _f(graph_path: str, input_nodes: dict, output_nodes: List[str], output_folder: str, report_folder: str = None): # Check whether pytorch is installed. error_info = None @@ -119,7 +119,7 @@ def torch_installation_validation(func): _print_error(error) sys.exit(0) - func(graph_path=graph_path, sample_shape=sample_shape, + func(graph_path=graph_path, input_nodes=input_nodes, output_nodes=output_nodes, output_folder=output_folder, report_folder=report_folder) @@ -265,11 +265,12 @@ def main_graph_base_converter(file_config): if not file_config.get("shape"): raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") - if graph_path.endswith("pth") and not file_config['input_nodes'] and \ - file_config.get("shape") and len(file_config.get("shape")) == 1: + if graph_path.endswith("pth") and not file_config.get("input_nodes", []) and \ + file_config.get("shape") and len(file_config.get("shape", ())) == 1: file_config['input_nodes'] = ["input.1"] - if len(file_config['shape']) != len(file_config['input_nodes']) != len(set(file_config['input_nodes'])): + if len(file_config['shape']) != len(file_config.get("input_nodes", [])) != len( + set(file_config.get("input_nodes", []))): raise BadParamError("`--shape` and `--input_nodes` must have the same length, " "and no redundant node in `--input_nodes`.") diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py index 8551afc3..2c4eff66 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py @@ -17,12 +17,13 @@ __all__ = ["context", "gen_hash_key", "DagGraph", - "MAX_OUT_DEGREE", + "MAX_DEGREE", "cal_matching_score", "ACCEPTABLE_RESULT_COUNT", "MINI_FREQUENCY", "SATISFIED_SCORE", - "MAX_ITERATION_DEPTH"] + "MAX_ITERATION_DEPTH_OF_MULTI_IPT", + "MAX_ITERATION_DEPTH_OF_SINGLE_IPT"] import math import copy @@ -32,9 +33,10 @@ from typing import List from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode -MAX_OUT_DEGREE = 1 +MAX_DEGREE = 1 MINI_FREQUENCY = 0.07 -MAX_ITERATION_DEPTH = 16 +MAX_ITERATION_DEPTH_OF_MULTI_IPT = 16 +MAX_ITERATION_DEPTH_OF_SINGLE_IPT = 8 SATISFIED_SCORE = 0.74 ACCEPTABLE_RESULT_COUNT = 32 PTN_COVERAGE_THRESHOLD = 0.65 @@ -127,6 +129,7 @@ class AlgorithmContext: precursor_table = {} successor_table = {} outputs_table = {} + has_multi_inputs = False def set_init_node_collection(self, nd_col): """Init node_collection.""" diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py index 7a1f7018..2322402b 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py @@ -21,7 +21,6 @@ class Pattern: def __init__(self, pattern, pattern_length, in_degree, out_degree, ptn_items: list = None): self.pattern = pattern - self.count = 0 self.start_index = [] self.end_index = [] self.module_name = None @@ -37,6 +36,11 @@ class Pattern: self.additional_score = 0 self.known_module_name = None + @property + def count(self): + """Count of the pattern.""" + return len(self.start_index) + def insert(self, idx, seq_len): """ Insert a new position. @@ -49,7 +53,6 @@ class Pattern: return self.start_index.append(idx) self.end_index.append(idx + seq_len) - self.count += 1 def __str__(self): """Override `str()` method.""" diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py index 9278b464..6c197d96 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py @@ -21,7 +21,7 @@ from typing import Dict, List, Callable, Union from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \ is_built_in_pattern from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \ - MAX_OUT_DEGREE, cal_matching_score + MAX_DEGREE, cal_matching_score from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ @@ -390,7 +390,7 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph, dag=dag) in_degree, out_degree, _, _ = _get_pattern_degree(found_sequence, dag) - if out_degree > MAX_OUT_DEGREE: + if out_degree > MAX_DEGREE or (not context.has_multi_inputs and in_degree > MAX_DEGREE): cur_idx += 1 continue @@ -419,6 +419,7 @@ def _post_process_overlap(patterns) -> Dict: patterns[name].start_index.pop(idx) patterns[name].end_index.pop(idx) continue + prev_end = patterns[name].end_index[idx] idx += 1 return patterns diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py index cde2c8fd..8ecefb71 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py @@ -17,9 +17,9 @@ from queue import PriorityQueue from typing import Dict, List from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ - ACCEPTABLE_RESULT_COUNT + ACCEPTABLE_RESULT_COUNT, MAX_ITERATION_DEPTH_OF_SINGLE_IPT from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \ - MAX_ITERATION_DEPTH, SATISFIED_SCORE + MAX_ITERATION_DEPTH_OF_MULTI_IPT, SATISFIED_SCORE from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \ @@ -37,7 +37,9 @@ def _is_satisfied(path): Returns: bool, True or False. """ - if len(path.recursion_path) > MAX_ITERATION_DEPTH: + recursion_depth = MAX_ITERATION_DEPTH_OF_MULTI_IPT if context.has_multi_inputs \ + else MAX_ITERATION_DEPTH_OF_SINGLE_IPT + if len(path.recursion_path) > recursion_depth: return True candidate_eval = any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()]) if not path.new_pattern or not candidate_eval: @@ -262,6 +264,8 @@ def _build_connection(loader): context.successor_table[node_name] = list(node.get_successor_dict().keys()) context.outputs_table[node_name] = node.output_name_list + # Record the model inputs count, use it to control the search algorithm. + context.has_multi_inputs = len(loader.input_nodes) > 1 dag = DagGraph(nodes=context.node_collection.copy(), precursor=context.precursor_table.copy(), successor=context.successor_table.copy()) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py index c36f8ee1..6bf1f36b 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py @@ -202,6 +202,8 @@ class OnnxGraph(Graph): else: onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs) onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input] - if input_nodes not in onnx_inputs: - raise ModelNotSupportError(f"input nodes({input_nodes}) is not in model inputs ({onnx_inputs}).") + for ipt in input_nodes: + if ipt not in onnx_inputs: + raise ModelNotSupportError(f"input nodes({input_nodes}) is not " + f"in model inputs ({onnx_inputs}).") return onnx_model diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py index a6451e37..8b74ff3d 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py @@ -93,7 +93,8 @@ class OnnxSimplify: self._constant_nodes = copy.deepcopy(const_nodes) @ModelNotSupportError.check_except( - "Error occurs in loading model, please check your model or runtime environment integrity." + "Error occurs when loading model with given params, please check `--shape`, " + "`--input_nodes`, `--output_nodes`, `--model_file` or runtime environment integrity." ) def _onnx_infer(self, infer_inputs_shape): """ diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py index eb8aa8de..f2763ed1 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py @@ -27,7 +27,8 @@ class PyTorchGraphParser(GraphParser): @classmethod @ModelNotSupportError.check_except( - "Error occurs in loading model, please check your model or runtime environment integrity." + "Error occurs when loading model with given params, please check `--shape`, " + "`--input_nodes`, `--output_nodes`, `--model_file` or runtime environment integrity." ) def parse(cls, model_path: str, **kwargs): """ @@ -47,8 +48,9 @@ class PyTorchGraphParser(GraphParser): raise error try: + sample_shape = list(kwargs.get("input_nodes").values())[0] onnx_model_sim = cls._convert_pytorch_graph_to_onnx( - model_path, kwargs['sample_shape'], opset_version=11) + model_path, sample_shape, opset_version=11) return onnx_model_sim except ModuleNotFoundError: diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py index 007b9e8c..0e152775 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py @@ -27,7 +27,8 @@ class TFGraphParser(GraphParser): @classmethod @ModelNotSupportError.check_except( - "Error occurs in loading model, please check your model or runtime environment integrity." + "Error occurs when loading model with given params, please check `--shape`, " + "`--input_nodes`, `--output_nodes`, `--model_file` or runtime environment integrity." ) def parse(cls, model_path: str, **kwargs): """