| @@ -19,6 +19,9 @@ import argparse | |||||
| import mindinsight | import mindinsight | ||||
| from mindinsight.mindconverter.converter import main | from mindinsight.mindconverter.converter import main | ||||
| from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter | |||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| class FileDirAction(argparse.Action): | class FileDirAction(argparse.Action): | ||||
| @@ -92,6 +95,26 @@ class OutputDirAction(argparse.Action): | |||||
| setattr(namespace, self.dest, output) | setattr(namespace, self.dest, output) | ||||
| class ProjectPathAction(argparse.Action): | |||||
| """Project directory action class definition.""" | |||||
| def __call__(self, parser, namespace, values, option_string=None): | |||||
| """ | |||||
| Inherited __call__ method from argparse.Action. | |||||
| Args: | |||||
| parser (ArgumentParser): Passed-in argument parser. | |||||
| namespace (Namespace): Namespace object to hold arguments. | |||||
| values (object): Argument values with type depending on argument definition. | |||||
| option_string (str): Optional string for specific argument name. Default: None. | |||||
| """ | |||||
| outfile_dir = FileDirAction.check_path(parser, values, option_string) | |||||
| if not os.path.isdir(outfile_dir): | |||||
| parser.error(f'{option_string} [{outfile_dir}] should be a directory.') | |||||
| setattr(namespace, self.dest, outfile_dir) | |||||
| class InFileAction(argparse.Action): | class InFileAction(argparse.Action): | ||||
| """Input File action class definition.""" | """Input File action class definition.""" | ||||
| @@ -134,6 +157,29 @@ class LogFileAction(argparse.Action): | |||||
| setattr(namespace, self.dest, outfile_dir) | setattr(namespace, self.dest, outfile_dir) | ||||
| class ShapeAction(argparse.Action): | |||||
| """Shape action class definition.""" | |||||
| def __call__(self, parser, namespace, values, option_string=None): | |||||
| """ | |||||
| Inherited __call__ method from FileDirAction. | |||||
| Args: | |||||
| parser (ArgumentParser): Passed-in argument parser. | |||||
| namespace (Namespace): Namespace object to hold arguments. | |||||
| values (object): Argument values with type depending on argument definition. | |||||
| option_string (str): Optional string for specific argument name. Default: None. | |||||
| """ | |||||
| in_shape = None | |||||
| shape_str = values | |||||
| try: | |||||
| in_shape = [int(num_shape) for num_shape in shape_str.split(',')] | |||||
| except ValueError: | |||||
| parser.error( | |||||
| f"{option_string} {shape_str} should be a list of integer split by ',', check it please.") | |||||
| setattr(namespace, self.dest, in_shape) | |||||
| def cli_entry(): | def cli_entry(): | ||||
| """Entry point for mindconverter CLI.""" | """Entry point for mindconverter CLI.""" | ||||
| @@ -153,9 +199,36 @@ def cli_entry(): | |||||
| '--in_file', | '--in_file', | ||||
| type=str, | type=str, | ||||
| action=InFileAction, | action=InFileAction, | ||||
| required=True, | |||||
| required=False, | |||||
| default=None, | |||||
| help=""" | |||||
| Specify path for script file. | |||||
| """) | |||||
| parser.add_argument( | |||||
| '--model_file', | |||||
| type=str, | |||||
| action=InFileAction, | |||||
| required=False, | |||||
| help=""" | help=""" | ||||
| Specify path for script file. | |||||
| Pytorch .pth model file path ot use graph | |||||
| based schema to do script generation. When | |||||
| `--in_file` and `--model_path` are both provided, | |||||
| use AST schema as default. | |||||
| Usage: --model_file ~/pytorch_file/net.pth. | |||||
| """) | |||||
| parser.add_argument( | |||||
| '--shape', | |||||
| type=str, | |||||
| action=ShapeAction, | |||||
| default=None, | |||||
| required=False, | |||||
| help=""" | |||||
| Optional, excepted input tensor shape of | |||||
| `--model_file`. It's required when use graph based | |||||
| schema. | |||||
| Usage: --shape 3,244,244 | |||||
| """) | """) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| @@ -172,11 +245,24 @@ def cli_entry(): | |||||
| '--report', | '--report', | ||||
| type=str, | type=str, | ||||
| action=LogFileAction, | action=LogFileAction, | ||||
| default=os.getcwd(), | |||||
| default=None, | |||||
| help=""" | help=""" | ||||
| Specify report directory. Default is the current working directory. | Specify report directory. Default is the current working directory. | ||||
| """) | """) | ||||
| parser.add_argument( | |||||
| '--project_path', | |||||
| type=str, | |||||
| action=ProjectPathAction, | |||||
| required=False, | |||||
| default=None, | |||||
| help=""" | |||||
| Optional, pytorch scripts project path. If pytorch | |||||
| project is not in PYTHONPATH, please assign | |||||
| `--project_path' when use graph based schema. | |||||
| Usage: --project_path ~/script_file/ | |||||
| """) | |||||
| argv = sys.argv[1:] | argv = sys.argv[1:] | ||||
| if not argv: | if not argv: | ||||
| argv = ['-h'] | argv = ['-h'] | ||||
| @@ -185,30 +271,58 @@ def cli_entry(): | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| mode = permissions << 6 | mode = permissions << 6 | ||||
| os.makedirs(args.output, mode=mode, exist_ok=True) | os.makedirs(args.output, mode=mode, exist_ok=True) | ||||
| if args.report is None: | |||||
| args.report = args.output | |||||
| os.makedirs(args.report, mode=mode, exist_ok=True) | os.makedirs(args.report, mode=mode, exist_ok=True) | ||||
| _run(args.in_file, args.output, args.report) | |||||
| _run(args.in_file, args.model_file, args.shape, args.output, args.report, args.project_path) | |||||
| def _run(in_files, out_dir, report): | |||||
| def _run(in_files, model_file, shape, out_dir, report, project_path): | |||||
| """ | """ | ||||
| Run converter command. | Run converter command. | ||||
| Args: | Args: | ||||
| in_files (str): The file path or directory to convert. | in_files (str): The file path or directory to convert. | ||||
| model_file(str): The pytorch .pth to convert on graph based schema. | |||||
| shape(list): The input tensor shape of module_file. | |||||
| out_dir (str): The output directory to save converted file. | out_dir (str): The output directory to save converted file. | ||||
| report (str): The report file path. | report (str): The report file path. | ||||
| project_path(str): Pytorch scripts project path. | |||||
| """ | """ | ||||
| files_config = { | |||||
| 'root_path': in_files if in_files else '', | |||||
| 'in_files': [], | |||||
| 'outfile_dir': out_dir, | |||||
| 'report_dir': report | |||||
| } | |||||
| if os.path.isfile(in_files): | |||||
| files_config['root_path'] = os.path.dirname(in_files) | |||||
| files_config['in_files'] = [in_files] | |||||
| if in_files: | |||||
| files_config = { | |||||
| 'root_path': in_files, | |||||
| 'in_files': [], | |||||
| 'outfile_dir': out_dir, | |||||
| 'report_dir': report if report else out_dir | |||||
| } | |||||
| if os.path.isfile(in_files): | |||||
| files_config['root_path'] = os.path.dirname(in_files) | |||||
| files_config['in_files'] = [in_files] | |||||
| else: | |||||
| for root_dir, _, files in os.walk(in_files): | |||||
| for file in files: | |||||
| files_config['in_files'].append(os.path.join(root_dir, file)) | |||||
| main(files_config) | |||||
| elif model_file: | |||||
| file_config = { | |||||
| 'model_file': model_file, | |||||
| 'shape': shape if shape else [], | |||||
| 'outfile_dir': out_dir, | |||||
| 'report_dir': report if report else out_dir | |||||
| } | |||||
| if project_path: | |||||
| paths = sys.path | |||||
| if project_path not in paths: | |||||
| sys.path.append(project_path) | |||||
| main_graph_base_converter(file_config) | |||||
| else: | else: | ||||
| for root_dir, _, files in os.walk(in_files): | |||||
| for file in files: | |||||
| files_config['in_files'].append(os.path.join(root_dir, file)) | |||||
| main(files_config) | |||||
| error_msg = "`--in_files` and `--model_file` should be set at least one." | |||||
| error = FileNotFoundError(error_msg) | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| @@ -18,6 +18,7 @@ import argparse | |||||
| from importlib.util import find_spec | from importlib.util import find_spec | ||||
| import mindinsight | import mindinsight | ||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from .mapper import ONNXToMindSporeMapper | from .mapper import ONNXToMindSporeMapper | ||||
| permissions = os.R_OK | os.W_OK | os.X_OK | permissions = os.R_OK | os.W_OK | os.X_OK | ||||
| @@ -57,9 +58,12 @@ def torch_installation_validation(func): | |||||
| checkpoint_path: str = None): | checkpoint_path: str = None): | ||||
| # Check whether pytorch is installed. | # Check whether pytorch is installed. | ||||
| if not find_spec("torch"): | if not find_spec("torch"): | ||||
| raise ModuleNotFoundError("PyTorch is required when using graph based " | |||||
| "scripts converter, and PyTorch vision must " | |||||
| "be consisted with model generation runtime.") | |||||
| error = ModuleNotFoundError("PyTorch is required when using graph based " | |||||
| "scripts converter, and PyTorch vision must " | |||||
| "be consisted with model generation runtime.") | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| func(graph_path=graph_path, sample_shape=sample_shape, | func(graph_path=graph_path, sample_shape=sample_shape, | ||||
| output_folder=output_folder, report_folder=report_folder, | output_folder=output_folder, report_folder=report_folder, | ||||
| @@ -93,10 +97,14 @@ def graph_based_converter(graph_path: str, sample_shape: tuple, | |||||
| report_folder=report_folder) | report_folder=report_folder) | ||||
| if __name__ == '__main__': | |||||
| args, _ = parser.parse_known_args() | |||||
| graph_based_converter(graph_path=args.graph, | |||||
| sample_shape=args.sample_shape, | |||||
| output_folder=args.output, | |||||
| report_folder=args.report, | |||||
| checkpoint_path=args.ckpt) | |||||
| def main_graph_base_converter(file_config): | |||||
| """ | |||||
| The entrance for converter, script files will be converted. | |||||
| Args: | |||||
| file_config (dict): The config of file which to convert. | |||||
| """ | |||||
| graph_based_converter(graph_path=file_config['model_file'], | |||||
| sample_shape=file_config['shape'], | |||||
| output_folder=file_config['outfile_dir'], | |||||
| report_folder=file_config['report_dir']) | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Define hierarchical tree.""" | """Define hierarchical tree.""" | ||||
| import os | import os | ||||
| import stat | |||||
| from copy import deepcopy | from copy import deepcopy | ||||
| from typing import NoReturn, Union | from typing import NoReturn, Union | ||||
| from queue import Queue | from queue import Queue | ||||
| @@ -21,6 +22,8 @@ from queue import Queue | |||||
| from yapf.yapflib.yapf_api import FormatCode | from yapf.yapflib.yapf_api import FormatCode | ||||
| from treelib import Tree, Node | from treelib import Tree, Node | ||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from .name_mgr import ModuleNameMgr, GlobalVarNameMgr | from .name_mgr import ModuleNameMgr, GlobalVarNameMgr | ||||
| from ..mapper.base import Mapper | from ..mapper.base import Mapper | ||||
| from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode | ||||
| @@ -34,6 +37,10 @@ GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() | |||||
| class HierarchicalTree(Tree): | class HierarchicalTree(Tree): | ||||
| """Define hierarchical tree.""" | """Define hierarchical tree.""" | ||||
| flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL | |||||
| modes = stat.S_IRUSR | stat.S_IWUSR | |||||
| modes_usr = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR | |||||
| _root_created = False | _root_created = False | ||||
| ROOT_LEVEL = 0 | ROOT_LEVEL = 0 | ||||
| @@ -162,19 +169,31 @@ class HierarchicalTree(Tree): | |||||
| report_folder = os.path.abspath(report_folder) | report_folder = os.path.abspath(report_folder) | ||||
| if not os.path.exists(out_folder): | if not os.path.exists(out_folder): | ||||
| os.makedirs(out_folder) | |||||
| os.makedirs(out_folder, self.modes_usr) | |||||
| if not os.path.exists(report_folder): | if not os.path.exists(report_folder): | ||||
| os.makedirs(report_folder) | |||||
| os.makedirs(report_folder, self.modes_usr) | |||||
| for file_name in code_fragments: | for file_name in code_fragments: | ||||
| code, report = code_fragments[file_name] | code, report = code_fragments[file_name] | ||||
| with open(os.path.join(os.path.abspath(out_folder), | |||||
| f"{file_name}.py"), "w") as file: | |||||
| file.write(code) | |||||
| with open(os.path.join(report_folder, | |||||
| f"report_of_{file_name}.txt"), "w") as rpt_f: | |||||
| rpt_f.write(report) | |||||
| try: | |||||
| with os.fdopen( | |||||
| os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"), | |||||
| self.flags, self.modes), 'w') as file: | |||||
| file.write(code) | |||||
| except IOError as error: | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| try: | |||||
| with os.fdopen( | |||||
| os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"), | |||||
| self.flags, stat.S_IRUSR), "w") as rpt_f: | |||||
| rpt_f.write(report) | |||||
| except IOError as error: | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| def _preprocess_node_args(self, node, module_key): | def _preprocess_node_args(self, node, module_key): | ||||
| """ | """ | ||||
| @@ -625,7 +644,6 @@ class HierarchicalTree(Tree): | |||||
| nd_inst = self.get_node(successor_name) | nd_inst = self.get_node(successor_name) | ||||
| # Generate variable name here, then | # Generate variable name here, then | ||||
| # to generate args. | # to generate args. | ||||
| # if nd_inst.data.node_type == NodeType.OPERATION.value: | |||||
| if created: | if created: | ||||
| nd_inst.data.variable_name = self._module_vars[module_key][idx] | nd_inst.data.variable_name = self._module_vars[module_key][idx] | ||||
| else: | else: | ||||
| @@ -16,6 +16,7 @@ | |||||
| import abc | import abc | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from ..constant import SEPARATOR_IN_ONNX_OP | from ..constant import SEPARATOR_IN_ONNX_OP | ||||
| from ..mapper.base import Mapper | from ..mapper.base import Mapper | ||||
| @@ -66,8 +67,11 @@ class BaseGraph(metaclass=abc.ABCMeta): | |||||
| """Control the create action of graph.""" | """Control the create action of graph.""" | ||||
| model_param = args[0] if args else kwargs.get(cls._REQUIRED_PARAM_OF_MODEL) | model_param = args[0] if args else kwargs.get(cls._REQUIRED_PARAM_OF_MODEL) | ||||
| if not model_param: | if not model_param: | ||||
| raise ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` " | |||||
| f"can not be None.") | |||||
| error = ValueError(f"`{cls._REQUIRED_PARAM_OF_MODEL}` " | |||||
| f"can not be None.") | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| return super(BaseGraph, cls).__new__(cls) | return super(BaseGraph, cls).__new__(cls) | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| """Third party graph parser.""" | """Third party graph parser.""" | ||||
| import os | import os | ||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from .base import GraphParser | from .base import GraphParser | ||||
| @@ -34,12 +35,24 @@ class PyTorchGraphParser(GraphParser): | |||||
| import torch | import torch | ||||
| if not os.path.exists(model_path): | if not os.path.exists(model_path): | ||||
| raise FileNotFoundError("`model_path` must be assigned with " | |||||
| "an existed file path.") | |||||
| if torch.cuda.is_available(): | |||||
| model = torch.load(f=model_path) | |||||
| else: | |||||
| model = torch.load(f=model_path, map_location="cpu") | |||||
| error = FileNotFoundError("`model_path` must be assigned with " | |||||
| "an existed file path.") | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| try: | |||||
| if torch.cuda.is_available(): | |||||
| model = torch.load(f=model_path) | |||||
| else: | |||||
| model = torch.load(f=model_path, map_location="cpu") | |||||
| except ModuleNotFoundError: | |||||
| error_msg = \ | |||||
| "Cannot find model scripts in system path, " \ | |||||
| "set `--project_path` to the path of model scripts folder correctly." | |||||
| error = ModuleNotFoundError(error_msg) | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| return model | return model | ||||
| @@ -17,6 +17,7 @@ import warnings | |||||
| import re | import re | ||||
| from typing import Dict, NoReturn | from typing import Dict, NoReturn | ||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from .base import Graph | from .base import Graph | ||||
| from .input_node import InputNode | from .input_node import InputNode | ||||
| from .pytorch_graph_node import PyTorchGraphNode | from .pytorch_graph_node import PyTorchGraphNode | ||||
| @@ -89,12 +90,18 @@ class PyTorchGraph(Graph): | |||||
| """ | """ | ||||
| if not input_shape: | if not input_shape: | ||||
| raise ValueError("`input_shape` can not be None.") | |||||
| error = ValueError("`input_shape` can not be None.") | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| for item in input_shape: | for item in input_shape: | ||||
| if not isinstance(item, int): | if not isinstance(item, int): | ||||
| raise ValueError(f"Only support model with one input now, " | |||||
| f"and each shape value in `input_shape` should be int.") | |||||
| error = ValueError(f"Only support model with one input now, " | |||||
| f"and each shape value in `input_shape` should be int.") | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| def build(self, input_shape): | def build(self, input_shape): | ||||
| """ | """ | ||||
| @@ -122,9 +129,11 @@ class PyTorchGraph(Graph): | |||||
| Returns: | Returns: | ||||
| list, shape. | list, shape. | ||||
| """ | """ | ||||
| pattern = re.compile(r"\d+:\d*") | |||||
| if not pattern.findall(shape): | |||||
| if "," not in shape: | |||||
| return [] | return [] | ||||
| for s in shape.split(","): | |||||
| if not s: | |||||
| return [] | |||||
| return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')] | return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')] | ||||
| feed_forward_ipt_shape = (1, *input_shape) | feed_forward_ipt_shape = (1, *input_shape) | ||||
| @@ -133,10 +142,15 @@ class PyTorchGraph(Graph): | |||||
| # Assign execution mode to eval. | # Assign execution mode to eval. | ||||
| self.model.eval() | self.model.eval() | ||||
| with OverloadTorchModuleTemporarily() as _: | |||||
| # In pytorch higher version, trace function has a known. | |||||
| graph = onnx_tracer(self.model, batched_sample, | |||||
| OperatorExportTypes.ONNX) | |||||
| try: | |||||
| with OverloadTorchModuleTemporarily() as _: | |||||
| # In pytorch higher version, trace function has a known. | |||||
| graph = onnx_tracer(self.model, batched_sample, | |||||
| OperatorExportTypes.ONNX) | |||||
| except RuntimeError as error: | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| nodes = list(graph.nodes()) | nodes = list(graph.nodes()) | ||||
| @@ -190,6 +204,37 @@ class PyTorchGraph(Graph): | |||||
| """ | """ | ||||
| raise NotImplementedError() | raise NotImplementedError() | ||||
| def to_hierarchical_tree(self): | |||||
| """ | |||||
| Generate hierarchical tree based on graph. | |||||
| """ | |||||
| from ..hierarchical_tree import HierarchicalTree | |||||
| tree = HierarchicalTree() | |||||
| node_input = None | |||||
| for _, node_name in enumerate(self.nodes_in_topological_order): | |||||
| node_inst = self.get_node(node_name) | |||||
| node_output = self._shape_dict.get(node_name) | |||||
| if node_inst.in_degree == 0: | |||||
| # If in-degree equals to zero, then it's a input node. | |||||
| continue | |||||
| # If the node is on the top, then fetch its input | |||||
| # from input table. | |||||
| if not node_input: | |||||
| node_input = self._input_shape.get(node_name) | |||||
| if not node_input: | |||||
| error = ValueError(f"This model is not supported now. " | |||||
| f"Cannot find {node_name}'s input shape.") | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| tree.insert(node_inst, node_name, node_input, node_output) | |||||
| node_input = node_output | |||||
| return tree | |||||
| def build_connection(self, src, tgt) -> NoReturn: | def build_connection(self, src, tgt) -> NoReturn: | ||||
| """ | """ | ||||
| Build connection between source node and target node. | Build connection between source node and target node. | ||||
| @@ -229,8 +274,11 @@ class PyTorchGraph(Graph): | |||||
| """ | """ | ||||
| Load graph metadata. | Load graph metadata. | ||||
| """ | """ | ||||
| raise NotImplementedError("class `PyTorchGraph` has not implemented " | |||||
| "`load_metadata()`.") | |||||
| error = NotImplementedError("class `PyTorchGraph` has not implemented " | |||||
| "`load_metadata()`.") | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| @staticmethod | @staticmethod | ||||
| def load_graph(graph_path: str): | def load_graph(graph_path: str): | ||||
| @@ -116,8 +116,6 @@ class PyTorchGraphNode(GraphNode): | |||||
| """ | """ | ||||
| if not self._module_name_frozen: | if not self._module_name_frozen: | ||||
| module_name = self.tag | module_name = self.tag | ||||
| # if self._node_type == NodeType.CLASS.value: | |||||
| # module_name = f"{module_name[0].upper()}{module_name[1:]}" | |||||
| return module_name | return module_name | ||||
| return self._module_name | return self._module_name | ||||