Browse Source

!665 Fix error message, converted model and report name in Graph based Mindspore scripts generation.

Merge pull request !665 from moran/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ad2ef0e484
4 changed files with 33 additions and 6 deletions
  1. +1
    -1
      mindinsight/mindconverter/cli.py
  2. +19
    -0
      mindinsight/mindconverter/graph_based_converter/framework.py
  3. +8
    -5
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  4. +5
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py

+ 1
- 1
mindinsight/mindconverter/cli.py View File

@@ -352,7 +352,7 @@ def _run(in_files, model_file, shape, out_dir, report, project_path):
main_graph_base_converter(file_config)

else:
error_msg = "`--in_files` and `--model_file` should be set at least one."
error_msg = "`--in_file` and `--model_file` should be set at least one."
error = FileNotFoundError(error_msg)
log.error(str(error))
log.exception(error)


+ 19
- 0
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Graph based scripts converter workflow."""
import os
import re
import argparse
from importlib.util import find_spec

@@ -73,6 +74,21 @@ def torch_installation_validation(func):
return _f


def _extract_model_name(model_path):
"""
Extract model name from model path.

Args:
model_path(str): Path of Converted model.

Returns:
str: Name of Converted model.
"""

model_name = re.findall(r".*[/](.*).pth", model_path)[-1]
return model_name


@torch_installation_validation
def graph_based_converter(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None,
@@ -100,7 +116,10 @@ def graph_based_converter(graph_path: str, sample_shape: tuple,
log.error("Error occur when create hierarchical tree.")
raise NodeTypeNotSupport("This model is not supported now.")

model_name = _extract_model_name(graph_path)

hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
model_name=model_name,
report_folder=report_folder)




+ 8
- 5
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -153,6 +153,7 @@ class HierarchicalTree(Tree):
parent_node.set_successors(brothers, tree_id=self.tree_identifier)

def save_source_files(self, out_folder: str, mapper: Mapper,
model_name: str,
report_folder: str = None) -> NoReturn:
"""
Save source codes to target folder.
@@ -160,6 +161,7 @@ class HierarchicalTree(Tree):
Args:
report_folder (str): Report folder.
mapper (Mapper): Mapper of third party framework and mindspore.
model_name(str): Name of Converted model.
out_folder (str): Output folder.

"""
@@ -171,11 +173,11 @@ class HierarchicalTree(Tree):
log.error("Error occur when create hierarchical tree.")
raise NodeTypeNotSupport("This model is not supported now.")

out_folder = os.path.abspath(out_folder)
out_folder = os.path.realpath(out_folder)
if not report_folder:
report_folder = out_folder
else:
report_folder = os.path.abspath(report_folder)
report_folder = os.path.realpath(report_folder)

if not os.path.exists(out_folder):
os.makedirs(out_folder, self.modes_usr)
@@ -185,8 +187,8 @@ class HierarchicalTree(Tree):
for file_name in code_fragments:
code, report = code_fragments[file_name]
try:
with os.fdopen(os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"),
self.flags, self.modes), "w") as file:
with os.fdopen(os.open(os.path.realpath(os.path.join(out_folder, f"{model_name}.py")),
self.flags, self.modes), 'w') as file:
file.write(code)
except IOError as error:
log.error(str(error))
@@ -194,7 +196,8 @@ class HierarchicalTree(Tree):
raise error

try:
with os.fdopen(os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"),
with os.fdopen(os.open(os.path.realpath(os.path.join(report_folder,
f"report_of_{model_name}.txt")),
self.flags, stat.S_IRUSR), "w") as rpt_f:
rpt_f.write(report)
except IOError as error:


+ 5
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py View File

@@ -54,5 +54,10 @@ class PyTorchGraphParser(GraphParser):
log.error(str(error))
log.exception(error)
raise error
except Exception as e:
error_msg = "Error occurs in loading model, make sure model.pth correct."
log.error(error_msg)
log.exception(e)
raise Exception(error_msg)

return model

Loading…
Cancel
Save