| @@ -19,6 +19,9 @@ import argparse | |||
| import mindinsight | |||
| from mindinsight.mindconverter.converter import main | |||
| from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| class FileDirAction(argparse.Action): | |||
| @@ -92,6 +95,26 @@ class OutputDirAction(argparse.Action): | |||
| setattr(namespace, self.dest, output) | |||
| class ProjectPathAction(argparse.Action): | |||
| """Project directory action class definition.""" | |||
| def __call__(self, parser, namespace, values, option_string=None): | |||
| """ | |||
| Inherited __call__ method from argparse.Action. | |||
| Args: | |||
| parser (ArgumentParser): Passed-in argument parser. | |||
| namespace (Namespace): Namespace object to hold arguments. | |||
| values (object): Argument values with type depending on argument definition. | |||
| option_string (str): Optional string for specific argument name. Default: None. | |||
| """ | |||
| outfile_dir = FileDirAction.check_path(parser, values, option_string) | |||
| if not os.path.isdir(outfile_dir): | |||
| parser.error(f'{option_string} [{outfile_dir}] should be a directory.') | |||
| setattr(namespace, self.dest, outfile_dir) | |||
| class InFileAction(argparse.Action): | |||
| """Input File action class definition.""" | |||
| @@ -134,6 +157,29 @@ class LogFileAction(argparse.Action): | |||
| setattr(namespace, self.dest, outfile_dir) | |||
| class ShapeAction(argparse.Action): | |||
| """Shape action class definition.""" | |||
| def __call__(self, parser, namespace, values, option_string=None): | |||
| """ | |||
| Inherited __call__ method from FileDirAction. | |||
| Args: | |||
| parser (ArgumentParser): Passed-in argument parser. | |||
| namespace (Namespace): Namespace object to hold arguments. | |||
| values (object): Argument values with type depending on argument definition. | |||
| option_string (str): Optional string for specific argument name. Default: None. | |||
| """ | |||
| in_shape = None | |||
| shape_str = values | |||
| try: | |||
| in_shape = [int(num_shape) for num_shape in shape_str.split(',')] | |||
| except ValueError: | |||
| parser.error( | |||
| f"{option_string} {shape_str} should be a list of integer split by ',', check it please.") | |||
| setattr(namespace, self.dest, in_shape) | |||
| def cli_entry(): | |||
| """Entry point for mindconverter CLI.""" | |||
| @@ -153,9 +199,36 @@ def cli_entry(): | |||
| '--in_file', | |||
| type=str, | |||
| action=InFileAction, | |||
| required=True, | |||
| required=False, | |||
| default=None, | |||
| help=""" | |||
| Specify path for script file. | |||
| """) | |||
| parser.add_argument( | |||
| '--model_file', | |||
| type=str, | |||
| action=InFileAction, | |||
| required=False, | |||
| help=""" | |||
| Specify path for script file. | |||
| Pytorch .pth model file path ot use graph | |||
| based schema to do script generation. When | |||
| `--in_file` and `--model_path` are both provided, | |||
| use AST schema as default. | |||
| Usage: --model_file ~/pytorch_file/net.pth. | |||
| """) | |||
| parser.add_argument( | |||
| '--shape', | |||
| type=str, | |||
| action=ShapeAction, | |||
| default=None, | |||
| required=False, | |||
| help=""" | |||
| Optional, excepted input tensor shape of | |||
| `--model_file`. It's required when use graph based | |||
| schema. | |||
| Usage: --shape 3,244,244 | |||
| """) | |||
| parser.add_argument( | |||
| @@ -172,11 +245,24 @@ def cli_entry(): | |||
| '--report', | |||
| type=str, | |||
| action=LogFileAction, | |||
| default=os.getcwd(), | |||
| default=None, | |||
| help=""" | |||
| Specify report directory. Default is the current working directory. | |||
| """) | |||
| parser.add_argument( | |||
| '--project_path', | |||
| type=str, | |||
| action=ProjectPathAction, | |||
| required=False, | |||
| default=None, | |||
| help=""" | |||
| Optional, pytorch scripts project path. If pytorch | |||
| project is not in PYTHONPATH, please assign | |||
| `--project_path' when use graph based schema. | |||
| Usage: --project_path ~/script_file/ | |||
| """) | |||
| argv = sys.argv[1:] | |||
| if not argv: | |||
| argv = ['-h'] | |||
| @@ -185,30 +271,58 @@ def cli_entry(): | |||
| args = parser.parse_args() | |||
| mode = permissions << 6 | |||
| os.makedirs(args.output, mode=mode, exist_ok=True) | |||
| if args.report is None: | |||
| args.report = args.output | |||
| os.makedirs(args.report, mode=mode, exist_ok=True) | |||
| _run(args.in_file, args.output, args.report) | |||
| _run(args.in_file, args.model_file, args.shape, args.output, args.report, args.project_path) | |||
| def _run(in_files, out_dir, report): | |||
| def _run(in_files, model_file, shape, out_dir, report, project_path): | |||
| """ | |||
| Run converter command. | |||
| Args: | |||
| in_files (str): The file path or directory to convert. | |||
| model_file(str): The pytorch .pth to convert on graph based schema. | |||
| shape(list): The input tensor shape of module_file. | |||
| out_dir (str): The output directory to save converted file. | |||
| report (str): The report file path. | |||
| project_path(str): Pytorch scripts project path. | |||
| """ | |||
| files_config = { | |||
| 'root_path': in_files if in_files else '', | |||
| 'in_files': [], | |||
| 'outfile_dir': out_dir, | |||
| 'report_dir': report | |||
| } | |||
| if os.path.isfile(in_files): | |||
| files_config['root_path'] = os.path.dirname(in_files) | |||
| files_config['in_files'] = [in_files] | |||
| if in_files: | |||
| files_config = { | |||
| 'root_path': in_files, | |||
| 'in_files': [], | |||
| 'outfile_dir': out_dir, | |||
| 'report_dir': report if report else out_dir | |||
| } | |||
| if os.path.isfile(in_files): | |||
| files_config['root_path'] = os.path.dirname(in_files) | |||
| files_config['in_files'] = [in_files] | |||
| else: | |||
| for root_dir, _, files in os.walk(in_files): | |||
| for file in files: | |||
| files_config['in_files'].append(os.path.join(root_dir, file)) | |||
| main(files_config) | |||
| elif model_file: | |||
| file_config = { | |||
| 'model_file': model_file, | |||
| 'shape': shape if shape else [], | |||
| 'outfile_dir': out_dir, | |||
| 'report_dir': report if report else out_dir | |||
| } | |||
| if project_path: | |||
| paths = sys.path | |||
| if project_path not in paths: | |||
| sys.path.append(project_path) | |||
| main_graph_base_converter(file_config) | |||
| else: | |||
| for root_dir, _, files in os.walk(in_files): | |||
| for file in files: | |||
| files_config['in_files'].append(os.path.join(root_dir, file)) | |||
| main(files_config) | |||
| error_msg = "`--in_files` and `--model_file` should be set at least one." | |||
| error = FileNotFoundError(error_msg) | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| @@ -18,6 +18,7 @@ import argparse | |||
| from importlib.util import find_spec | |||
| import mindinsight | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .mapper import ONNXToMindSporeMapper | |||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||
| @@ -57,9 +58,12 @@ def torch_installation_validation(func): | |||
| checkpoint_path: str = None): | |||
| # Check whether pytorch is installed. | |||
| if not find_spec("torch"): | |||
| raise ModuleNotFoundError("PyTorch is required when using graph based " | |||
| "scripts converter, and PyTorch vision must " | |||
| "be consisted with model generation runtime.") | |||
| error = ModuleNotFoundError("PyTorch is required when using graph based " | |||
| "scripts converter, and PyTorch vision must " | |||
| "be consisted with model generation runtime.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||
| output_folder=output_folder, report_folder=report_folder, | |||
| @@ -93,10 +97,14 @@ def graph_based_converter(graph_path: str, sample_shape: tuple, | |||
| report_folder=report_folder) | |||
| if __name__ == '__main__': | |||
| args, _ = parser.parse_known_args() | |||
| graph_based_converter(graph_path=args.graph, | |||
| sample_shape=args.sample_shape, | |||
| output_folder=args.output, | |||
| report_folder=args.report, | |||
| checkpoint_path=args.ckpt) | |||
| def main_graph_base_converter(file_config): | |||
| """ | |||
| The entrance for converter, script files will be converted. | |||
| Args: | |||
| file_config (dict): The config of file which to convert. | |||
| """ | |||
| graph_based_converter(graph_path=file_config['model_file'], | |||
| sample_shape=file_config['shape'], | |||
| output_folder=file_config['outfile_dir'], | |||
| report_folder=file_config['report_dir']) | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Define hierarchical tree.""" | |||
| import os | |||
| import stat | |||
| from copy import deepcopy | |||
| from typing import NoReturn, Union | |||
| from queue import Queue | |||
| @@ -21,6 +22,8 @@ from queue import Queue | |||
| from yapf.yapflib.yapf_api import FormatCode | |||
| from treelib import Tree, Node | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .name_mgr import ModuleNameMgr, GlobalVarNameMgr | |||
| from ..mapper.base import Mapper | |||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | |||
| @@ -34,6 +37,10 @@ GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() | |||
| class HierarchicalTree(Tree): | |||
| """Define hierarchical tree.""" | |||
| flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL | |||
| modes = stat.S_IRUSR | stat.S_IWUSR | |||
| modes_usr = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR | |||
| _root_created = False | |||
| ROOT_LEVEL = 0 | |||
| @@ -162,19 +169,31 @@ class HierarchicalTree(Tree): | |||
| report_folder = os.path.abspath(report_folder) | |||
| if not os.path.exists(out_folder): | |||
| os.makedirs(out_folder) | |||
| os.makedirs(out_folder, self.modes_usr) | |||
| if not os.path.exists(report_folder): | |||
| os.makedirs(report_folder) | |||
| os.makedirs(report_folder, self.modes_usr) | |||
| for file_name in code_fragments: | |||
| code, report = code_fragments[file_name] | |||
| with open(os.path.join(os.path.abspath(out_folder), | |||
| f"{file_name}.py"), "w") as file: | |||
| file.write(code) | |||
| with open(os.path.join(report_folder, | |||
| f"report_of_{file_name}.txt"), "w") as rpt_f: | |||
| rpt_f.write(report) | |||
| try: | |||
| with os.fdopen( | |||
| os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"), | |||
| self.flags, self.modes), 'w') as file: | |||
| file.write(code) | |||
| except IOError as error: | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| try: | |||
| with os.fdopen( | |||
| os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"), | |||
| self.flags, stat.S_IRUSR), "w") as rpt_f: | |||
| rpt_f.write(report) | |||
| except IOError as error: | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| def _preprocess_node_args(self, node, module_key): | |||
| """ | |||
| @@ -625,7 +644,6 @@ class HierarchicalTree(Tree): | |||
| nd_inst = self.get_node(successor_name) | |||
| # Generate variable name here, then | |||
| # to generate args. | |||
| # if nd_inst.data.node_type == NodeType.OPERATION.value: | |||
| if created: | |||
| nd_inst.data.variable_name = self._module_vars[module_key][idx] | |||
| else: | |||
| @@ -16,6 +16,7 @@ | |||
| import abc | |||
| from collections import OrderedDict | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from ..constant import SEPARATOR_IN_ONNX_OP | |||
| from ..mapper.base import Mapper | |||
| @@ -66,8 +67,11 @@ class BaseGraph(metaclass=abc.ABCMeta): | |||
| """Control the create action of graph.""" | |||
| model_param = args[0] if args else kwargs.get(cls._REQUIRED_PARAM_OF_MODEL) | |||
| if not model_param: | |||
| raise ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` " | |||
| f"can not be None.") | |||
| error = ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` " | |||
| f"can not be None.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| return super(BaseGraph, cls).__new__(cls) | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| """Third party graph parser.""" | |||
| import os | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import GraphParser | |||
| @@ -34,12 +35,24 @@ class PyTorchGraphParser(GraphParser): | |||
| import torch | |||
| if not os.path.exists(model_path): | |||
| raise FileNotFoundError("`model_path` must be assigned with " | |||
| "an existed file path.") | |||
| if torch.cuda.is_available(): | |||
| model = torch.load(f=model_path) | |||
| else: | |||
| model = torch.load(f=model_path, map_location="cpu") | |||
| error = FileNotFoundError("`model_path` must be assigned with " | |||
| "an existed file path.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| try: | |||
| if torch.cuda.is_available(): | |||
| model = torch.load(f=model_path) | |||
| else: | |||
| model = torch.load(f=model_path, map_location="cpu") | |||
| except ModuleNotFoundError: | |||
| error_msg = \ | |||
| "Cannot find model scripts in system path, " \ | |||
| "set `--project_path` to the path of model scripts folder correctly." | |||
| error = ModuleNotFoundError(error_msg) | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| return model | |||
| @@ -17,6 +17,7 @@ import warnings | |||
| import re | |||
| from typing import Dict, NoReturn | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import Graph | |||
| from .input_node import InputNode | |||
| from .pytorch_graph_node import PyTorchGraphNode | |||
| @@ -89,12 +90,18 @@ class PyTorchGraph(Graph): | |||
| """ | |||
| if not input_shape: | |||
| raise ValueError("`input_shape` can not be None.") | |||
| error = ValueError("`input_shape` can not be None.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| for item in input_shape: | |||
| if not isinstance(item, int): | |||
| raise ValueError(f"Only support model with one input now, " | |||
| f"and each shape value in `input_shape` should be int.") | |||
| error = ValueError(f"Only support model with one input now, " | |||
| f"and each shape value in `input_shape` should be int.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| def build(self, input_shape): | |||
| """ | |||
| @@ -122,9 +129,11 @@ class PyTorchGraph(Graph): | |||
| Returns: | |||
| list, shape. | |||
| """ | |||
| pattern = re.compile(r"\d+:\d*") | |||
| if not pattern.findall(shape): | |||
| if "," not in shape: | |||
| return [] | |||
| for s in shape.split(","): | |||
| if not s: | |||
| return [] | |||
| return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')] | |||
| feed_forward_ipt_shape = (1, *input_shape) | |||
| @@ -133,10 +142,15 @@ class PyTorchGraph(Graph): | |||
| # Assign execution mode to eval. | |||
| self.model.eval() | |||
| with OverloadTorchModuleTemporarily() as _: | |||
| # In pytorch higher version, trace function has a known. | |||
| graph = onnx_tracer(self.model, batched_sample, | |||
| OperatorExportTypes.ONNX) | |||
| try: | |||
| with OverloadTorchModuleTemporarily() as _: | |||
| # In pytorch higher version, trace function has a known. | |||
| graph = onnx_tracer(self.model, batched_sample, | |||
| OperatorExportTypes.ONNX) | |||
| except RuntimeError as error: | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| nodes = list(graph.nodes()) | |||
| @@ -190,6 +204,37 @@ class PyTorchGraph(Graph): | |||
| """ | |||
| raise NotImplementedError() | |||
| def to_hierarchical_tree(self): | |||
| """ | |||
| Generate hierarchical tree based on graph. | |||
| """ | |||
| from ..hierarchical_tree import HierarchicalTree | |||
| tree = HierarchicalTree() | |||
| node_input = None | |||
| for _, node_name in enumerate(self.nodes_in_topological_order): | |||
| node_inst = self.get_node(node_name) | |||
| node_output = self._shape_dict.get(node_name) | |||
| if node_inst.in_degree == 0: | |||
| # If in-degree equals to zero, then it's a input node. | |||
| continue | |||
| # If the node is on the top, then fetch its input | |||
| # from input table. | |||
| if not node_input: | |||
| node_input = self._input_shape.get(node_name) | |||
| if not node_input: | |||
| error = ValueError(f"This model is not supported now. " | |||
| f"Cannot find {node_name}'s input shape.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| tree.insert(node_inst, node_name, node_input, node_output) | |||
| node_input = node_output | |||
| return tree | |||
| def build_connection(self, src, tgt) -> NoReturn: | |||
| """ | |||
| Build connection between source node and target node. | |||
| @@ -229,8 +274,11 @@ class PyTorchGraph(Graph): | |||
| """ | |||
| Load graph metadata. | |||
| """ | |||
| raise NotImplementedError("class `PyTorchGraph` has not implemented " | |||
| "`load_metadata()`.") | |||
| error = NotImplementedError("class `PyTorchGraph` has not implemented " | |||
| "`load_metadata()`.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| @staticmethod | |||
| def load_graph(graph_path: str): | |||
| @@ -116,8 +116,6 @@ class PyTorchGraphNode(GraphNode): | |||
| """ | |||
| if not self._module_name_frozen: | |||
| module_name = self.tag | |||
| # if self._node_type == NodeType.CLASS.value: | |||
| # module_name = f"{module_name[0].upper()}{module_name[1:]}" | |||
| return module_name | |||
| return self._module_name | |||