Browse Source

Allow user-defined code structure to be set for codegen..

pull/1331/head
liuchongming 4 years ago
parent
commit
0b4c8212e2
4 changed files with 33 additions and 12 deletions
  1. +5
    -0
      mindinsight/mindconverter/__init__.py
  2. +5
    -5
      mindinsight/mindconverter/cli.py
  3. +21
    -5
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  4. +2
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py

+ 5
- 0
mindinsight/mindconverter/__init__.py View File

@@ -18,3 +18,8 @@ MindConverter.
MindConverter is a migration tool to transform the model scripts from PyTorch to Mindspore.
Users can migrate their PyTorch models to Mindspore rapidly with minor changes according to the conversion report.
"""

__all__ = ["user_defined_pattern", "main_entry"]

from mindinsight.mindconverter.cli import run as main_entry
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import user_defined_pattern

+ 5
- 5
mindinsight/mindconverter/cli.py View File

@@ -387,13 +387,13 @@ def cli_entry():
if args.report is None:
args.report = args.output
os.makedirs(args.report, mode=mode, exist_ok=True)
_run(args.in_file, args.model_file,
args.shape,
args.input_nodes, args.output_nodes,
args.output, args.report)
run(args.in_file, args.model_file,
args.shape,
args.input_nodes, args.output_nodes,
args.output, args.report)


def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report):
def run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report):
"""
Run converter command.



+ 21
- 5
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -710,12 +710,28 @@ class ReplacePath(BasePath):
Args:
increment_idx (int): To deduplicate module name.
"""
src = ",".join(self.topo_order_bef_repl)
tgt = self.pattern.pattern
md_name = f"Module{increment_idx}"
src_aft_repl = src.replace(tgt, md_name)
if src != src_aft_repl:
new_order = []
idx = 0
merged = False
while idx < len(self.topo_order_bef_repl):
op_node = self.topo_order_bef_repl[idx]
if op_node.op_type == self.pattern.head:
windowed_nodes = self.topo_order_bef_repl[idx:idx + self.pattern.ptn_length]
cur_window = ",".join([nd.op_type for nd in windowed_nodes])
if cur_window == self.pattern.pattern:
merged_node = MergedONNXNode(name="", module_name=md_name,
ori_nodes=windowed_nodes, inputs="", outputs="",
known_module_name=self.pattern.known_module_name)
new_order.append(merged_node)
idx += self.pattern.ptn_length
merged = True
continue
new_order.append(op_node)
idx += 1

if merged:
self.pattern.module_name = md_name
self.topo_order_aft_repl = src_aft_repl.split(",")
self.topo_order_aft_repl = new_order
return md_name
return None

+ 2
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -438,8 +438,8 @@ def greedy_match(topo_order, user_defined_ptn):
"""
increment_idx = 0
prev_path = None
for md_name, ptn in user_defined_ptn:
ptn = Pattern(",".join(ptn), len(ptn), -1, -1, ptn)
for md_name, ptn_items in user_defined_ptn.items():
ptn = Pattern(",".join(ptn_items), len(ptn_items), -1, -1, ptn_items)
ptn.known_module_name = md_name
topo_order_aft_rpl = topo_order[:] if prev_path is None else prev_path.topo_order_aft_repl
repl_path = ReplacePath(ptn, topo_order_aft_rpl, prev_path=prev_path)


Loading…
Cancel
Save