Merge pull request !781 from moran/network_validationtags/v1.1.0
| @@ -19,6 +19,7 @@ import argparse | |||||
| import mindinsight | import mindinsight | ||||
| from mindinsight.mindconverter.converter import main | from mindinsight.mindconverter.converter import main | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_SHAPE_NUMBER | |||||
| from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter | from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter | ||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| @@ -38,6 +39,11 @@ class FileDirAction(argparse.Action): | |||||
| option_string (str): Optional string for specific argument name. Default: None. | option_string (str): Optional string for specific argument name. Default: None. | ||||
| """ | """ | ||||
| outfile = values | outfile = values | ||||
| if len(outfile) > ARGUMENT_LENGTH_LIMIT: | |||||
| parser_in.error( | |||||
| f"The length of {option_string}{outfile} should be no more than {ARGUMENT_LENGTH_LIMIT}.") | |||||
| if outfile.startswith('~'): | if outfile.startswith('~'): | ||||
| outfile = os.path.realpath(os.path.expanduser(outfile)) | outfile = os.path.realpath(os.path.expanduser(outfile)) | ||||
| @@ -160,9 +166,6 @@ class ModelFileAction(argparse.Action): | |||||
| if not os.path.isfile(outfile_dir): | if not os.path.isfile(outfile_dir): | ||||
| parser_in.error(f'{option_string} {outfile_dir} is not a file') | parser_in.error(f'{option_string} {outfile_dir} is not a file') | ||||
| if not outfile_dir.endswith('.pth'): | |||||
| parser_in.error(f"{option_string} {outfile_dir} should be a Pytorch model, ending with '.pth'.") | |||||
| setattr(namespace, self.dest, outfile_dir) | setattr(namespace, self.dest, outfile_dir) | ||||
| @@ -200,14 +203,42 @@ class ShapeAction(argparse.Action): | |||||
| """ | """ | ||||
| in_shape = None | in_shape = None | ||||
| shape_str = values | shape_str = values | ||||
| shape_list = shape_str.split(';') | |||||
| if not len(shape_list) == EXPECTED_SHAPE_NUMBER: | |||||
| parser_in.error(f"Only support one shape now, but get {len(shape_list)}.") | |||||
| try: | try: | ||||
| in_shape = [int(num_shape) for num_shape in shape_str.split(',')] | |||||
| in_shape = [int(num_shape) for num_shape in shape_list[0].split(',')] | |||||
| except ValueError: | except ValueError: | ||||
| parser_in.error( | parser_in.error( | ||||
| f"{option_string} {shape_str} should be a list of integer split by ',', check it please.") | f"{option_string} {shape_str} should be a list of integer split by ',', check it please.") | ||||
| setattr(namespace, self.dest, in_shape) | setattr(namespace, self.dest, in_shape) | ||||
| class NodeAction(argparse.Action): | |||||
| """Node action class definition.""" | |||||
| def __call__(self, parser_in, namespace, values, option_string=None): | |||||
| """ | |||||
| Inherited __call__ method from FileDirAction. | |||||
| Args: | |||||
| parser_in (ArgumentParser): Passed-in argument parser. | |||||
| namespace (Namespace): Namespace object to hold arguments. | |||||
| values (object): Argument values with type depending on argument definition. | |||||
| option_string (str): Optional string for specific argument name. Default: None. | |||||
| """ | |||||
| node_str = values | |||||
| if len(node_str) > ARGUMENT_LENGTH_LIMIT: | |||||
| parser_in.error( | |||||
| f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}." | |||||
| ) | |||||
| setattr(namespace, self.dest, node_str) | |||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| prog='mindconverter', | prog='mindconverter', | ||||
| description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__)) | description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__)) | ||||
| @@ -234,7 +265,7 @@ parser.add_argument( | |||||
| action=ModelFileAction, | action=ModelFileAction, | ||||
| required=False, | required=False, | ||||
| help=""" | help=""" | ||||
| PyTorch .pth model file path to use graph | |||||
| PyTorch .pth or Tensorflow .pb model file path to use graph | |||||
| based schema to do script generation. When | based schema to do script generation. When | ||||
| `--in_file` and `--model_file` are both provided, | `--in_file` and `--model_file` are both provided, | ||||
| use AST schema as default. | use AST schema as default. | ||||
| @@ -250,7 +281,29 @@ parser.add_argument( | |||||
| Optional, expected input tensor shape of | Optional, expected input tensor shape of | ||||
| `--model_file`. It's required when use graph based | `--model_file`. It's required when use graph based | ||||
| schema. | schema. | ||||
| Usage: --shape 3,244,244 | |||||
| Usage: --shape 1,3,244,244 | |||||
| """) | |||||
| parser.add_argument( | |||||
| '--input_nodes', | |||||
| type=str, | |||||
| action=NodeAction, | |||||
| default=None, | |||||
| required=False, | |||||
| help=""" | |||||
| Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model. | |||||
| Usage: --input_nodes input_1:0,input_2:0 | |||||
| """) | |||||
| parser.add_argument( | |||||
| '--output_nodes', | |||||
| type=str, | |||||
| action=NodeAction, | |||||
| default=None, | |||||
| required=False, | |||||
| help=""" | |||||
| Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model. | |||||
| Usage: --output_nodes output_1:0,output_2:0 | |||||
| """) | """) | ||||
| parser.add_argument( | parser.add_argument( | ||||
| @@ -305,10 +358,14 @@ def cli_entry(): | |||||
| if args.report is None: | if args.report is None: | ||||
| args.report = args.output | args.report = args.output | ||||
| os.makedirs(args.report, mode=mode, exist_ok=True) | os.makedirs(args.report, mode=mode, exist_ok=True) | ||||
| _run(args.in_file, args.model_file, args.shape, args.output, args.report, args.project_path) | |||||
| _run(args.in_file, args.model_file, | |||||
| args.shape, | |||||
| args.input_nodes, args.output_nodes, | |||||
| args.output, args.report, | |||||
| args.project_path) | |||||
| def _run(in_files, model_file, shape, out_dir, report, project_path): | |||||
| def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report, project_path): | |||||
| """ | """ | ||||
| Run converter command. | Run converter command. | ||||
| @@ -316,6 +373,8 @@ def _run(in_files, model_file, shape, out_dir, report, project_path): | |||||
| in_files (str): The file path or directory to convert. | in_files (str): The file path or directory to convert. | ||||
| model_file(str): The pytorch .pth to convert on graph based schema. | model_file(str): The pytorch .pth to convert on graph based schema. | ||||
| shape(list): The input tensor shape of module_file. | shape(list): The input tensor shape of module_file. | ||||
| input_nodes(str): The input node(s) name of Tensorflow model, split by ','. | |||||
| output_nodes(str): The output node(s) name of Tensorflow model, split by ','. | |||||
| out_dir (str): The output directory to save converted file. | out_dir (str): The output directory to save converted file. | ||||
| report (str): The report file path. | report (str): The report file path. | ||||
| project_path(str): Pytorch scripts project path. | project_path(str): Pytorch scripts project path. | ||||
| @@ -341,6 +400,8 @@ def _run(in_files, model_file, shape, out_dir, report, project_path): | |||||
| file_config = { | file_config = { | ||||
| 'model_file': model_file, | 'model_file': model_file, | ||||
| 'shape': shape if shape else [], | 'shape': shape if shape else [], | ||||
| 'input_nodes': input_nodes, | |||||
| 'output_nodes': output_nodes, | |||||
| 'outfile_dir': out_dir, | 'outfile_dir': out_dir, | ||||
| 'report_dir': report if report else out_dir | 'report_dir': report if report else out_dir | ||||
| } | } | ||||
| @@ -26,6 +26,7 @@ class ConverterErrors(ScriptConverterErrors): | |||||
| NODE_TYPE_NOT_SUPPORT = 2 | NODE_TYPE_NOT_SUPPORT = 2 | ||||
| CODE_SYNTAX_ERROR = 3 | CODE_SYNTAX_ERROR = 3 | ||||
| NODE_INPUT_TYPE_NOT_SUPPORT = 4 | NODE_INPUT_TYPE_NOT_SUPPORT = 4 | ||||
| UNKNOWN_MODEL = 5 | |||||
| class ScriptNotSupport(MindInsightException): | class ScriptNotSupport(MindInsightException): | ||||
| @@ -62,3 +63,12 @@ class NodeInputTypeNotSupport(MindInsightException): | |||||
| super(NodeInputTypeNotSupport, self).__init__(ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, | super(NodeInputTypeNotSupport, self).__init__(ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, | ||||
| msg, | msg, | ||||
| http_code=400) | http_code=400) | ||||
| class UnknownModel(MindInsightException): | |||||
| """The unknown model error.""" | |||||
| def __init__(self, msg): | |||||
| super(UnknownModel, self).__init__(ConverterErrors.UNKNOWN_MODEL, | |||||
| msg, | |||||
| http_code=400) | |||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Graph based scripts converter definition.""" | """Graph based scripts converter definition.""" | ||||
| from .framework import graph_based_converter | |||||
| from .framework import graph_based_converter_pytorch_to_ms | |||||
| from .framework import graph_based_converter_tf_to_ms | |||||
| __all__ = ["graph_based_converter"] | |||||
| __all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"] | |||||
| @@ -35,6 +35,15 @@ ONNX_TYPE_FLOAT = 1 | |||||
| ONNX_TYPE_FLOATS = 6 | ONNX_TYPE_FLOATS = 6 | ||||
| ONNX_TYPE_STRING = 3 | ONNX_TYPE_STRING = 3 | ||||
| BINARY_HEADER_PYTORCH_FILE = \ | |||||
| b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00' | |||||
| BINARY_HEADER_PYTORCH_BITS = 32 | |||||
| ARGUMENT_LENGTH_LIMIT = 512 | |||||
| EXPECTED_SHAPE_NUMBER = 1 | |||||
| @unique | @unique | ||||
| class CodeFormatConfig(Enum): | class CodeFormatConfig(Enum): | ||||
| @@ -54,3 +63,9 @@ class NodeType(Enum): | |||||
| class InputType(Enum): | class InputType(Enum): | ||||
| TENSOR = "tensor" | TENSOR = "tensor" | ||||
| LIST = "list" | LIST = "list" | ||||
| @unique | |||||
| class FrameworkType(Enum): | |||||
| PYTORCH = 0 | |||||
| TENSORFLOW = 1 | |||||
| @@ -16,12 +16,16 @@ | |||||
| import os | import os | ||||
| import re | import re | ||||
| import argparse | import argparse | ||||
| from importlib import import_module | |||||
| from importlib.util import find_spec | from importlib.util import find_spec | ||||
| import mindinsight | import mindinsight | ||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | |||||
| BINARY_HEADER_PYTORCH_BITS | |||||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | ||||
| from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport | |||||
| from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport, UnknownModel | |||||
| from mindinsight.utils.exceptions import ParamMissError | |||||
| permissions = os.R_OK | os.W_OK | os.X_OK | permissions = os.R_OK | os.W_OK | os.X_OK | ||||
| os.umask(permissions << 3 | permissions) | os.umask(permissions << 3 | permissions) | ||||
| @@ -56,8 +60,7 @@ def torch_installation_validation(func): | |||||
| """ | """ | ||||
| def _f(graph_path: str, sample_shape: tuple, | def _f(graph_path: str, sample_shape: tuple, | ||||
| output_folder: str, report_folder: str = None, | |||||
| checkpoint_path: str = None): | |||||
| output_folder: str, report_folder: str = None): | |||||
| # Check whether pytorch is installed. | # Check whether pytorch is installed. | ||||
| if not find_spec("torch"): | if not find_spec("torch"): | ||||
| error = ModuleNotFoundError("PyTorch is required when using graph based " | error = ModuleNotFoundError("PyTorch is required when using graph based " | ||||
| @@ -67,9 +70,35 @@ def torch_installation_validation(func): | |||||
| log.exception(error) | log.exception(error) | ||||
| raise error | raise error | ||||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||||
| output_folder=output_folder, report_folder=report_folder) | |||||
| return _f | |||||
| def tf_installation_validation(func): | |||||
| """ | |||||
| Validate args of func. | |||||
| Args: | |||||
| func(type): Function. | |||||
| Returns: | |||||
| type, inner function. | |||||
| """ | |||||
| def _f(graph_path: str, sample_shape: tuple, | |||||
| output_folder: str, report_folder: str = None, | |||||
| input_nodes: str = None, output_nodes: str = None): | |||||
| # Check whether tensorflow is installed. | |||||
| if not find_spec("tensorflow") or not find_spec("tf2onnx"): | |||||
| error = ModuleNotFoundError("Tensorflow and tf2onnx are required when using " | |||||
| "graph based scripts converter.") | |||||
| log.error(str(error)) | |||||
| raise error | |||||
| func(graph_path=graph_path, sample_shape=sample_shape, | func(graph_path=graph_path, sample_shape=sample_shape, | ||||
| output_folder=output_folder, report_folder=report_folder, | output_folder=output_folder, report_folder=report_folder, | ||||
| checkpoint_path=checkpoint_path) | |||||
| input_nodes=input_nodes, output_nodes=output_nodes) | |||||
| return _f | return _f | ||||
| @@ -85,32 +114,33 @@ def _extract_model_name(model_path): | |||||
| str: Name of Converted model. | str: Name of Converted model. | ||||
| """ | """ | ||||
| model_name = re.findall(r".*[/](.*).pth", model_path)[-1] | |||||
| model_name = re.findall(r".*[/](.*)(?:\.pth|\.pb)", model_path)[-1] | |||||
| return model_name | return model_name | ||||
| @torch_installation_validation | @torch_installation_validation | ||||
| def graph_based_converter(graph_path: str, sample_shape: tuple, | |||||
| output_folder: str, report_folder: str = None, | |||||
| checkpoint_path: str = None): | |||||
| def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||||
| output_folder: str, report_folder: str = None): | |||||
| """ | """ | ||||
| Graph based scripts converter. | |||||
| Pytoch to MindSpore based on Graph. | |||||
| Args: | Args: | ||||
| graph_path (str): Graph file path. | graph_path (str): Graph file path. | ||||
| sample_shape (tuple): Input shape of the model. | sample_shape (tuple): Input shape of the model. | ||||
| output_folder (str): Output folder. | output_folder (str): Output folder. | ||||
| report_folder (str): Report output folder path. | report_folder (str): Report output folder path. | ||||
| checkpoint_path (str): Checkpoint file path. | |||||
| """ | """ | ||||
| from .third_party_graph import GraphFactory | |||||
| from .hierarchical_tree import HierarchicalTreeFactory | |||||
| third_party_graph_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph') | |||||
| hierarchical_tree_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree') | |||||
| cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') | |||||
| cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') | |||||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | |||||
| checkpoint=checkpoint_path) | |||||
| graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape) | |||||
| try: | try: | ||||
| hierarchical_tree = HierarchicalTreeFactory.create(graph_obj) | |||||
| hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj) | |||||
| except Exception as e: | except Exception as e: | ||||
| log.exception(e) | log.exception(e) | ||||
| log.error("Error occur when create hierarchical tree.") | log.error("Error occur when create hierarchical tree.") | ||||
| @@ -123,6 +153,49 @@ def graph_based_converter(graph_path: str, sample_shape: tuple, | |||||
| report_folder=report_folder) | report_folder=report_folder) | ||||
| @tf_installation_validation | |||||
| def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||||
| input_nodes: str, output_nodes: str, | |||||
| output_folder: str, report_folder: str = None): | |||||
| """ | |||||
| Tensorflow to MindSpore based on Graph. | |||||
| Args: | |||||
| graph_path(str): Graph file path. | |||||
| sample_shape(tuple): Input shape of the model. | |||||
| input_nodes(str): Input node(s) of the model. | |||||
| output_nodes(str): Output node(s) of the model. | |||||
| output_folder(str): Output folder. | |||||
| report_folder(str): Report output folder path. | |||||
| """ | |||||
| third_party_graph_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph') | |||||
| hierarchical_tree_module = import_module( | |||||
| 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree') | |||||
| cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') | |||||
| cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') | |||||
| # Close unnecessary log. | |||||
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |||||
| graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, | |||||
| input_nodes=input_nodes, output_nodes=output_nodes) | |||||
| try: | |||||
| hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj) | |||||
| except Exception as e: | |||||
| log.exception(e) | |||||
| log.error("Error occur when create hierarchical tree.") | |||||
| raise NodeTypeNotSupport("This model is not supported now.") | |||||
| model_name = _extract_model_name(graph_path) | |||||
| hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | |||||
| model_name=model_name, | |||||
| report_folder=report_folder, | |||||
| scope_name_map=scope_name_map) | |||||
| def main_graph_base_converter(file_config): | def main_graph_base_converter(file_config): | ||||
| """ | """ | ||||
| The entrance for converter, script files will be converted. | The entrance for converter, script files will be converted. | ||||
| @@ -130,7 +203,54 @@ def main_graph_base_converter(file_config): | |||||
| Args: | Args: | ||||
| file_config (dict): The config of file which to convert. | file_config (dict): The config of file which to convert. | ||||
| """ | """ | ||||
| graph_based_converter(graph_path=file_config['model_file'], | |||||
| sample_shape=file_config['shape'], | |||||
| output_folder=file_config['outfile_dir'], | |||||
| report_folder=file_config['report_dir']) | |||||
| graph_path = file_config['model_file'] | |||||
| frame_type = get_framework_type(graph_path) | |||||
| if frame_type == FrameworkType.PYTORCH.value: | |||||
| graph_based_converter_pytorch_to_ms(graph_path=graph_path, | |||||
| sample_shape=file_config['shape'], | |||||
| output_folder=file_config['outfile_dir'], | |||||
| report_folder=file_config['report_dir']) | |||||
| elif frame_type == FrameworkType.TENSORFLOW.value: | |||||
| check_params = ['input_nodes', 'output_nodes'] | |||||
| check_params_exist(check_params, file_config) | |||||
| graph_based_converter_tf_to_ms(graph_path=graph_path, | |||||
| sample_shape=file_config['shape'], | |||||
| input_nodes=file_config['input_nodes'], | |||||
| output_nodes=file_config['output_nodes'], | |||||
| output_folder=file_config['outfile_dir'], | |||||
| report_folder=file_config['report_dir']) | |||||
| else: | |||||
| error_msg = "Get UNSUPPORTED model." | |||||
| error = UnknownModel(error_msg) | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| def get_framework_type(model_path): | |||||
| """Get framework type.""" | |||||
| try: | |||||
| with open(model_path, 'rb') as f: | |||||
| if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: | |||||
| framework_type = FrameworkType.PYTORCH.value | |||||
| else: | |||||
| framework_type = FrameworkType.TENSORFLOW.value | |||||
| except IOError: | |||||
| error_msg = "Get UNSUPPORTED model." | |||||
| error = UnknownModel(error_msg) | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| return framework_type | |||||
| def check_params_exist(params: list, config): | |||||
| """Check params exist.""" | |||||
| miss_param_list = '' | |||||
| for param in params: | |||||
| if not config.get(param) or not config[param]: | |||||
| miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param | |||||
| if miss_param_list: | |||||
| raise ParamMissError(miss_param_list) | |||||
| @@ -182,6 +182,7 @@ class HierarchicalTree(Tree): | |||||
| mapper (Mapper): Mapper of third party framework and mindspore. | mapper (Mapper): Mapper of third party framework and mindspore. | ||||
| model_name(str): Name of Converted model. | model_name(str): Name of Converted model. | ||||
| out_folder (str): Output folder. | out_folder (str): Output folder. | ||||
| scope_name_map(str): Scope name map of tensorflow. | |||||
| """ | """ | ||||
| if scope_name_map: | if scope_name_map: | ||||
| @@ -117,12 +117,14 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _operation_name_in_ms(*args, **kwargs): | def _operation_name_in_ms(*args, **kwargs): | ||||
| if not kwargs['weights'].get('weight'): # is from tf | |||||
| weight = kwargs['weights'].get('weight', 'empty') | |||||
| if weight == 'empty': # is from tf | |||||
| kernel_size = kwargs['params'].get('kernel_shape') | kernel_size = kwargs['params'].get('kernel_shape') | ||||
| dim = len(kernel_size) | dim = len(kernel_size) | ||||
| return f"nn.Conv{dim}d" | return f"nn.Conv{dim}d" | ||||
| weight = kwargs['weights']['weight'].numpy() | |||||
| weight = weight.numpy() | |||||
| dim = weight.ndim - 2 | dim = weight.ndim - 2 | ||||
| return f"nn.Conv{dim}d" | return f"nn.Conv{dim}d" | ||||
| @@ -131,7 +133,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| weights = kwargs['weights'] | weights = kwargs['weights'] | ||||
| params = kwargs['params'] | params = kwargs['params'] | ||||
| if not weights.get('weight'): # is from tf | |||||
| if weights.get('weight', 'empty') == 'empty': # is from tf | |||||
| return ConvMapper.convert_params_tf(params=params, weights=weights) | return ConvMapper.convert_params_tf(params=params, weights=weights) | ||||
| return ConvMapper.convert_params_torch(params=params, weights=weights) | return ConvMapper.convert_params_torch(params=params, weights=weights) | ||||
| @@ -25,15 +25,16 @@ class GraphFactory: | |||||
| @classmethod | @classmethod | ||||
| def init(cls, graph_path: str, | def init(cls, graph_path: str, | ||||
| input_nodes: str, output_nodes: str, | |||||
| sample_shape: tuple): | |||||
| sample_shape: tuple, | |||||
| input_nodes: str = None, output_nodes: str = None): | |||||
| """ | """ | ||||
| Init an instance of graph. | Init an instance of graph. | ||||
| Args: | Args: | ||||
| graph_path (str): Graph or model file path. | graph_path (str): Graph or model file path. | ||||
| sample_shape (tuple): Input shape of the model. | sample_shape (tuple): Input shape of the model. | ||||
| checkpoint (str): Checkpoint file path. | |||||
| input_nodes(str): Input nodes. | |||||
| output_nodes(str): Output nodes. | |||||
| Returns: | Returns: | ||||
| Graph, graph instance. | Graph, graph instance. | ||||
| @@ -246,17 +246,11 @@ class Graph(BaseGraph, abc.ABC): | |||||
| model_path (str): Graph or model file path. | model_path (str): Graph or model file path. | ||||
| sample_shape (tuple): Input shape of the model. | sample_shape (tuple): Input shape of the model. | ||||
| checkpoint (str): Checkpoint file path. | checkpoint (str): Checkpoint file path. | ||||
| input_nodes (list[str]): list of input nodes' name | |||||
| output_nodes (list[str]): list of output nodes' name | |||||
| Returns: | Returns: | ||||
| cls, graph instance. | cls, graph instance. | ||||
| """ | """ | ||||
| tf_input_nodes = kwargs.get('input_nodes') | |||||
| tf_output_nodes = kwargs.get('output_nodes') | |||||
| src_graph = cls.load_graph(graph_path=model_path, | |||||
| input_nodes=tf_input_nodes, | |||||
| output_nodes=tf_output_nodes) | |||||
| src_graph = cls.load_graph(graph_path=model_path, **kwargs) | |||||
| ckpt = cls.load_checkpoint( | ckpt = cls.load_checkpoint( | ||||
| ckpt_path=checkpoint) if checkpoint else None | ckpt_path=checkpoint) if checkpoint else None | ||||
| @@ -193,8 +193,6 @@ class OnnxGraph(Graph): | |||||
| Args: | Args: | ||||
| graph_path (str): Graph path. | graph_path (str): Graph path. | ||||
| tf_input_nodes (str): input nodes of tf graph | |||||
| tf_output_nodes (str): output nodes of tf graph | |||||
| Returns: | Returns: | ||||
| object, ONNX model. | object, ONNX model. | ||||
| @@ -18,8 +18,9 @@ from copy import deepcopy | |||||
| from .base import GraphNode | from .base import GraphNode | ||||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, \ | from ..constant import NodeType, SEPARATOR_IN_SCOPE, \ | ||||
| SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP | |||||
| SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP, InputType | |||||
| from ..mapper.base import Mapper | from ..mapper.base import Mapper | ||||
| from ...common.exceptions import NodeInputTypeNotSupport | |||||
| class OnnxGraphNode(GraphNode): | class OnnxGraphNode(GraphNode): | ||||
| @@ -160,10 +161,10 @@ class OnnxGraphNode(GraphNode): | |||||
| Args: | Args: | ||||
| op_name (str): Add the tensor to args if the current node has this | op_name (str): Add the tensor to args if the current node has this | ||||
| op_name. | op_name. | ||||
| t_identifier (str): The unique strinf appeared in the target tensor | |||||
| t_identifier (str): The unique string appeared in the target tensor | |||||
| name. | name. | ||||
| declare_s (str): Declare statement generated in to_code(). | |||||
| init_s (str): init statement generated in to_code(). | |||||
| declare (str): Declare statement generated in to_code(). | |||||
| args (str): Args statement generated in to_code(). | |||||
| Returns: | Returns: | ||||
| declare_list list, multiple declare statements. | declare_list list, multiple declare statements. | ||||
| @@ -226,9 +227,9 @@ class OnnxGraphNode(GraphNode): | |||||
| declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | ||||
| 'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct) | 'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct) | ||||
| # Extra Tensor generator for onnx::BiasAdd | |||||
| # Extra Tensor generator for onnx::Add | |||||
| declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | ||||
| 'onnx::MatMul', 'BiasAdd', declare, ipt_args_settings_in_construct) | |||||
| 'onnx::Add', 'BiasAdd', declare, ipt_args_settings_in_construct) | |||||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" | call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" | ||||
| @@ -320,7 +321,7 @@ class OnnxGraphNode(GraphNode): | |||||
| def param_transform(self, mapper: Mapper): | def param_transform(self, mapper: Mapper): | ||||
| """ | """ | ||||
| Transform torch params into mindspore. | |||||
| Transform tensorflow params into mindspore. | |||||
| Args: | Args: | ||||
| mapper (Mapper): Mapper of params. | mapper (Mapper): Mapper of params. | ||||
| @@ -173,7 +173,7 @@ class PyTorchGraph(Graph): | |||||
| """ | """ | ||||
| self._check_input_shape(input_shape) | self._check_input_shape(input_shape) | ||||
| feed_forward_ipt_shape = (1, *input_shape) | |||||
| feed_forward_ipt_shape = tuple(input_shape) | |||||
| graph = self._trace_torch_graph(feed_forward_ipt_shape) | graph = self._trace_torch_graph(feed_forward_ipt_shape) | ||||
| nodes = list(graph.nodes()) | nodes = list(graph.nodes()) | ||||
| @@ -283,7 +283,7 @@ class PyTorchGraph(Graph): | |||||
| raise NotImplementedError(err_msg) | raise NotImplementedError(err_msg) | ||||
| @staticmethod | @staticmethod | ||||
| def load_graph(graph_path: str): | |||||
| def load_graph(graph_path: str, **kwargs): | |||||
| """ | """ | ||||
| Load graph. | Load graph. | ||||
| @@ -21,11 +21,13 @@ Usage: | |||||
| """ | """ | ||||
| import difflib | import difflib | ||||
| import os | import os | ||||
| import re | |||||
| import sys | import sys | ||||
| import pytest | import pytest | ||||
| from mindinsight.mindconverter.converter import main | from mindinsight.mindconverter.converter import main | ||||
| from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter | |||||
| @pytest.mark.usefixtures('create_output_dir') | @pytest.mark.usefixtures('create_output_dir') | ||||
| @@ -36,7 +38,9 @@ class TestConverter: | |||||
| def setup_class(cls): | def setup_class(cls): | ||||
| """Setup method.""" | """Setup method.""" | ||||
| cls.script_dir = os.path.join(os.path.dirname(__file__), 'data') | cls.script_dir = os.path.join(os.path.dirname(__file__), 'data') | ||||
| cls.pytorch_dir = '/home/test/mindinsight_sample' | |||||
| pytorch_base_dir = os.path.dirname(__file__).split('/')[:3] | |||||
| cls.pytorch_dir = \ | |||||
| '/'.join(pytorch_base_dir + ['share-data', 'dataset', 'mindinsight_dataset', 'resnet50']) | |||||
| sys.path.insert(0, cls.script_dir) | sys.path.insert(0, cls.script_dir) | ||||
| @classmethod | @classmethod | ||||
| @@ -78,3 +82,35 @@ class TestConverter: | |||||
| converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) | converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) | ||||
| assert converted_ratio >= 80 | assert converted_ratio >= 80 | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_single | |||||
| def test_main_graph_based_converter(self, output): | |||||
| """Test main graph based converter.""" | |||||
| pytorch_filename = "resnet50.pth" | |||||
| expected_model_filename = "resnet50.py" | |||||
| expected_report_filename = "report_of_resnet50.txt" | |||||
| file_config = { | |||||
| 'model_file': os.path.join(self.pytorch_dir, pytorch_filename), | |||||
| 'shape': (1, 3, 224, 224), | |||||
| 'outfile_dir': output, | |||||
| 'report_dir': output | |||||
| } | |||||
| with pytest.raises(ValueError) as e: | |||||
| main_graph_base_converter(file_config=file_config) | |||||
| assert os.path.isfile(os.path.join(output, expected_model_filename)) | |||||
| assert os.path.isfile(os.path.join(output, expected_report_filename)) | |||||
| with open(os.path.join(output, expected_report_filename)) as converted_r: | |||||
| converted_report = converted_r.readlines() | |||||
| converted_rate = re.findall(r".*(?:Converted Rate: )(.*)[.]", converted_report[-1]) | |||||
| assert converted_rate[0] == '100.00%' | |||||
| exec_msg = e.value.args[0] | |||||
| assert exec_msg == "torch.__spec__ is None" | |||||