From: @liuchongming74 Reviewed-by: @yelihua,@ouwenchang Signed-off-by: @ouwenchangpull/1331/MERGE
| @@ -18,3 +18,8 @@ MindConverter. | |||||
| MindConverter is a migration tool to transform the model scripts from PyTorch to Mindspore. | MindConverter is a migration tool to transform the model scripts from PyTorch to Mindspore. | ||||
| Users can migrate their PyTorch models to Mindspore rapidly with minor changes according to the conversion report. | Users can migrate their PyTorch models to Mindspore rapidly with minor changes according to the conversion report. | ||||
| """ | """ | ||||
| __all__ = ["user_defined_pattern", "main_entry"] | |||||
| from mindinsight.mindconverter.cli import run as main_entry | |||||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import user_defined_pattern | |||||
| @@ -387,13 +387,13 @@ def cli_entry(): | |||||
| if args.report is None: | if args.report is None: | ||||
| args.report = args.output | args.report = args.output | ||||
| os.makedirs(args.report, mode=mode, exist_ok=True) | os.makedirs(args.report, mode=mode, exist_ok=True) | ||||
| _run(args.in_file, args.model_file, | |||||
| args.shape, | |||||
| args.input_nodes, args.output_nodes, | |||||
| args.output, args.report) | |||||
| run(args.in_file, args.model_file, | |||||
| args.shape, | |||||
| args.input_nodes, args.output_nodes, | |||||
| args.output, args.report) | |||||
| def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report): | |||||
| def run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report): | |||||
| """ | """ | ||||
| Run converter command. | Run converter command. | ||||
| @@ -710,12 +710,28 @@ class ReplacePath(BasePath): | |||||
| Args: | Args: | ||||
| increment_idx (int): To deduplicate module name. | increment_idx (int): To deduplicate module name. | ||||
| """ | """ | ||||
| src = ",".join(self.topo_order_bef_repl) | |||||
| tgt = self.pattern.pattern | |||||
| md_name = f"Module{increment_idx}" | md_name = f"Module{increment_idx}" | ||||
| src_aft_repl = src.replace(tgt, md_name) | |||||
| if src != src_aft_repl: | |||||
| new_order = [] | |||||
| idx = 0 | |||||
| merged = False | |||||
| while idx < len(self.topo_order_bef_repl): | |||||
| op_node = self.topo_order_bef_repl[idx] | |||||
| if op_node.op_type == self.pattern.head: | |||||
| windowed_nodes = self.topo_order_bef_repl[idx:idx + self.pattern.ptn_length] | |||||
| cur_window = ",".join([nd.op_type for nd in windowed_nodes]) | |||||
| if cur_window == self.pattern.pattern: | |||||
| merged_node = MergedONNXNode(name="", module_name=md_name, | |||||
| ori_nodes=windowed_nodes, inputs="", outputs="", | |||||
| known_module_name=self.pattern.known_module_name) | |||||
| new_order.append(merged_node) | |||||
| idx += self.pattern.ptn_length | |||||
| merged = True | |||||
| continue | |||||
| new_order.append(op_node) | |||||
| idx += 1 | |||||
| if merged: | |||||
| self.pattern.module_name = md_name | self.pattern.module_name = md_name | ||||
| self.topo_order_aft_repl = src_aft_repl.split(",") | |||||
| self.topo_order_aft_repl = new_order | |||||
| return md_name | return md_name | ||||
| return None | return None | ||||
| @@ -438,8 +438,8 @@ def greedy_match(topo_order, user_defined_ptn): | |||||
| """ | """ | ||||
| increment_idx = 0 | increment_idx = 0 | ||||
| prev_path = None | prev_path = None | ||||
| for md_name, ptn in user_defined_ptn: | |||||
| ptn = Pattern(",".join(ptn), len(ptn), -1, -1, ptn) | |||||
| for md_name, ptn_items in user_defined_ptn.items(): | |||||
| ptn = Pattern(",".join(ptn_items), len(ptn_items), -1, -1, ptn_items) | |||||
| ptn.known_module_name = md_name | ptn.known_module_name = md_name | ||||
| topo_order_aft_rpl = topo_order[:] if prev_path is None else prev_path.topo_order_aft_repl | topo_order_aft_rpl = topo_order[:] if prev_path is None else prev_path.topo_order_aft_repl | ||||
| repl_path = ReplacePath(ptn, topo_order_aft_rpl, prev_path=prev_path) | repl_path = ReplacePath(ptn, topo_order_aft_rpl, prev_path=prev_path) | ||||