diff --git a/mindinsight/mindconverter/cli.py b/mindinsight/mindconverter/cli.py index 9c49d65d..2529e23d 100644 --- a/mindinsight/mindconverter/cli.py +++ b/mindinsight/mindconverter/cli.py @@ -19,6 +19,9 @@ import argparse import mindinsight from mindinsight.mindconverter.converter import main +from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter + +from mindinsight.mindconverter.common.log import logger as log class FileDirAction(argparse.Action): @@ -92,6 +95,26 @@ class OutputDirAction(argparse.Action): setattr(namespace, self.dest, output) +class ProjectPathAction(argparse.Action): + """Project directory action class definition.""" + + def __call__(self, parser, namespace, values, option_string=None): + """ + Inherited __call__ method from argparse.Action. + + Args: + parser (ArgumentParser): Passed-in argument parser. + namespace (Namespace): Namespace object to hold arguments. + values (object): Argument values with type depending on argument definition. + option_string (str): Optional string for specific argument name. Default: None. + """ + outfile_dir = FileDirAction.check_path(parser, values, option_string) + if not os.path.isdir(outfile_dir): + parser.error(f'{option_string} [{outfile_dir}] should be a directory.') + + setattr(namespace, self.dest, outfile_dir) + + class InFileAction(argparse.Action): """Input File action class definition.""" @@ -134,6 +157,29 @@ class LogFileAction(argparse.Action): setattr(namespace, self.dest, outfile_dir) +class ShapeAction(argparse.Action): + """Shape action class definition.""" + + def __call__(self, parser, namespace, values, option_string=None): + """ + Inherited __call__ method from FileDirAction. + + Args: + parser (ArgumentParser): Passed-in argument parser. + namespace (Namespace): Namespace object to hold arguments. + values (object): Argument values with type depending on argument definition. + option_string (str): Optional string for specific argument name. Default: None. + """ + in_shape = None + shape_str = values + try: + in_shape = [int(num_shape) for num_shape in shape_str.split(',')] + except ValueError: + parser.error( + f"{option_string} {shape_str} should be a list of integer split by ',', check it please.") + setattr(namespace, self.dest, in_shape) + + def cli_entry(): """Entry point for mindconverter CLI.""" @@ -153,9 +199,36 @@ def cli_entry(): '--in_file', type=str, action=InFileAction, - required=True, + required=False, + default=None, + help=""" + Specify path for script file. + """) + + parser.add_argument( + '--model_file', + type=str, + action=InFileAction, + required=False, help=""" - Specify path for script file. + Pytorch .pth model file path ot use graph + based schema to do script generation. When + `--in_file` and `--model_path` are both provided, + use AST schema as default. + Usage: --model_file ~/pytorch_file/net.pth. + """) + + parser.add_argument( + '--shape', + type=str, + action=ShapeAction, + default=None, + required=False, + help=""" + Optional, excepted input tensor shape of + `--model_file`. It's required when use graph based + schema. + Usage: --shape 3,244,244 """) parser.add_argument( @@ -172,11 +245,24 @@ def cli_entry(): '--report', type=str, action=LogFileAction, - default=os.getcwd(), + default=None, help=""" Specify report directory. Default is the current working directory. """) + parser.add_argument( + '--project_path', + type=str, + action=ProjectPathAction, + required=False, + default=None, + help=""" + Optional, pytorch scripts project path. If pytorch + project is not in PYTHONPATH, please assign + `--project_path' when use graph based schema. + Usage: --project_path ~/script_file/ + """) + argv = sys.argv[1:] if not argv: argv = ['-h'] @@ -185,30 +271,58 @@ def cli_entry(): args = parser.parse_args() mode = permissions << 6 os.makedirs(args.output, mode=mode, exist_ok=True) + if args.report is None: + args.report = args.output os.makedirs(args.report, mode=mode, exist_ok=True) - _run(args.in_file, args.output, args.report) + _run(args.in_file, args.model_file, args.shape, args.output, args.report, args.project_path) -def _run(in_files, out_dir, report): +def _run(in_files, model_file, shape, out_dir, report, project_path): """ Run converter command. Args: in_files (str): The file path or directory to convert. + model_file(str): The pytorch .pth to convert on graph based schema. + shape(list): The input tensor shape of module_file. out_dir (str): The output directory to save converted file. report (str): The report file path. + project_path(str): Pytorch scripts project path. """ - files_config = { - 'root_path': in_files if in_files else '', - 'in_files': [], - 'outfile_dir': out_dir, - 'report_dir': report - } - if os.path.isfile(in_files): - files_config['root_path'] = os.path.dirname(in_files) - files_config['in_files'] = [in_files] + if in_files: + files_config = { + 'root_path': in_files, + 'in_files': [], + 'outfile_dir': out_dir, + 'report_dir': report if report else out_dir + } + + if os.path.isfile(in_files): + files_config['root_path'] = os.path.dirname(in_files) + files_config['in_files'] = [in_files] + else: + for root_dir, _, files in os.walk(in_files): + for file in files: + files_config['in_files'].append(os.path.join(root_dir, file)) + main(files_config) + + elif model_file: + file_config = { + 'model_file': model_file, + 'shape': shape if shape else [], + 'outfile_dir': out_dir, + 'report_dir': report if report else out_dir + } + if project_path: + paths = sys.path + if project_path not in paths: + sys.path.append(project_path) + + main_graph_base_converter(file_config) + else: - for root_dir, _, files in os.walk(in_files): - for file in files: - files_config['in_files'].append(os.path.join(root_dir, file)) - main(files_config) + error_msg = "`--in_files` and `--model_file` should be set at least one." + error = FileNotFoundError(error_msg) + log.error(str(error)) + log.exception(error) + raise error diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 5d660239..f4a7fec2 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -18,6 +18,7 @@ import argparse from importlib.util import find_spec import mindinsight +from mindinsight.mindconverter.common.log import logger as log from .mapper import ONNXToMindSporeMapper permissions = os.R_OK | os.W_OK | os.X_OK @@ -57,9 +58,12 @@ def torch_installation_validation(func): checkpoint_path: str = None): # Check whether pytorch is installed. if not find_spec("torch"): - raise ModuleNotFoundError("PyTorch is required when using graph based " - "scripts converter, and PyTorch vision must " - "be consisted with model generation runtime.") + error = ModuleNotFoundError("PyTorch is required when using graph based " + "scripts converter, and PyTorch vision must " + "be consisted with model generation runtime.") + log.error(str(error)) + log.exception(error) + raise error func(graph_path=graph_path, sample_shape=sample_shape, output_folder=output_folder, report_folder=report_folder, @@ -93,10 +97,14 @@ def graph_based_converter(graph_path: str, sample_shape: tuple, report_folder=report_folder) -if __name__ == '__main__': - args, _ = parser.parse_known_args() - graph_based_converter(graph_path=args.graph, - sample_shape=args.sample_shape, - output_folder=args.output, - report_folder=args.report, - checkpoint_path=args.ckpt) +def main_graph_base_converter(file_config): + """ + The entrance for converter, script files will be converted. + + Args: + file_config (dict): The config of file which to convert. + """ + graph_based_converter(graph_path=file_config['model_file'], + sample_shape=file_config['shape'], + output_folder=file_config['outfile_dir'], + report_folder=file_config['report_dir']) diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py index 6f235ddd..1497ee15 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py @@ -14,6 +14,7 @@ # ============================================================================== """Define hierarchical tree.""" import os +import stat from copy import deepcopy from typing import NoReturn, Union from queue import Queue @@ -21,6 +22,8 @@ from queue import Queue from yapf.yapflib.yapf_api import FormatCode from treelib import Tree, Node +from mindinsight.mindconverter.common.log import logger as log + from .name_mgr import ModuleNameMgr, GlobalVarNameMgr from ..mapper.base import Mapper from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode @@ -34,6 +37,10 @@ GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() class HierarchicalTree(Tree): """Define hierarchical tree.""" + flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL + modes = stat.S_IRUSR | stat.S_IWUSR + modes_usr = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR + _root_created = False ROOT_LEVEL = 0 @@ -162,19 +169,31 @@ class HierarchicalTree(Tree): report_folder = os.path.abspath(report_folder) if not os.path.exists(out_folder): - os.makedirs(out_folder) + os.makedirs(out_folder, self.modes_usr) if not os.path.exists(report_folder): - os.makedirs(report_folder) + os.makedirs(report_folder, self.modes_usr) for file_name in code_fragments: code, report = code_fragments[file_name] - with open(os.path.join(os.path.abspath(out_folder), - f"{file_name}.py"), "w") as file: - file.write(code) - - with open(os.path.join(report_folder, - f"report_of_{file_name}.txt"), "w") as rpt_f: - rpt_f.write(report) + try: + with os.fdopen( + os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"), + self.flags, self.modes), 'w') as file: + file.write(code) + except IOError as error: + log.error(str(error)) + log.exception(error) + raise error + + try: + with os.fdopen( + os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"), + self.flags, stat.S_IRUSR), "w") as rpt_f: + rpt_f.write(report) + except IOError as error: + log.error(str(error)) + log.exception(error) + raise error def _preprocess_node_args(self, node, module_key): """ @@ -625,7 +644,6 @@ class HierarchicalTree(Tree): nd_inst = self.get_node(successor_name) # Generate variable name here, then # to generate args. - # if nd_inst.data.node_type == NodeType.OPERATION.value: if created: nd_inst.data.variable_name = self._module_vars[module_key][idx] else: diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py index 595d3737..c2e6011c 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py @@ -16,6 +16,7 @@ import abc from collections import OrderedDict +from mindinsight.mindconverter.common.log import logger as log from ..constant import SEPARATOR_IN_ONNX_OP from ..mapper.base import Mapper @@ -66,8 +67,11 @@ class BaseGraph(metaclass=abc.ABCMeta): """Control the create action of graph.""" model_param = args[0] if args else kwargs.get(cls._REQUIRED_PARAM_OF_MODEL) if not model_param: - raise ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` " - f"can not be None.") + error = ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` " + f"can not be None.") + log.error(str(error)) + log.exception(error) + raise error return super(BaseGraph, cls).__new__(cls) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py index e9c027b2..1b7e45bd 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/graph_parser.py @@ -14,6 +14,7 @@ # ============================================================================== """Third party graph parser.""" import os +from mindinsight.mindconverter.common.log import logger as log from .base import GraphParser @@ -34,12 +35,24 @@ class PyTorchGraphParser(GraphParser): import torch if not os.path.exists(model_path): - raise FileNotFoundError("`model_path` must be assigned with " - "an existed file path.") - - if torch.cuda.is_available(): - model = torch.load(f=model_path) - else: - model = torch.load(f=model_path, map_location="cpu") + error = FileNotFoundError("`model_path` must be assigned with " + "an existed file path.") + log.error(str(error)) + log.exception(error) + raise error + + try: + if torch.cuda.is_available(): + model = torch.load(f=model_path) + else: + model = torch.load(f=model_path, map_location="cpu") + except ModuleNotFoundError: + error_msg = \ + "Cannot find model scripts in system path, " \ + "set `--project_path` to the path of model scripts folder correctly." + error = ModuleNotFoundError(error_msg) + log.error(str(error)) + log.exception(error) + raise error return model diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py index 078d06e8..a7a2e70a 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py @@ -17,6 +17,7 @@ import warnings import re from typing import Dict, NoReturn +from mindinsight.mindconverter.common.log import logger as log from .base import Graph from .input_node import InputNode from .pytorch_graph_node import PyTorchGraphNode @@ -89,12 +90,18 @@ class PyTorchGraph(Graph): """ if not input_shape: - raise ValueError("`input_shape` can not be None.") + error = ValueError("`input_shape` can not be None.") + log.error(str(error)) + log.exception(error) + raise error for item in input_shape: if not isinstance(item, int): - raise ValueError(f"Only support model with one input now, " - f"and each shape value in `input_shape` should be int.") + error = ValueError(f"Only support model with one input now, " + f"and each shape value in `input_shape` should be int.") + log.error(str(error)) + log.exception(error) + raise error def build(self, input_shape): """ @@ -122,9 +129,11 @@ class PyTorchGraph(Graph): Returns: list, shape. """ - pattern = re.compile(r"\d+:\d*") - if not pattern.findall(shape): + if "," not in shape: return [] + for s in shape.split(","): + if not s: + return [] return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')] feed_forward_ipt_shape = (1, *input_shape) @@ -133,10 +142,15 @@ class PyTorchGraph(Graph): # Assign execution mode to eval. self.model.eval() - with OverloadTorchModuleTemporarily() as _: - # In pytorch higher version, trace function has a known. - graph = onnx_tracer(self.model, batched_sample, - OperatorExportTypes.ONNX) + try: + with OverloadTorchModuleTemporarily() as _: + # In pytorch higher version, trace function has a known. + graph = onnx_tracer(self.model, batched_sample, + OperatorExportTypes.ONNX) + except RuntimeError as error: + log.error(str(error)) + log.exception(error) + raise error nodes = list(graph.nodes()) @@ -190,6 +204,37 @@ class PyTorchGraph(Graph): """ raise NotImplementedError() + def to_hierarchical_tree(self): + """ + Generate hierarchical tree based on graph. + """ + from ..hierarchical_tree import HierarchicalTree + + tree = HierarchicalTree() + node_input = None + for _, node_name in enumerate(self.nodes_in_topological_order): + node_inst = self.get_node(node_name) + node_output = self._shape_dict.get(node_name) + if node_inst.in_degree == 0: + # If in-degree equals to zero, then it's a input node. + continue + + # If the node is on the top, then fetch its input + # from input table. + if not node_input: + node_input = self._input_shape.get(node_name) + + if not node_input: + error = ValueError(f"This model is not supported now. " + f"Cannot find {node_name}'s input shape.") + log.error(str(error)) + log.exception(error) + raise error + + tree.insert(node_inst, node_name, node_input, node_output) + node_input = node_output + return tree + def build_connection(self, src, tgt) -> NoReturn: """ Build connection between source node and target node. @@ -229,8 +274,11 @@ class PyTorchGraph(Graph): """ Load graph metadata. """ - raise NotImplementedError("class `PyTorchGraph` has not implemented " - "`load_metadata()`.") + error = NotImplementedError("class `PyTorchGraph` has not implemented " + "`load_metadata()`.") + log.error(str(error)) + log.exception(error) + raise error @staticmethod def load_graph(graph_path: str): diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py index a0fd4306..d98b22f5 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py @@ -116,8 +116,6 @@ class PyTorchGraphNode(GraphNode): """ if not self._module_name_frozen: module_name = self.tag - # if self._node_type == NodeType.CLASS.value: - # module_name = f"{module_name[0].upper()}{module_name[1:]}" return module_name return self._module_name