diff --git a/mindinsight/mindconverter/cli.py b/mindinsight/mindconverter/cli.py index b3818ed2..c0443fd2 100644 --- a/mindinsight/mindconverter/cli.py +++ b/mindinsight/mindconverter/cli.py @@ -352,7 +352,7 @@ def _run(in_files, model_file, shape, out_dir, report, project_path): main_graph_base_converter(file_config) else: - error_msg = "`--in_files` and `--model_file` should be set at least one." + error_msg = "`--in_file` and `--model_file` should be set at least one." error = FileNotFoundError(error_msg) log.error(str(error)) log.exception(error) diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 70999bea..0fde4e35 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -14,6 +14,7 @@ # ============================================================================== """Graph based scripts converter workflow.""" import os +import re import argparse from importlib.util import find_spec @@ -73,6 +74,21 @@ def torch_installation_validation(func): return _f +def _extract_model_name(model_path): + """ + Extract model name from model path. + + Args: + model_path(str): Path of Converted model. + + Returns: + str: Name of Converted model. + """ + + model_name = re.findall(r".*[/](.*).pth", model_path)[-1] + return model_name + + @torch_installation_validation def graph_based_converter(graph_path: str, sample_shape: tuple, output_folder: str, report_folder: str = None, @@ -100,7 +116,10 @@ def graph_based_converter(graph_path: str, sample_shape: tuple, log.error("Error occur when create hierarchical tree.") raise NodeTypeNotSupport("This model is not supported now.") + model_name = _extract_model_name(graph_path) + hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, + model_name=model_name, report_folder=report_folder) diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py index 12772434..094b9440 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py @@ -153,6 +153,7 @@ class HierarchicalTree(Tree): parent_node.set_successors(brothers, tree_id=self.tree_identifier) def save_source_files(self, out_folder: str, mapper: Mapper, + model_name: str, report_folder: str = None) -> NoReturn: """ Save source codes to target folder. @@ -160,6 +161,7 @@ class HierarchicalTree(Tree): Args: report_folder (str): Report folder. mapper (Mapper): Mapper of third party framework and mindspore. + model_name(str): Name of Converted model. out_folder (str): Output folder. """ @@ -171,11 +173,11 @@ class HierarchicalTree(Tree): log.error("Error occur when create hierarchical tree.") raise NodeTypeNotSupport("This model is not supported now.") - out_folder = os.path.abspath(out_folder) + out_folder = os.path.realpath(out_folder) if not report_folder: report_folder = out_folder else: - report_folder = os.path.abspath(report_folder) + report_folder = os.path.realpath(report_folder) if not os.path.exists(out_folder): os.makedirs(out_folder, self.modes_usr) @@ -185,8 +187,8 @@ class HierarchicalTree(Tree): for file_name in code_fragments: code, report = code_fragments[file_name] try: - with os.fdopen(os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"), - self.flags, self.modes), "w") as file: + with os.fdopen(os.open(os.path.realpath(os.path.join(out_folder, f"{model_name}.py")), + self.flags, self.modes), 'w') as file: file.write(code) except IOError as error: log.error(str(error)) @@ -194,7 +196,8 @@ class HierarchicalTree(Tree): raise error try: - with os.fdopen(os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"), + with os.fdopen(os.open(os.path.realpath(os.path.join(report_folder, + f"report_of_{model_name}.txt")), self.flags, stat.S_IRUSR), "w") as rpt_f: rpt_f.write(report) except IOError as error: diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py index 1b7e45bd..bf234376 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py @@ -54,5 +54,10 @@ class PyTorchGraphParser(GraphParser): log.error(str(error)) log.exception(error) raise error + except Exception as e: + error_msg = "Error occurs in loading model, make sure model.pth correct." + log.error(error_msg) + log.exception(e) + raise Exception(error_msg) return model