diff --git a/mindinsight/mindconverter/__init__.py b/mindinsight/mindconverter/__init__.py index 7bd90e4c..a95edf9f 100644 --- a/mindinsight/mindconverter/__init__.py +++ b/mindinsight/mindconverter/__init__.py @@ -18,3 +18,8 @@ MindConverter. MindConverter is a migration tool to transform the model scripts from PyTorch to Mindspore. Users can migrate their PyTorch models to Mindspore rapidly with minor changes according to the conversion report. """ + +__all__ = ["user_defined_pattern", "main_entry"] + +from mindinsight.mindconverter.cli import run as main_entry +from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import user_defined_pattern diff --git a/mindinsight/mindconverter/cli.py b/mindinsight/mindconverter/cli.py index 2b193d48..7566d26b 100644 --- a/mindinsight/mindconverter/cli.py +++ b/mindinsight/mindconverter/cli.py @@ -387,13 +387,13 @@ def cli_entry(): if args.report is None: args.report = args.output os.makedirs(args.report, mode=mode, exist_ok=True) - _run(args.in_file, args.model_file, - args.shape, - args.input_nodes, args.output_nodes, - args.output, args.report) + run(args.in_file, args.model_file, + args.shape, + args.input_nodes, args.output_nodes, + args.output, args.report) -def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report): +def run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report): """ Run converter command. diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py index 7d11fe8d..d90b9298 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py @@ -710,12 +710,28 @@ class ReplacePath(BasePath): Args: increment_idx (int): To deduplicate module name. """ - src = ",".join(self.topo_order_bef_repl) - tgt = self.pattern.pattern md_name = f"Module{increment_idx}" - src_aft_repl = src.replace(tgt, md_name) - if src != src_aft_repl: + new_order = [] + idx = 0 + merged = False + while idx < len(self.topo_order_bef_repl): + op_node = self.topo_order_bef_repl[idx] + if op_node.op_type == self.pattern.head: + windowed_nodes = self.topo_order_bef_repl[idx:idx + self.pattern.ptn_length] + cur_window = ",".join([nd.op_type for nd in windowed_nodes]) + if cur_window == self.pattern.pattern: + merged_node = MergedONNXNode(name="", module_name=md_name, + ori_nodes=windowed_nodes, inputs="", outputs="", + known_module_name=self.pattern.known_module_name) + new_order.append(merged_node) + idx += self.pattern.ptn_length + merged = True + continue + new_order.append(op_node) + idx += 1 + + if merged: self.pattern.module_name = md_name - self.topo_order_aft_repl = src_aft_repl.split(",") + self.topo_order_aft_repl = new_order return md_name return None diff --git a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py index edc86247..084439b6 100644 --- a/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py +++ b/mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py @@ -438,8 +438,8 @@ def greedy_match(topo_order, user_defined_ptn): """ increment_idx = 0 prev_path = None - for md_name, ptn in user_defined_ptn: - ptn = Pattern(",".join(ptn), len(ptn), -1, -1, ptn) + for md_name, ptn_items in user_defined_ptn.items(): + ptn = Pattern(",".join(ptn_items), len(ptn_items), -1, -1, ptn_items) ptn.known_module_name = md_name topo_order_aft_rpl = topo_order[:] if prev_path is None else prev_path.topo_order_aft_repl repl_path = ReplacePath(ptn, topo_order_aft_rpl, prev_path=prev_path)