Merge pull request !781 from moran/network_validationtags/v1.1.0
| @@ -19,6 +19,7 @@ import argparse | |||
| import mindinsight | |||
| from mindinsight.mindconverter.converter import main | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_SHAPE_NUMBER | |||
| from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| @@ -38,6 +39,11 @@ class FileDirAction(argparse.Action): | |||
| option_string (str): Optional string for specific argument name. Default: None. | |||
| """ | |||
| outfile = values | |||
| if len(outfile) > ARGUMENT_LENGTH_LIMIT: | |||
| parser_in.error( | |||
| f"The length of {option_string}{outfile} should be no more than {ARGUMENT_LENGTH_LIMIT}.") | |||
| if outfile.startswith('~'): | |||
| outfile = os.path.realpath(os.path.expanduser(outfile)) | |||
| @@ -160,9 +166,6 @@ class ModelFileAction(argparse.Action): | |||
| if not os.path.isfile(outfile_dir): | |||
| parser_in.error(f'{option_string} {outfile_dir} is not a file') | |||
| if not outfile_dir.endswith('.pth'): | |||
| parser_in.error(f"{option_string} {outfile_dir} should be a Pytorch model, ending with '.pth'.") | |||
| setattr(namespace, self.dest, outfile_dir) | |||
| @@ -200,14 +203,42 @@ class ShapeAction(argparse.Action): | |||
| """ | |||
| in_shape = None | |||
| shape_str = values | |||
| shape_list = shape_str.split(';') | |||
| if not len(shape_list) == EXPECTED_SHAPE_NUMBER: | |||
| parser_in.error(f"Only support one shape now, but get {len(shape_list)}.") | |||
| try: | |||
| in_shape = [int(num_shape) for num_shape in shape_str.split(',')] | |||
| in_shape = [int(num_shape) for num_shape in shape_list[0].split(',')] | |||
| except ValueError: | |||
| parser_in.error( | |||
| f"{option_string} {shape_str} should be a list of integer split by ',', check it please.") | |||
| setattr(namespace, self.dest, in_shape) | |||
| class NodeAction(argparse.Action): | |||
| """Node action class definition.""" | |||
| def __call__(self, parser_in, namespace, values, option_string=None): | |||
| """ | |||
| Inherited __call__ method from FileDirAction. | |||
| Args: | |||
| parser_in (ArgumentParser): Passed-in argument parser. | |||
| namespace (Namespace): Namespace object to hold arguments. | |||
| values (object): Argument values with type depending on argument definition. | |||
| option_string (str): Optional string for specific argument name. Default: None. | |||
| """ | |||
| node_str = values | |||
| if len(node_str) > ARGUMENT_LENGTH_LIMIT: | |||
| parser_in.error( | |||
| f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}." | |||
| ) | |||
| setattr(namespace, self.dest, node_str) | |||
| parser = argparse.ArgumentParser( | |||
| prog='mindconverter', | |||
| description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__)) | |||
| @@ -234,7 +265,7 @@ parser.add_argument( | |||
| action=ModelFileAction, | |||
| required=False, | |||
| help=""" | |||
| PyTorch .pth model file path to use graph | |||
| PyTorch .pth or Tensorflow .pb model file path to use graph | |||
| based schema to do script generation. When | |||
| `--in_file` and `--model_file` are both provided, | |||
| use AST schema as default. | |||
| @@ -250,7 +281,29 @@ parser.add_argument( | |||
| Optional, expected input tensor shape of | |||
| `--model_file`. It's required when use graph based | |||
| schema. | |||
| Usage: --shape 3,244,244 | |||
| Usage: --shape 1,3,244,244 | |||
| """) | |||
| parser.add_argument( | |||
| '--input_nodes', | |||
| type=str, | |||
| action=NodeAction, | |||
| default=None, | |||
| required=False, | |||
| help=""" | |||
| Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model. | |||
| Usage: --input_nodes input_1:0,input_2:0 | |||
| """) | |||
| parser.add_argument( | |||
| '--output_nodes', | |||
| type=str, | |||
| action=NodeAction, | |||
| default=None, | |||
| required=False, | |||
| help=""" | |||
| Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model. | |||
| Usage: --output_nodes output_1:0,output_2:0 | |||
| """) | |||
| parser.add_argument( | |||
| @@ -305,10 +358,14 @@ def cli_entry(): | |||
| if args.report is None: | |||
| args.report = args.output | |||
| os.makedirs(args.report, mode=mode, exist_ok=True) | |||
| _run(args.in_file, args.model_file, args.shape, args.output, args.report, args.project_path) | |||
| _run(args.in_file, args.model_file, | |||
| args.shape, | |||
| args.input_nodes, args.output_nodes, | |||
| args.output, args.report, | |||
| args.project_path) | |||
| def _run(in_files, model_file, shape, out_dir, report, project_path): | |||
| def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report, project_path): | |||
| """ | |||
| Run converter command. | |||
| @@ -316,6 +373,8 @@ def _run(in_files, model_file, shape, out_dir, report, project_path): | |||
| in_files (str): The file path or directory to convert. | |||
| model_file(str): The pytorch .pth to convert on graph based schema. | |||
| shape(list): The input tensor shape of module_file. | |||
| input_nodes(str): The input node(s) name of Tensorflow model, split by ','. | |||
| output_nodes(str): The output node(s) name of Tensorflow model, split by ','. | |||
| out_dir (str): The output directory to save converted file. | |||
| report (str): The report file path. | |||
| project_path(str): Pytorch scripts project path. | |||
| @@ -341,6 +400,8 @@ def _run(in_files, model_file, shape, out_dir, report, project_path): | |||
| file_config = { | |||
| 'model_file': model_file, | |||
| 'shape': shape if shape else [], | |||
| 'input_nodes': input_nodes, | |||
| 'output_nodes': output_nodes, | |||
| 'outfile_dir': out_dir, | |||
| 'report_dir': report if report else out_dir | |||
| } | |||
| @@ -26,6 +26,7 @@ class ConverterErrors(ScriptConverterErrors): | |||
| NODE_TYPE_NOT_SUPPORT = 2 | |||
| CODE_SYNTAX_ERROR = 3 | |||
| NODE_INPUT_TYPE_NOT_SUPPORT = 4 | |||
| UNKNOWN_MODEL = 5 | |||
| class ScriptNotSupport(MindInsightException): | |||
| @@ -62,3 +63,12 @@ class NodeInputTypeNotSupport(MindInsightException): | |||
| super(NodeInputTypeNotSupport, self).__init__(ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, | |||
| msg, | |||
| http_code=400) | |||
| class UnknownModel(MindInsightException): | |||
| """The unknown model error.""" | |||
| def __init__(self, msg): | |||
| super(UnknownModel, self).__init__(ConverterErrors.UNKNOWN_MODEL, | |||
| msg, | |||
| http_code=400) | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Graph based scripts converter definition.""" | |||
| from .framework import graph_based_converter | |||
| from .framework import graph_based_converter_pytorch_to_ms | |||
| from .framework import graph_based_converter_tf_to_ms | |||
| __all__ = ["graph_based_converter"] | |||
| __all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"] | |||
| @@ -35,6 +35,15 @@ ONNX_TYPE_FLOAT = 1 | |||
| ONNX_TYPE_FLOATS = 6 | |||
| ONNX_TYPE_STRING = 3 | |||
| BINARY_HEADER_PYTORCH_FILE = \ | |||
| b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00' | |||
| BINARY_HEADER_PYTORCH_BITS = 32 | |||
| ARGUMENT_LENGTH_LIMIT = 512 | |||
| EXPECTED_SHAPE_NUMBER = 1 | |||
| @unique | |||
| class CodeFormatConfig(Enum): | |||
| @@ -54,3 +63,9 @@ class NodeType(Enum): | |||
| class InputType(Enum): | |||
| TENSOR = "tensor" | |||
| LIST = "list" | |||
| @unique | |||
| class FrameworkType(Enum): | |||
| PYTORCH = 0 | |||
| TENSORFLOW = 1 | |||
| @@ -16,12 +16,16 @@ | |||
| import os | |||
| import re | |||
| import argparse | |||
| from importlib import import_module | |||
| from importlib.util import find_spec | |||
| import mindinsight | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ | |||
| BINARY_HEADER_PYTORCH_BITS | |||
| from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport | |||
| from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport, UnknownModel | |||
| from mindinsight.utils.exceptions import ParamMissError | |||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||
| os.umask(permissions << 3 | permissions) | |||
| @@ -56,8 +60,7 @@ def torch_installation_validation(func): | |||
| """ | |||
| def _f(graph_path: str, sample_shape: tuple, | |||
| output_folder: str, report_folder: str = None, | |||
| checkpoint_path: str = None): | |||
| output_folder: str, report_folder: str = None): | |||
| # Check whether pytorch is installed. | |||
| if not find_spec("torch"): | |||
| error = ModuleNotFoundError("PyTorch is required when using graph based " | |||
| @@ -67,9 +70,35 @@ def torch_installation_validation(func): | |||
| log.exception(error) | |||
| raise error | |||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||
| output_folder=output_folder, report_folder=report_folder) | |||
| return _f | |||
| def tf_installation_validation(func): | |||
| """ | |||
| Validate args of func. | |||
| Args: | |||
| func(type): Function. | |||
| Returns: | |||
| type, inner function. | |||
| """ | |||
| def _f(graph_path: str, sample_shape: tuple, | |||
| output_folder: str, report_folder: str = None, | |||
| input_nodes: str = None, output_nodes: str = None): | |||
| # Check whether tensorflow is installed. | |||
| if not find_spec("tensorflow") or not find_spec("tf2onnx"): | |||
| error = ModuleNotFoundError("Tensorflow and tf2onnx are required when using " | |||
| "graph based scripts converter.") | |||
| log.error(str(error)) | |||
| raise error | |||
| func(graph_path=graph_path, sample_shape=sample_shape, | |||
| output_folder=output_folder, report_folder=report_folder, | |||
| checkpoint_path=checkpoint_path) | |||
| input_nodes=input_nodes, output_nodes=output_nodes) | |||
| return _f | |||
| @@ -85,32 +114,33 @@ def _extract_model_name(model_path): | |||
| str: Name of Converted model. | |||
| """ | |||
| model_name = re.findall(r".*[/](.*).pth", model_path)[-1] | |||
| model_name = re.findall(r".*[/](.*)(?:\.pth|\.pb)", model_path)[-1] | |||
| return model_name | |||
| @torch_installation_validation | |||
| def graph_based_converter(graph_path: str, sample_shape: tuple, | |||
| output_folder: str, report_folder: str = None, | |||
| checkpoint_path: str = None): | |||
| def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, | |||
| output_folder: str, report_folder: str = None): | |||
| """ | |||
| Graph based scripts converter. | |||
| Pytoch to MindSpore based on Graph. | |||
| Args: | |||
| graph_path (str): Graph file path. | |||
| sample_shape (tuple): Input shape of the model. | |||
| output_folder (str): Output folder. | |||
| report_folder (str): Report output folder path. | |||
| checkpoint_path (str): Checkpoint file path. | |||
| """ | |||
| from .third_party_graph import GraphFactory | |||
| from .hierarchical_tree import HierarchicalTreeFactory | |||
| third_party_graph_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph') | |||
| hierarchical_tree_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree') | |||
| cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') | |||
| cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') | |||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | |||
| checkpoint=checkpoint_path) | |||
| graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape) | |||
| try: | |||
| hierarchical_tree = HierarchicalTreeFactory.create(graph_obj) | |||
| hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj) | |||
| except Exception as e: | |||
| log.exception(e) | |||
| log.error("Error occur when create hierarchical tree.") | |||
| @@ -123,6 +153,49 @@ def graph_based_converter(graph_path: str, sample_shape: tuple, | |||
| report_folder=report_folder) | |||
| @tf_installation_validation | |||
| def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, | |||
| input_nodes: str, output_nodes: str, | |||
| output_folder: str, report_folder: str = None): | |||
| """ | |||
| Tensorflow to MindSpore based on Graph. | |||
| Args: | |||
| graph_path(str): Graph file path. | |||
| sample_shape(tuple): Input shape of the model. | |||
| input_nodes(str): Input node(s) of the model. | |||
| output_nodes(str): Output node(s) of the model. | |||
| output_folder(str): Output folder. | |||
| report_folder(str): Report output folder path. | |||
| """ | |||
| third_party_graph_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.third_party_graph') | |||
| hierarchical_tree_module = import_module( | |||
| 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree') | |||
| cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') | |||
| cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') | |||
| # Close unnecessary log. | |||
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |||
| graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, | |||
| input_nodes=input_nodes, output_nodes=output_nodes) | |||
| try: | |||
| hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj) | |||
| except Exception as e: | |||
| log.exception(e) | |||
| log.error("Error occur when create hierarchical tree.") | |||
| raise NodeTypeNotSupport("This model is not supported now.") | |||
| model_name = _extract_model_name(graph_path) | |||
| hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | |||
| model_name=model_name, | |||
| report_folder=report_folder, | |||
| scope_name_map=scope_name_map) | |||
| def main_graph_base_converter(file_config): | |||
| """ | |||
| The entrance for converter, script files will be converted. | |||
| @@ -130,7 +203,54 @@ def main_graph_base_converter(file_config): | |||
| Args: | |||
| file_config (dict): The config of file which to convert. | |||
| """ | |||
| graph_based_converter(graph_path=file_config['model_file'], | |||
| sample_shape=file_config['shape'], | |||
| output_folder=file_config['outfile_dir'], | |||
| report_folder=file_config['report_dir']) | |||
| graph_path = file_config['model_file'] | |||
| frame_type = get_framework_type(graph_path) | |||
| if frame_type == FrameworkType.PYTORCH.value: | |||
| graph_based_converter_pytorch_to_ms(graph_path=graph_path, | |||
| sample_shape=file_config['shape'], | |||
| output_folder=file_config['outfile_dir'], | |||
| report_folder=file_config['report_dir']) | |||
| elif frame_type == FrameworkType.TENSORFLOW.value: | |||
| check_params = ['input_nodes', 'output_nodes'] | |||
| check_params_exist(check_params, file_config) | |||
| graph_based_converter_tf_to_ms(graph_path=graph_path, | |||
| sample_shape=file_config['shape'], | |||
| input_nodes=file_config['input_nodes'], | |||
| output_nodes=file_config['output_nodes'], | |||
| output_folder=file_config['outfile_dir'], | |||
| report_folder=file_config['report_dir']) | |||
| else: | |||
| error_msg = "Get UNSUPPORTED model." | |||
| error = UnknownModel(error_msg) | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| def get_framework_type(model_path): | |||
| """Get framework type.""" | |||
| try: | |||
| with open(model_path, 'rb') as f: | |||
| if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: | |||
| framework_type = FrameworkType.PYTORCH.value | |||
| else: | |||
| framework_type = FrameworkType.TENSORFLOW.value | |||
| except IOError: | |||
| error_msg = "Get UNSUPPORTED model." | |||
| error = UnknownModel(error_msg) | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| return framework_type | |||
| def check_params_exist(params: list, config): | |||
| """Check params exist.""" | |||
| miss_param_list = '' | |||
| for param in params: | |||
| if not config.get(param) or not config[param]: | |||
| miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param | |||
| if miss_param_list: | |||
| raise ParamMissError(miss_param_list) | |||
| @@ -182,6 +182,7 @@ class HierarchicalTree(Tree): | |||
| mapper (Mapper): Mapper of third party framework and mindspore. | |||
| model_name(str): Name of Converted model. | |||
| out_folder (str): Output folder. | |||
| scope_name_map(str): Scope name map of tensorflow. | |||
| """ | |||
| if scope_name_map: | |||
| @@ -117,12 +117,14 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| if not kwargs['weights'].get('weight'): # is from tf | |||
| weight = kwargs['weights'].get('weight', 'empty') | |||
| if weight == 'empty': # is from tf | |||
| kernel_size = kwargs['params'].get('kernel_shape') | |||
| dim = len(kernel_size) | |||
| return f"nn.Conv{dim}d" | |||
| weight = kwargs['weights']['weight'].numpy() | |||
| weight = weight.numpy() | |||
| dim = weight.ndim - 2 | |||
| return f"nn.Conv{dim}d" | |||
| @@ -131,7 +133,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| weights = kwargs['weights'] | |||
| params = kwargs['params'] | |||
| if not weights.get('weight'): # is from tf | |||
| if weights.get('weight', 'empty') == 'empty': # is from tf | |||
| return ConvMapper.convert_params_tf(params=params, weights=weights) | |||
| return ConvMapper.convert_params_torch(params=params, weights=weights) | |||
| @@ -25,15 +25,16 @@ class GraphFactory: | |||
| @classmethod | |||
| def init(cls, graph_path: str, | |||
| input_nodes: str, output_nodes: str, | |||
| sample_shape: tuple): | |||
| sample_shape: tuple, | |||
| input_nodes: str = None, output_nodes: str = None): | |||
| """ | |||
| Init an instance of graph. | |||
| Args: | |||
| graph_path (str): Graph or model file path. | |||
| sample_shape (tuple): Input shape of the model. | |||
| checkpoint (str): Checkpoint file path. | |||
| input_nodes(str): Input nodes. | |||
| output_nodes(str): Output nodes. | |||
| Returns: | |||
| Graph, graph instance. | |||
| @@ -246,17 +246,11 @@ class Graph(BaseGraph, abc.ABC): | |||
| model_path (str): Graph or model file path. | |||
| sample_shape (tuple): Input shape of the model. | |||
| checkpoint (str): Checkpoint file path. | |||
| input_nodes (list[str]): list of input nodes' name | |||
| output_nodes (list[str]): list of output nodes' name | |||
| Returns: | |||
| cls, graph instance. | |||
| """ | |||
| tf_input_nodes = kwargs.get('input_nodes') | |||
| tf_output_nodes = kwargs.get('output_nodes') | |||
| src_graph = cls.load_graph(graph_path=model_path, | |||
| input_nodes=tf_input_nodes, | |||
| output_nodes=tf_output_nodes) | |||
| src_graph = cls.load_graph(graph_path=model_path, **kwargs) | |||
| ckpt = cls.load_checkpoint( | |||
| ckpt_path=checkpoint) if checkpoint else None | |||
| @@ -193,8 +193,6 @@ class OnnxGraph(Graph): | |||
| Args: | |||
| graph_path (str): Graph path. | |||
| tf_input_nodes (str): input nodes of tf graph | |||
| tf_output_nodes (str): output nodes of tf graph | |||
| Returns: | |||
| object, ONNX model. | |||
| @@ -18,8 +18,9 @@ from copy import deepcopy | |||
| from .base import GraphNode | |||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, \ | |||
| SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP | |||
| SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP, InputType | |||
| from ..mapper.base import Mapper | |||
| from ...common.exceptions import NodeInputTypeNotSupport | |||
| class OnnxGraphNode(GraphNode): | |||
| @@ -160,10 +161,10 @@ class OnnxGraphNode(GraphNode): | |||
| Args: | |||
| op_name (str): Add the tensor to args if the current node has this | |||
| op_name. | |||
| t_identifier (str): The unique strinf appeared in the target tensor | |||
| t_identifier (str): The unique string appeared in the target tensor | |||
| name. | |||
| declare_s (str): Declare statement generated in to_code(). | |||
| init_s (str): init statement generated in to_code(). | |||
| declare (str): Declare statement generated in to_code(). | |||
| args (str): Args statement generated in to_code(). | |||
| Returns: | |||
| declare_list list, multiple declare statements. | |||
| @@ -226,9 +227,9 @@ class OnnxGraphNode(GraphNode): | |||
| declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | |||
| 'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct) | |||
| # Extra Tensor generator for onnx::BiasAdd | |||
| # Extra Tensor generator for onnx::Add | |||
| declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( | |||
| 'onnx::MatMul', 'BiasAdd', declare, ipt_args_settings_in_construct) | |||
| 'onnx::Add', 'BiasAdd', declare, ipt_args_settings_in_construct) | |||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" | |||
| @@ -320,7 +321,7 @@ class OnnxGraphNode(GraphNode): | |||
| def param_transform(self, mapper: Mapper): | |||
| """ | |||
| Transform torch params into mindspore. | |||
| Transform tensorflow params into mindspore. | |||
| Args: | |||
| mapper (Mapper): Mapper of params. | |||
| @@ -173,7 +173,7 @@ class PyTorchGraph(Graph): | |||
| """ | |||
| self._check_input_shape(input_shape) | |||
| feed_forward_ipt_shape = (1, *input_shape) | |||
| feed_forward_ipt_shape = tuple(input_shape) | |||
| graph = self._trace_torch_graph(feed_forward_ipt_shape) | |||
| nodes = list(graph.nodes()) | |||
| @@ -283,7 +283,7 @@ class PyTorchGraph(Graph): | |||
| raise NotImplementedError(err_msg) | |||
| @staticmethod | |||
| def load_graph(graph_path: str): | |||
| def load_graph(graph_path: str, **kwargs): | |||
| """ | |||
| Load graph. | |||
| @@ -21,11 +21,13 @@ Usage: | |||
| """ | |||
| import difflib | |||
| import os | |||
| import re | |||
| import sys | |||
| import pytest | |||
| from mindinsight.mindconverter.converter import main | |||
| from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter | |||
| @pytest.mark.usefixtures('create_output_dir') | |||
| @@ -36,7 +38,9 @@ class TestConverter: | |||
| def setup_class(cls): | |||
| """Setup method.""" | |||
| cls.script_dir = os.path.join(os.path.dirname(__file__), 'data') | |||
| cls.pytorch_dir = '/home/test/mindinsight_sample' | |||
| pytorch_base_dir = os.path.dirname(__file__).split('/')[:3] | |||
| cls.pytorch_dir = \ | |||
| '/'.join(pytorch_base_dir + ['share-data', 'dataset', 'mindinsight_dataset', 'resnet50']) | |||
| sys.path.insert(0, cls.script_dir) | |||
| @classmethod | |||
| @@ -78,3 +82,35 @@ class TestConverter: | |||
| converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) | |||
| assert converted_ratio >= 80 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_single | |||
| def test_main_graph_based_converter(self, output): | |||
| """Test main graph based converter.""" | |||
| pytorch_filename = "resnet50.pth" | |||
| expected_model_filename = "resnet50.py" | |||
| expected_report_filename = "report_of_resnet50.txt" | |||
| file_config = { | |||
| 'model_file': os.path.join(self.pytorch_dir, pytorch_filename), | |||
| 'shape': (1, 3, 224, 224), | |||
| 'outfile_dir': output, | |||
| 'report_dir': output | |||
| } | |||
| with pytest.raises(ValueError) as e: | |||
| main_graph_base_converter(file_config=file_config) | |||
| assert os.path.isfile(os.path.join(output, expected_model_filename)) | |||
| assert os.path.isfile(os.path.join(output, expected_report_filename)) | |||
| with open(os.path.join(output, expected_report_filename)) as converted_r: | |||
| converted_report = converted_r.readlines() | |||
| converted_rate = re.findall(r".*(?:Converted Rate: )(.*)[.]", converted_report[-1]) | |||
| assert converted_rate[0] == '100.00%' | |||
| exec_msg = e.value.args[0] | |||
| assert exec_msg == "torch.__spec__ is None" | |||