Browse Source

!1331 Add user-defined code structure python interface.

From: @liuchongming74
Reviewed-by: @yelihua,@ouwenchang
Signed-off-by: @ouwenchang
pull/1331/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
620479138e
4 changed files with 33 additions and 12 deletions
  1. +5
    -0
      mindinsight/mindconverter/__init__.py
  2. +5
    -5
      mindinsight/mindconverter/cli.py
  3. +21
    -5
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  4. +2
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py

+ 5
- 0
mindinsight/mindconverter/__init__.py View File

@@ -18,3 +18,8 @@ MindConverter.
MindConverter is a migration tool to transform the model scripts from PyTorch to Mindspore. MindConverter is a migration tool to transform the model scripts from PyTorch to Mindspore.
Users can migrate their PyTorch models to Mindspore rapidly with minor changes according to the conversion report. Users can migrate their PyTorch models to Mindspore rapidly with minor changes according to the conversion report.
""" """

__all__ = ["user_defined_pattern", "main_entry"]

from mindinsight.mindconverter.cli import run as main_entry
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import user_defined_pattern

+ 5
- 5
mindinsight/mindconverter/cli.py View File

@@ -387,13 +387,13 @@ def cli_entry():
if args.report is None: if args.report is None:
args.report = args.output args.report = args.output
os.makedirs(args.report, mode=mode, exist_ok=True) os.makedirs(args.report, mode=mode, exist_ok=True)
_run(args.in_file, args.model_file,
args.shape,
args.input_nodes, args.output_nodes,
args.output, args.report)
run(args.in_file, args.model_file,
args.shape,
args.input_nodes, args.output_nodes,
args.output, args.report)




def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report):
def run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report):
""" """
Run converter command. Run converter command.




+ 21
- 5
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -710,12 +710,28 @@ class ReplacePath(BasePath):
Args: Args:
increment_idx (int): To deduplicate module name. increment_idx (int): To deduplicate module name.
""" """
src = ",".join(self.topo_order_bef_repl)
tgt = self.pattern.pattern
md_name = f"Module{increment_idx}" md_name = f"Module{increment_idx}"
src_aft_repl = src.replace(tgt, md_name)
if src != src_aft_repl:
new_order = []
idx = 0
merged = False
while idx < len(self.topo_order_bef_repl):
op_node = self.topo_order_bef_repl[idx]
if op_node.op_type == self.pattern.head:
windowed_nodes = self.topo_order_bef_repl[idx:idx + self.pattern.ptn_length]
cur_window = ",".join([nd.op_type for nd in windowed_nodes])
if cur_window == self.pattern.pattern:
merged_node = MergedONNXNode(name="", module_name=md_name,
ori_nodes=windowed_nodes, inputs="", outputs="",
known_module_name=self.pattern.known_module_name)
new_order.append(merged_node)
idx += self.pattern.ptn_length
merged = True
continue
new_order.append(op_node)
idx += 1

if merged:
self.pattern.module_name = md_name self.pattern.module_name = md_name
self.topo_order_aft_repl = src_aft_repl.split(",")
self.topo_order_aft_repl = new_order
return md_name return md_name
return None return None

+ 2
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -438,8 +438,8 @@ def greedy_match(topo_order, user_defined_ptn):
""" """
increment_idx = 0 increment_idx = 0
prev_path = None prev_path = None
for md_name, ptn in user_defined_ptn:
ptn = Pattern(",".join(ptn), len(ptn), -1, -1, ptn)
for md_name, ptn_items in user_defined_ptn.items():
ptn = Pattern(",".join(ptn_items), len(ptn_items), -1, -1, ptn_items)
ptn.known_module_name = md_name ptn.known_module_name = md_name
topo_order_aft_rpl = topo_order[:] if prev_path is None else prev_path.topo_order_aft_repl topo_order_aft_rpl = topo_order[:] if prev_path is None else prev_path.topo_order_aft_repl
repl_path = ReplacePath(ptn, topo_order_aft_rpl, prev_path=prev_path) repl_path = ReplacePath(ptn, topo_order_aft_rpl, prev_path=prev_path)


Loading…
Cancel
Save