diff --git a/mindinsight/mindconverter/cli.py b/mindinsight/mindconverter/cli.py index b4e1d56f..c8a7b675 100644 --- a/mindinsight/mindconverter/cli.py +++ b/mindinsight/mindconverter/cli.py @@ -19,6 +19,7 @@ import argparse import mindinsight from mindinsight.mindconverter.converter import main +from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_SHAPE_NUMBER from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter from mindinsight.mindconverter.common.log import logger as log @@ -38,6 +39,11 @@ class FileDirAction(argparse.Action): option_string (str): Optional string for specific argument name. Default: None. """ outfile = values + + if len(outfile) > ARGUMENT_LENGTH_LIMIT: + parser_in.error( + f"The length of {option_string}{outfile} should be no more than {ARGUMENT_LENGTH_LIMIT}.") + if outfile.startswith('~'): outfile = os.path.realpath(os.path.expanduser(outfile)) @@ -160,9 +166,6 @@ class ModelFileAction(argparse.Action): if not os.path.isfile(outfile_dir): parser_in.error(f'{option_string} {outfile_dir} is not a file') - if not outfile_dir.endswith('.pth'): - parser_in.error(f"{option_string} {outfile_dir} should be a Pytorch model, ending with '.pth'.") - setattr(namespace, self.dest, outfile_dir) @@ -200,14 +203,42 @@ class ShapeAction(argparse.Action): """ in_shape = None shape_str = values + + shape_list = shape_str.split(';') + if not len(shape_list) == EXPECTED_SHAPE_NUMBER: + parser_in.error(f"Only support one shape now, but get {len(shape_list)}.") + try: - in_shape = [int(num_shape) for num_shape in shape_str.split(',')] + in_shape = [int(num_shape) for num_shape in shape_list[0].split(',')] except ValueError: parser_in.error( f"{option_string} {shape_str} should be a list of integer split by ',', check it please.") setattr(namespace, self.dest, in_shape) +class NodeAction(argparse.Action): + """Node action class definition.""" + + def __call__(self, parser_in, namespace, values, option_string=None): + """ + Inherited __call__ method from FileDirAction. + + Args: + parser_in (ArgumentParser): Passed-in argument parser. + namespace (Namespace): Namespace object to hold arguments. + values (object): Argument values with type depending on argument definition. + option_string (str): Optional string for specific argument name. Default: None. + + """ + node_str = values + if len(node_str) > ARGUMENT_LENGTH_LIMIT: + parser_in.error( + f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}." + ) + + setattr(namespace, self.dest, node_str) + + parser = argparse.ArgumentParser( prog='mindconverter', description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__)) @@ -234,7 +265,7 @@ parser.add_argument( action=ModelFileAction, required=False, help=""" - PyTorch .pth model file path to use graph + PyTorch .pth or Tensorflow .pb model file path to use graph based schema to do script generation. When `--in_file` and `--model_file` are both provided, use AST schema as default. @@ -250,7 +281,29 @@ parser.add_argument( Optional, expected input tensor shape of `--model_file`. It's required when use graph based schema. - Usage: --shape 3,244,244 + Usage: --shape 1,3,244,244 + """) + +parser.add_argument( + '--input_nodes', + type=str, + action=NodeAction, + default=None, + required=False, + help=""" + Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model. + Usage: --input_nodes input_1:0,input_2:0 + """) + +parser.add_argument( + '--output_nodes', + type=str, + action=NodeAction, + default=None, + required=False, + help=""" + Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model. + Usage: --output_nodes output_1:0,output_2:0 """) parser.add_argument( @@ -305,10 +358,14 @@ def cli_entry(): if args.report is None: args.report = args.output os.makedirs(args.report, mode=mode, exist_ok=True) - _run(args.in_file, args.model_file, args.shape, args.output, args.report, args.project_path) + _run(args.in_file, args.model_file, + args.shape, + args.input_nodes, args.output_nodes, + args.output, args.report, + args.project_path) -def _run(in_files, model_file, shape, out_dir, report, project_path): +def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report, project_path): """ Run converter command. @@ -316,6 +373,8 @@ def _run(in_files, model_file, shape, out_dir, report, project_path): in_files (str): The file path or directory to convert. model_file(str): The pytorch .pth to convert on graph based schema. shape(list): The input tensor shape of module_file. + input_nodes(str): The input node(s) name of Tensorflow model, split by ','. + output_nodes(str): The output node(s) name of Tensorflow model, split by ','. out_dir (str): The output directory to save converted file. report (str): The report file path. project_path(str): Pytorch scripts project path. @@ -341,6 +400,8 @@ def _run(in_files, model_file, shape, out_dir, report, project_path): file_config = { 'model_file': model_file, 'shape': shape if shape else [], + 'input_nodes': input_nodes, + 'output_nodes': output_nodes, 'outfile_dir': out_dir, 'report_dir': report if report else out_dir } diff --git a/mindinsight/mindconverter/common/exceptions.py b/mindinsight/mindconverter/common/exceptions.py index cc7694f1..cf27bf4d 100644 --- a/mindinsight/mindconverter/common/exceptions.py +++ b/mindinsight/mindconverter/common/exceptions.py @@ -26,6 +26,7 @@ class ConverterErrors(ScriptConverterErrors): NODE_TYPE_NOT_SUPPORT = 2 CODE_SYNTAX_ERROR = 3 NODE_INPUT_TYPE_NOT_SUPPORT = 4 + UNKNOWN_MODEL = 5 class ScriptNotSupport(MindInsightException): @@ -62,3 +63,12 @@ class NodeInputTypeNotSupport(MindInsightException): super(NodeInputTypeNotSupport, self).__init__(ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, msg, http_code=400) + + +class UnknownModel(MindInsightException): + """The unknown model error.""" + + def __init__(self, msg): + super(UnknownModel, self).__init__(ConverterErrors.UNKNOWN_MODEL, + msg, + http_code=400) diff --git a/mindinsight/mindconverter/graph_based_converter/__init__.py b/mindinsight/mindconverter/graph_based_converter/__init__.py index 1ebd1394..c1a43578 100644 --- a/mindinsight/mindconverter/graph_based_converter/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Graph based scripts converter definition.""" -from .framework import graph_based_converter +from .framework import graph_based_converter_pytorch_to_ms +from .framework import graph_based_converter_tf_to_ms -__all__ = ["graph_based_converter"] +__all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"] diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index cc7ee4bc..d3b52c2d 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -35,6 +35,15 @@ ONNX_TYPE_FLOAT = 1 ONNX_TYPE_FLOATS = 6 ONNX_TYPE_STRING = 3 +BINARY_HEADER_PYTORCH_FILE = \ + b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00' + +BINARY_HEADER_PYTORCH_BITS = 32 + +ARGUMENT_LENGTH_LIMIT = 512 + +EXPECTED_SHAPE_NUMBER = 1 + @unique class CodeFormatConfig(Enum): @@ -54,3 +63,9 @@ class NodeType(Enum): class InputType(Enum): TENSOR = "tensor" LIST = "list" + + +@unique +class FrameworkType(Enum): + PYTORCH = 0 + TENSORFLOW = 1 diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 0dd64248..7aa7f79b 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -16,12 +16,16 @@ import os import re import argparse +from importlib import import_module from importlib.util import find_spec import mindinsight from mindinsight.mindconverter.common.log import logger as log +from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ + BINARY_HEADER_PYTORCH_BITS from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper -from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport +from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport, UnknownModel +from mindinsight.utils.exceptions import ParamMissError permissions = os.R_OK | os.W_OK | os.X_OK os.umask(permissions << 3 | permissions) @@ -56,8 +60,7 @@ def torch_installation_validation(func): """ def _f(graph_path: str, sample_shape: tuple, - output_folder: str, report_folder: str = None, - checkpoint_path: str = None): + output_folder: str, report_folder: str = None): # Check whether pytorch is installed. if not find_spec("torch"): error = ModuleNotFoundError("PyTorch is required when using graph based " @@ -67,9 +70,35 @@ def torch_installation_validation(func): log.exception(error) raise error + func(graph_path=graph_path, sample_shape=sample_shape, + output_folder=output_folder, report_folder=report_folder) + + return _f + + +def tf_installation_validation(func): + """ + Validate args of func. + + Args: + func(type): Function. + + Returns: + type, inner function. + """ + + def _f(graph_path: str, sample_shape: tuple, + output_folder: str, report_folder: str = None, + input_nodes: str = None, output_nodes: str = None): + # Check whether tensorflow is installed. + if not find_spec("tensorflow") or not find_spec("tf2onnx"): + error = ModuleNotFoundError("Tensorflow and tf2onnx are required when using " + "graph based scripts converter.") + log.error(str(error)) + raise error func(graph_path=graph_path, sample_shape=sample_shape, output_folder=output_folder, report_folder=report_folder, - checkpoint_path=checkpoint_path) + input_nodes=input_nodes, output_nodes=output_nodes) return _f @@ -85,32 +114,33 @@ def _extract_model_name(model_path): str: Name of Converted model. """ - model_name = re.findall(r".*[/](.*).pth", model_path)[-1] + model_name = re.findall(r".*[/](.*)(?:\.pth|\.pb)", model_path)[-1] return model_name @torch_installation_validation -def graph_based_converter(graph_path: str, sample_shape: tuple, - output_folder: str, report_folder: str = None, - checkpoint_path: str = None): +def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple, + output_folder: str, report_folder: str = None): """ - Graph based scripts converter. + Pytoch to MindSpore based on Graph. Args: graph_path (str): Graph file path. sample_shape (tuple): Input shape of the model. output_folder (str): Output folder. report_folder (str): Report output folder path. - checkpoint_path (str): Checkpoint file path. """ - from .third_party_graph import GraphFactory - from .hierarchical_tree import HierarchicalTreeFactory + third_party_graph_module = import_module( + 'mindinsight.mindconverter.graph_based_converter.third_party_graph') + hierarchical_tree_module = import_module( + 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree') + cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') + cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') - graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, - checkpoint=checkpoint_path) + graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape) try: - hierarchical_tree = HierarchicalTreeFactory.create(graph_obj) + hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj) except Exception as e: log.exception(e) log.error("Error occur when create hierarchical tree.") @@ -123,6 +153,49 @@ def graph_based_converter(graph_path: str, sample_shape: tuple, report_folder=report_folder) +@tf_installation_validation +def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple, + input_nodes: str, output_nodes: str, + output_folder: str, report_folder: str = None): + """ + Tensorflow to MindSpore based on Graph. + + Args: + graph_path(str): Graph file path. + sample_shape(tuple): Input shape of the model. + input_nodes(str): Input node(s) of the model. + output_nodes(str): Output node(s) of the model. + output_folder(str): Output folder. + report_folder(str): Report output folder path. + + """ + third_party_graph_module = import_module( + 'mindinsight.mindconverter.graph_based_converter.third_party_graph') + hierarchical_tree_module = import_module( + 'mindinsight.mindconverter.graph_based_converter.hierarchical_tree') + cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory') + cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory') + # Close unnecessary log. + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + + graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape, + input_nodes=input_nodes, output_nodes=output_nodes) + + try: + hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj) + except Exception as e: + log.exception(e) + log.error("Error occur when create hierarchical tree.") + raise NodeTypeNotSupport("This model is not supported now.") + + model_name = _extract_model_name(graph_path) + + hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, + model_name=model_name, + report_folder=report_folder, + scope_name_map=scope_name_map) + + def main_graph_base_converter(file_config): """ The entrance for converter, script files will be converted. @@ -130,7 +203,54 @@ def main_graph_base_converter(file_config): Args: file_config (dict): The config of file which to convert. """ - graph_based_converter(graph_path=file_config['model_file'], - sample_shape=file_config['shape'], - output_folder=file_config['outfile_dir'], - report_folder=file_config['report_dir']) + graph_path = file_config['model_file'] + frame_type = get_framework_type(graph_path) + if frame_type == FrameworkType.PYTORCH.value: + graph_based_converter_pytorch_to_ms(graph_path=graph_path, + sample_shape=file_config['shape'], + output_folder=file_config['outfile_dir'], + report_folder=file_config['report_dir']) + elif frame_type == FrameworkType.TENSORFLOW.value: + check_params = ['input_nodes', 'output_nodes'] + check_params_exist(check_params, file_config) + graph_based_converter_tf_to_ms(graph_path=graph_path, + sample_shape=file_config['shape'], + input_nodes=file_config['input_nodes'], + output_nodes=file_config['output_nodes'], + output_folder=file_config['outfile_dir'], + report_folder=file_config['report_dir']) + else: + error_msg = "Get UNSUPPORTED model." + error = UnknownModel(error_msg) + log.error(str(error)) + log.exception(error) + raise error + + +def get_framework_type(model_path): + """Get framework type.""" + try: + with open(model_path, 'rb') as f: + if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: + framework_type = FrameworkType.PYTORCH.value + else: + framework_type = FrameworkType.TENSORFLOW.value + except IOError: + error_msg = "Get UNSUPPORTED model." + error = UnknownModel(error_msg) + log.error(str(error)) + log.exception(error) + raise error + + return framework_type + + +def check_params_exist(params: list, config): + """Check params exist.""" + miss_param_list = '' + for param in params: + if not config.get(param) or not config[param]: + miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param + + if miss_param_list: + raise ParamMissError(miss_param_list) diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py index 7867dded..d922c966 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py @@ -182,6 +182,7 @@ class HierarchicalTree(Tree): mapper (Mapper): Mapper of third party framework and mindspore. model_name(str): Name of Converted model. out_folder (str): Output folder. + scope_name_map(str): Scope name map of tensorflow. """ if scope_name_map: diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py index bad5f135..e7e7622a 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py @@ -117,12 +117,14 @@ class ConvMapper(ONNXToMindSporeMapper): @staticmethod def _operation_name_in_ms(*args, **kwargs): - if not kwargs['weights'].get('weight'): # is from tf + weight = kwargs['weights'].get('weight', 'empty') + + if weight == 'empty': # is from tf kernel_size = kwargs['params'].get('kernel_shape') dim = len(kernel_size) return f"nn.Conv{dim}d" - weight = kwargs['weights']['weight'].numpy() + weight = weight.numpy() dim = weight.ndim - 2 return f"nn.Conv{dim}d" @@ -131,7 +133,7 @@ class ConvMapper(ONNXToMindSporeMapper): weights = kwargs['weights'] params = kwargs['params'] - if not weights.get('weight'): # is from tf + if weights.get('weight', 'empty') == 'empty': # is from tf return ConvMapper.convert_params_tf(params=params, weights=weights) return ConvMapper.convert_params_torch(params=params, weights=weights) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py index 8dd1e28d..4e39c285 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py @@ -25,15 +25,16 @@ class GraphFactory: @classmethod def init(cls, graph_path: str, - input_nodes: str, output_nodes: str, - sample_shape: tuple): + sample_shape: tuple, + input_nodes: str = None, output_nodes: str = None): """ Init an instance of graph. Args: graph_path (str): Graph or model file path. sample_shape (tuple): Input shape of the model. - checkpoint (str): Checkpoint file path. + input_nodes(str): Input nodes. + output_nodes(str): Output nodes. Returns: Graph, graph instance. diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py index 4303100a..811f1f8a 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py @@ -246,17 +246,11 @@ class Graph(BaseGraph, abc.ABC): model_path (str): Graph or model file path. sample_shape (tuple): Input shape of the model. checkpoint (str): Checkpoint file path. - input_nodes (list[str]): list of input nodes' name - output_nodes (list[str]): list of output nodes' name Returns: cls, graph instance. """ - tf_input_nodes = kwargs.get('input_nodes') - tf_output_nodes = kwargs.get('output_nodes') - src_graph = cls.load_graph(graph_path=model_path, - input_nodes=tf_input_nodes, - output_nodes=tf_output_nodes) + src_graph = cls.load_graph(graph_path=model_path, **kwargs) ckpt = cls.load_checkpoint( ckpt_path=checkpoint) if checkpoint else None diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py index 197d1543..27e52c68 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py @@ -193,8 +193,6 @@ class OnnxGraph(Graph): Args: graph_path (str): Graph path. - tf_input_nodes (str): input nodes of tf graph - tf_output_nodes (str): output nodes of tf graph Returns: object, ONNX model. diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py index 7c616d7f..4b4383e0 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py @@ -18,8 +18,9 @@ from copy import deepcopy from .base import GraphNode from ..constant import NodeType, SEPARATOR_IN_SCOPE, \ - SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP + SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP, InputType from ..mapper.base import Mapper +from ...common.exceptions import NodeInputTypeNotSupport class OnnxGraphNode(GraphNode): @@ -160,10 +161,10 @@ class OnnxGraphNode(GraphNode): Args: op_name (str): Add the tensor to args if the current node has this op_name. - t_identifier (str): The unique strinf appeared in the target tensor + t_identifier (str): The unique string appeared in the target tensor name. - declare_s (str): Declare statement generated in to_code(). - init_s (str): init statement generated in to_code(). + declare (str): Declare statement generated in to_code(). + args (str): Args statement generated in to_code(). Returns: declare_list list, multiple declare statements. @@ -226,9 +227,9 @@ class OnnxGraphNode(GraphNode): declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( 'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct) - # Extra Tensor generator for onnx::BiasAdd + # Extra Tensor generator for onnx::Add declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( - 'onnx::MatMul', 'BiasAdd', declare, ipt_args_settings_in_construct) + 'onnx::Add', 'BiasAdd', declare, ipt_args_settings_in_construct) call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" @@ -320,7 +321,7 @@ class OnnxGraphNode(GraphNode): def param_transform(self, mapper: Mapper): """ - Transform torch params into mindspore. + Transform tensorflow params into mindspore. Args: mapper (Mapper): Mapper of params. diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py index e86ba252..0e4e6164 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py @@ -173,7 +173,7 @@ class PyTorchGraph(Graph): """ self._check_input_shape(input_shape) - feed_forward_ipt_shape = (1, *input_shape) + feed_forward_ipt_shape = tuple(input_shape) graph = self._trace_torch_graph(feed_forward_ipt_shape) nodes = list(graph.nodes()) @@ -283,7 +283,7 @@ class PyTorchGraph(Graph): raise NotImplementedError(err_msg) @staticmethod - def load_graph(graph_path: str): + def load_graph(graph_path: str, **kwargs): """ Load graph. diff --git a/tests/st/func/mindconverter/test_converter.py b/tests/st/func/mindconverter/test_converter.py index 5846f7b5..699b44f0 100644 --- a/tests/st/func/mindconverter/test_converter.py +++ b/tests/st/func/mindconverter/test_converter.py @@ -21,11 +21,13 @@ Usage: """ import difflib import os +import re import sys import pytest from mindinsight.mindconverter.converter import main +from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter @pytest.mark.usefixtures('create_output_dir') @@ -36,7 +38,9 @@ class TestConverter: def setup_class(cls): """Setup method.""" cls.script_dir = os.path.join(os.path.dirname(__file__), 'data') - cls.pytorch_dir = '/home/test/mindinsight_sample' + pytorch_base_dir = os.path.dirname(__file__).split('/')[:3] + cls.pytorch_dir = \ + '/'.join(pytorch_base_dir + ['share-data', 'dataset', 'mindinsight_dataset', 'resnet50']) sys.path.insert(0, cls.script_dir) @classmethod @@ -78,3 +82,35 @@ class TestConverter: converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) assert converted_ratio >= 80 + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + @pytest.mark.platform_x86_cpu + @pytest.mark.env_single + def test_main_graph_based_converter(self, output): + """Test main graph based converter.""" + pytorch_filename = "resnet50.pth" + expected_model_filename = "resnet50.py" + expected_report_filename = "report_of_resnet50.txt" + file_config = { + 'model_file': os.path.join(self.pytorch_dir, pytorch_filename), + 'shape': (1, 3, 224, 224), + 'outfile_dir': output, + 'report_dir': output + } + with pytest.raises(ValueError) as e: + main_graph_base_converter(file_config=file_config) + + assert os.path.isfile(os.path.join(output, expected_model_filename)) + assert os.path.isfile(os.path.join(output, expected_report_filename)) + + with open(os.path.join(output, expected_report_filename)) as converted_r: + converted_report = converted_r.readlines() + converted_rate = re.findall(r".*(?:Converted Rate: )(.*)[.]", converted_report[-1]) + + assert converted_rate[0] == '100.00%' + + exec_msg = e.value.args[0] + assert exec_msg == "torch.__spec__ is None"