Browse Source

!781 Add cli for tfToMS in MindConverter based on graph.

Merge pull request !781 from moran/network_validation
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2ddbc7b529
13 changed files with 294 additions and 54 deletions
  1. +69
    -8
      mindinsight/mindconverter/cli.py
  2. +10
    -0
      mindinsight/mindconverter/common/exceptions.py
  3. +3
    -2
      mindinsight/mindconverter/graph_based_converter/__init__.py
  4. +15
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  5. +139
    -19
      mindinsight/mindconverter/graph_based_converter/framework.py
  6. +1
    -0
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  7. +5
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  8. +4
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py
  9. +1
    -7
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  10. +0
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  11. +8
    -7
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py
  12. +2
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  13. +37
    -1
      tests/st/func/mindconverter/test_converter.py

+ 69
- 8
mindinsight/mindconverter/cli.py View File

@@ -19,6 +19,7 @@ import argparse


import mindinsight import mindinsight
from mindinsight.mindconverter.converter import main from mindinsight.mindconverter.converter import main
from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_SHAPE_NUMBER
from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter


from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
@@ -38,6 +39,11 @@ class FileDirAction(argparse.Action):
option_string (str): Optional string for specific argument name. Default: None. option_string (str): Optional string for specific argument name. Default: None.
""" """
outfile = values outfile = values

if len(outfile) > ARGUMENT_LENGTH_LIMIT:
parser_in.error(
f"The length of {option_string}{outfile} should be no more than {ARGUMENT_LENGTH_LIMIT}.")

if outfile.startswith('~'): if outfile.startswith('~'):
outfile = os.path.realpath(os.path.expanduser(outfile)) outfile = os.path.realpath(os.path.expanduser(outfile))


@@ -160,9 +166,6 @@ class ModelFileAction(argparse.Action):
if not os.path.isfile(outfile_dir): if not os.path.isfile(outfile_dir):
parser_in.error(f'{option_string} {outfile_dir} is not a file') parser_in.error(f'{option_string} {outfile_dir} is not a file')


if not outfile_dir.endswith('.pth'):
parser_in.error(f"{option_string} {outfile_dir} should be a Pytorch model, ending with '.pth'.")

setattr(namespace, self.dest, outfile_dir) setattr(namespace, self.dest, outfile_dir)




@@ -200,14 +203,42 @@ class ShapeAction(argparse.Action):
""" """
in_shape = None in_shape = None
shape_str = values shape_str = values

shape_list = shape_str.split(';')
if not len(shape_list) == EXPECTED_SHAPE_NUMBER:
parser_in.error(f"Only support one shape now, but get {len(shape_list)}.")

try: try:
in_shape = [int(num_shape) for num_shape in shape_str.split(',')]
in_shape = [int(num_shape) for num_shape in shape_list[0].split(',')]
except ValueError: except ValueError:
parser_in.error( parser_in.error(
f"{option_string} {shape_str} should be a list of integer split by ',', check it please.") f"{option_string} {shape_str} should be a list of integer split by ',', check it please.")
setattr(namespace, self.dest, in_shape) setattr(namespace, self.dest, in_shape)




class NodeAction(argparse.Action):
"""Node action class definition."""

def __call__(self, parser_in, namespace, values, option_string=None):
"""
Inherited __call__ method from FileDirAction.

Args:
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.

"""
node_str = values
if len(node_str) > ARGUMENT_LENGTH_LIMIT:
parser_in.error(
f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}."
)

setattr(namespace, self.dest, node_str)


parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='mindconverter', prog='mindconverter',
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__)) description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__))
@@ -234,7 +265,7 @@ parser.add_argument(
action=ModelFileAction, action=ModelFileAction,
required=False, required=False,
help=""" help="""
PyTorch .pth model file path to use graph
PyTorch .pth or Tensorflow .pb model file path to use graph
based schema to do script generation. When based schema to do script generation. When
`--in_file` and `--model_file` are both provided, `--in_file` and `--model_file` are both provided,
use AST schema as default. use AST schema as default.
@@ -250,7 +281,29 @@ parser.add_argument(
Optional, expected input tensor shape of Optional, expected input tensor shape of
`--model_file`. It's required when use graph based `--model_file`. It's required when use graph based
schema. schema.
Usage: --shape 3,244,244
Usage: --shape 1,3,244,244
""")

parser.add_argument(
'--input_nodes',
type=str,
action=NodeAction,
default=None,
required=False,
help="""
Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model.
Usage: --input_nodes input_1:0,input_2:0
""")

parser.add_argument(
'--output_nodes',
type=str,
action=NodeAction,
default=None,
required=False,
help="""
Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model.
Usage: --output_nodes output_1:0,output_2:0
""") """)


parser.add_argument( parser.add_argument(
@@ -305,10 +358,14 @@ def cli_entry():
if args.report is None: if args.report is None:
args.report = args.output args.report = args.output
os.makedirs(args.report, mode=mode, exist_ok=True) os.makedirs(args.report, mode=mode, exist_ok=True)
_run(args.in_file, args.model_file, args.shape, args.output, args.report, args.project_path)
_run(args.in_file, args.model_file,
args.shape,
args.input_nodes, args.output_nodes,
args.output, args.report,
args.project_path)




def _run(in_files, model_file, shape, out_dir, report, project_path):
def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report, project_path):
""" """
Run converter command. Run converter command.


@@ -316,6 +373,8 @@ def _run(in_files, model_file, shape, out_dir, report, project_path):
in_files (str): The file path or directory to convert. in_files (str): The file path or directory to convert.
model_file(str): The pytorch .pth to convert on graph based schema. model_file(str): The pytorch .pth to convert on graph based schema.
shape(list): The input tensor shape of module_file. shape(list): The input tensor shape of module_file.
input_nodes(str): The input node(s) name of Tensorflow model, split by ','.
output_nodes(str): The output node(s) name of Tensorflow model, split by ','.
out_dir (str): The output directory to save converted file. out_dir (str): The output directory to save converted file.
report (str): The report file path. report (str): The report file path.
project_path(str): Pytorch scripts project path. project_path(str): Pytorch scripts project path.
@@ -341,6 +400,8 @@ def _run(in_files, model_file, shape, out_dir, report, project_path):
file_config = { file_config = {
'model_file': model_file, 'model_file': model_file,
'shape': shape if shape else [], 'shape': shape if shape else [],
'input_nodes': input_nodes,
'output_nodes': output_nodes,
'outfile_dir': out_dir, 'outfile_dir': out_dir,
'report_dir': report if report else out_dir 'report_dir': report if report else out_dir
} }


+ 10
- 0
mindinsight/mindconverter/common/exceptions.py View File

@@ -26,6 +26,7 @@ class ConverterErrors(ScriptConverterErrors):
NODE_TYPE_NOT_SUPPORT = 2 NODE_TYPE_NOT_SUPPORT = 2
CODE_SYNTAX_ERROR = 3 CODE_SYNTAX_ERROR = 3
NODE_INPUT_TYPE_NOT_SUPPORT = 4 NODE_INPUT_TYPE_NOT_SUPPORT = 4
UNKNOWN_MODEL = 5




class ScriptNotSupport(MindInsightException): class ScriptNotSupport(MindInsightException):
@@ -62,3 +63,12 @@ class NodeInputTypeNotSupport(MindInsightException):
super(NodeInputTypeNotSupport, self).__init__(ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT, super(NodeInputTypeNotSupport, self).__init__(ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT,
msg, msg,
http_code=400) http_code=400)


class UnknownModel(MindInsightException):
"""The unknown model error."""

def __init__(self, msg):
super(UnknownModel, self).__init__(ConverterErrors.UNKNOWN_MODEL,
msg,
http_code=400)

+ 3
- 2
mindinsight/mindconverter/graph_based_converter/__init__.py View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Graph based scripts converter definition.""" """Graph based scripts converter definition."""
from .framework import graph_based_converter
from .framework import graph_based_converter_pytorch_to_ms
from .framework import graph_based_converter_tf_to_ms


__all__ = ["graph_based_converter"]
__all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"]

+ 15
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -35,6 +35,15 @@ ONNX_TYPE_FLOAT = 1
ONNX_TYPE_FLOATS = 6 ONNX_TYPE_FLOATS = 6
ONNX_TYPE_STRING = 3 ONNX_TYPE_STRING = 3


BINARY_HEADER_PYTORCH_FILE = \
b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00'

BINARY_HEADER_PYTORCH_BITS = 32

ARGUMENT_LENGTH_LIMIT = 512

EXPECTED_SHAPE_NUMBER = 1



@unique @unique
class CodeFormatConfig(Enum): class CodeFormatConfig(Enum):
@@ -54,3 +63,9 @@ class NodeType(Enum):
class InputType(Enum): class InputType(Enum):
TENSOR = "tensor" TENSOR = "tensor"
LIST = "list" LIST = "list"


@unique
class FrameworkType(Enum):
PYTORCH = 0
TENSORFLOW = 1

+ 139
- 19
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -16,12 +16,16 @@
import os import os
import re import re
import argparse import argparse
from importlib import import_module
from importlib.util import find_spec from importlib.util import find_spec


import mindinsight import mindinsight
from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \
BINARY_HEADER_PYTORCH_BITS
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport, UnknownModel
from mindinsight.utils.exceptions import ParamMissError


permissions = os.R_OK | os.W_OK | os.X_OK permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions) os.umask(permissions << 3 | permissions)
@@ -56,8 +60,7 @@ def torch_installation_validation(func):
""" """


def _f(graph_path: str, sample_shape: tuple, def _f(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None,
checkpoint_path: str = None):
output_folder: str, report_folder: str = None):
# Check whether pytorch is installed. # Check whether pytorch is installed.
if not find_spec("torch"): if not find_spec("torch"):
error = ModuleNotFoundError("PyTorch is required when using graph based " error = ModuleNotFoundError("PyTorch is required when using graph based "
@@ -67,9 +70,35 @@ def torch_installation_validation(func):
log.exception(error) log.exception(error)
raise error raise error


func(graph_path=graph_path, sample_shape=sample_shape,
output_folder=output_folder, report_folder=report_folder)

return _f


def tf_installation_validation(func):
"""
Validate args of func.

Args:
func(type): Function.

Returns:
type, inner function.
"""

def _f(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None,
input_nodes: str = None, output_nodes: str = None):
# Check whether tensorflow is installed.
if not find_spec("tensorflow") or not find_spec("tf2onnx"):
error = ModuleNotFoundError("Tensorflow and tf2onnx are required when using "
"graph based scripts converter.")
log.error(str(error))
raise error
func(graph_path=graph_path, sample_shape=sample_shape, func(graph_path=graph_path, sample_shape=sample_shape,
output_folder=output_folder, report_folder=report_folder, output_folder=output_folder, report_folder=report_folder,
checkpoint_path=checkpoint_path)
input_nodes=input_nodes, output_nodes=output_nodes)


return _f return _f


@@ -85,32 +114,33 @@ def _extract_model_name(model_path):
str: Name of Converted model. str: Name of Converted model.
""" """


model_name = re.findall(r".*[/](.*).pth", model_path)[-1]
model_name = re.findall(r".*[/](.*)(?:\.pth|\.pb)", model_path)[-1]
return model_name return model_name




@torch_installation_validation @torch_installation_validation
def graph_based_converter(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None,
checkpoint_path: str = None):
def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None):
""" """
Graph based scripts converter.
Pytoch to MindSpore based on Graph.


Args: Args:
graph_path (str): Graph file path. graph_path (str): Graph file path.
sample_shape (tuple): Input shape of the model. sample_shape (tuple): Input shape of the model.
output_folder (str): Output folder. output_folder (str): Output folder.
report_folder (str): Report output folder path. report_folder (str): Report output folder path.
checkpoint_path (str): Checkpoint file path.


""" """
from .third_party_graph import GraphFactory
from .hierarchical_tree import HierarchicalTreeFactory
third_party_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph')
hierarchical_tree_module = import_module(
'mindinsight.mindconverter.graph_based_converter.hierarchical_tree')
cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory')


graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
checkpoint=checkpoint_path)
graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape)
try: try:
hierarchical_tree = HierarchicalTreeFactory.create(graph_obj)
hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
log.error("Error occur when create hierarchical tree.") log.error("Error occur when create hierarchical tree.")
@@ -123,6 +153,49 @@ def graph_based_converter(graph_path: str, sample_shape: tuple,
report_folder=report_folder) report_folder=report_folder)




@tf_installation_validation
def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None):
"""
Tensorflow to MindSpore based on Graph.

Args:
graph_path(str): Graph file path.
sample_shape(tuple): Input shape of the model.
input_nodes(str): Input node(s) of the model.
output_nodes(str): Output node(s) of the model.
output_folder(str): Output folder.
report_folder(str): Report output folder path.

"""
third_party_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph')
hierarchical_tree_module = import_module(
'mindinsight.mindconverter.graph_based_converter.hierarchical_tree')
cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory')
# Close unnecessary log.
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes)

try:
hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj)
except Exception as e:
log.exception(e)
log.error("Error occur when create hierarchical tree.")
raise NodeTypeNotSupport("This model is not supported now.")

model_name = _extract_model_name(graph_path)

hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
model_name=model_name,
report_folder=report_folder,
scope_name_map=scope_name_map)


def main_graph_base_converter(file_config): def main_graph_base_converter(file_config):
""" """
The entrance for converter, script files will be converted. The entrance for converter, script files will be converted.
@@ -130,7 +203,54 @@ def main_graph_base_converter(file_config):
Args: Args:
file_config (dict): The config of file which to convert. file_config (dict): The config of file which to convert.
""" """
graph_based_converter(graph_path=file_config['model_file'],
sample_shape=file_config['shape'],
output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir'])
graph_path = file_config['model_file']
frame_type = get_framework_type(graph_path)
if frame_type == FrameworkType.PYTORCH.value:
graph_based_converter_pytorch_to_ms(graph_path=graph_path,
sample_shape=file_config['shape'],
output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir'])
elif frame_type == FrameworkType.TENSORFLOW.value:
check_params = ['input_nodes', 'output_nodes']
check_params_exist(check_params, file_config)
graph_based_converter_tf_to_ms(graph_path=graph_path,
sample_shape=file_config['shape'],
input_nodes=file_config['input_nodes'],
output_nodes=file_config['output_nodes'],
output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir'])
else:
error_msg = "Get UNSUPPORTED model."
error = UnknownModel(error_msg)
log.error(str(error))
log.exception(error)
raise error


def get_framework_type(model_path):
"""Get framework type."""
try:
with open(model_path, 'rb') as f:
if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE:
framework_type = FrameworkType.PYTORCH.value
else:
framework_type = FrameworkType.TENSORFLOW.value
except IOError:
error_msg = "Get UNSUPPORTED model."
error = UnknownModel(error_msg)
log.error(str(error))
log.exception(error)
raise error

return framework_type


def check_params_exist(params: list, config):
"""Check params exist."""
miss_param_list = ''
for param in params:
if not config.get(param) or not config[param]:
miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param

if miss_param_list:
raise ParamMissError(miss_param_list)

+ 1
- 0
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -182,6 +182,7 @@ class HierarchicalTree(Tree):
mapper (Mapper): Mapper of third party framework and mindspore. mapper (Mapper): Mapper of third party framework and mindspore.
model_name(str): Name of Converted model. model_name(str): Name of Converted model.
out_folder (str): Output folder. out_folder (str): Output folder.
scope_name_map(str): Scope name map of tensorflow.


""" """
if scope_name_map: if scope_name_map:


+ 5
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -117,12 +117,14 @@ class ConvMapper(ONNXToMindSporeMapper):


@staticmethod @staticmethod
def _operation_name_in_ms(*args, **kwargs): def _operation_name_in_ms(*args, **kwargs):
if not kwargs['weights'].get('weight'): # is from tf
weight = kwargs['weights'].get('weight', 'empty')

if weight == 'empty': # is from tf
kernel_size = kwargs['params'].get('kernel_shape') kernel_size = kwargs['params'].get('kernel_shape')
dim = len(kernel_size) dim = len(kernel_size)
return f"nn.Conv{dim}d" return f"nn.Conv{dim}d"


weight = kwargs['weights']['weight'].numpy()
weight = weight.numpy()
dim = weight.ndim - 2 dim = weight.ndim - 2
return f"nn.Conv{dim}d" return f"nn.Conv{dim}d"


@@ -131,7 +133,7 @@ class ConvMapper(ONNXToMindSporeMapper):
weights = kwargs['weights'] weights = kwargs['weights']
params = kwargs['params'] params = kwargs['params']


if not weights.get('weight'): # is from tf
if weights.get('weight', 'empty') == 'empty': # is from tf
return ConvMapper.convert_params_tf(params=params, weights=weights) return ConvMapper.convert_params_tf(params=params, weights=weights)
return ConvMapper.convert_params_torch(params=params, weights=weights) return ConvMapper.convert_params_torch(params=params, weights=weights)




+ 4
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py View File

@@ -25,15 +25,16 @@ class GraphFactory:


@classmethod @classmethod
def init(cls, graph_path: str, def init(cls, graph_path: str,
input_nodes: str, output_nodes: str,
sample_shape: tuple):
sample_shape: tuple,
input_nodes: str = None, output_nodes: str = None):
""" """
Init an instance of graph. Init an instance of graph.


Args: Args:
graph_path (str): Graph or model file path. graph_path (str): Graph or model file path.
sample_shape (tuple): Input shape of the model. sample_shape (tuple): Input shape of the model.
checkpoint (str): Checkpoint file path.
input_nodes(str): Input nodes.
output_nodes(str): Output nodes.


Returns: Returns:
Graph, graph instance. Graph, graph instance.


+ 1
- 7
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -246,17 +246,11 @@ class Graph(BaseGraph, abc.ABC):
model_path (str): Graph or model file path. model_path (str): Graph or model file path.
sample_shape (tuple): Input shape of the model. sample_shape (tuple): Input shape of the model.
checkpoint (str): Checkpoint file path. checkpoint (str): Checkpoint file path.
input_nodes (list[str]): list of input nodes' name
output_nodes (list[str]): list of output nodes' name


Returns: Returns:
cls, graph instance. cls, graph instance.
""" """
tf_input_nodes = kwargs.get('input_nodes')
tf_output_nodes = kwargs.get('output_nodes')
src_graph = cls.load_graph(graph_path=model_path,
input_nodes=tf_input_nodes,
output_nodes=tf_output_nodes)
src_graph = cls.load_graph(graph_path=model_path, **kwargs)
ckpt = cls.load_checkpoint( ckpt = cls.load_checkpoint(
ckpt_path=checkpoint) if checkpoint else None ckpt_path=checkpoint) if checkpoint else None




+ 0
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -193,8 +193,6 @@ class OnnxGraph(Graph):


Args: Args:
graph_path (str): Graph path. graph_path (str): Graph path.
tf_input_nodes (str): input nodes of tf graph
tf_output_nodes (str): output nodes of tf graph


Returns: Returns:
object, ONNX model. object, ONNX model.


+ 8
- 7
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py View File

@@ -18,8 +18,9 @@ from copy import deepcopy
from .base import GraphNode from .base import GraphNode


from ..constant import NodeType, SEPARATOR_IN_SCOPE, \ from ..constant import NodeType, SEPARATOR_IN_SCOPE, \
SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP
SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP, InputType
from ..mapper.base import Mapper from ..mapper.base import Mapper
from ...common.exceptions import NodeInputTypeNotSupport




class OnnxGraphNode(GraphNode): class OnnxGraphNode(GraphNode):
@@ -160,10 +161,10 @@ class OnnxGraphNode(GraphNode):
Args: Args:
op_name (str): Add the tensor to args if the current node has this op_name (str): Add the tensor to args if the current node has this
op_name. op_name.
t_identifier (str): The unique strinf appeared in the target tensor
t_identifier (str): The unique string appeared in the target tensor
name. name.
declare_s (str): Declare statement generated in to_code().
init_s (str): init statement generated in to_code().
declare (str): Declare statement generated in to_code().
args (str): Args statement generated in to_code().


Returns: Returns:
declare_list list, multiple declare statements. declare_list list, multiple declare statements.
@@ -226,9 +227,9 @@ class OnnxGraphNode(GraphNode):
declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code(
'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct) 'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct)


# Extra Tensor generator for onnx::BiasAdd
# Extra Tensor generator for onnx::Add
declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code( declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code(
'onnx::MatMul', 'BiasAdd', declare, ipt_args_settings_in_construct)
'onnx::Add', 'BiasAdd', declare, ipt_args_settings_in_construct)


call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})" call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})"


@@ -320,7 +321,7 @@ class OnnxGraphNode(GraphNode):


def param_transform(self, mapper: Mapper): def param_transform(self, mapper: Mapper):
""" """
Transform torch params into mindspore.
Transform tensorflow params into mindspore.


Args: Args:
mapper (Mapper): Mapper of params. mapper (Mapper): Mapper of params.


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -173,7 +173,7 @@ class PyTorchGraph(Graph):
""" """
self._check_input_shape(input_shape) self._check_input_shape(input_shape)


feed_forward_ipt_shape = (1, *input_shape)
feed_forward_ipt_shape = tuple(input_shape)
graph = self._trace_torch_graph(feed_forward_ipt_shape) graph = self._trace_torch_graph(feed_forward_ipt_shape)
nodes = list(graph.nodes()) nodes = list(graph.nodes())


@@ -283,7 +283,7 @@ class PyTorchGraph(Graph):
raise NotImplementedError(err_msg) raise NotImplementedError(err_msg)


@staticmethod @staticmethod
def load_graph(graph_path: str):
def load_graph(graph_path: str, **kwargs):
""" """
Load graph. Load graph.




+ 37
- 1
tests/st/func/mindconverter/test_converter.py View File

@@ -21,11 +21,13 @@ Usage:
""" """
import difflib import difflib
import os import os
import re
import sys import sys


import pytest import pytest


from mindinsight.mindconverter.converter import main from mindinsight.mindconverter.converter import main
from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter




@pytest.mark.usefixtures('create_output_dir') @pytest.mark.usefixtures('create_output_dir')
@@ -36,7 +38,9 @@ class TestConverter:
def setup_class(cls): def setup_class(cls):
"""Setup method.""" """Setup method."""
cls.script_dir = os.path.join(os.path.dirname(__file__), 'data') cls.script_dir = os.path.join(os.path.dirname(__file__), 'data')
cls.pytorch_dir = '/home/test/mindinsight_sample'
pytorch_base_dir = os.path.dirname(__file__).split('/')[:3]
cls.pytorch_dir = \
'/'.join(pytorch_base_dir + ['share-data', 'dataset', 'mindinsight_dataset', 'resnet50'])
sys.path.insert(0, cls.script_dir) sys.path.insert(0, cls.script_dir)


@classmethod @classmethod
@@ -78,3 +82,35 @@ class TestConverter:


converted_ratio = 100 - (diff_lines * 100) / (len(expect_source)) converted_ratio = 100 - (diff_lines * 100) / (len(expect_source))
assert converted_ratio >= 80 assert converted_ratio >= 80

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_main_graph_based_converter(self, output):
"""Test main graph based converter."""
pytorch_filename = "resnet50.pth"
expected_model_filename = "resnet50.py"
expected_report_filename = "report_of_resnet50.txt"
file_config = {
'model_file': os.path.join(self.pytorch_dir, pytorch_filename),
'shape': (1, 3, 224, 224),
'outfile_dir': output,
'report_dir': output
}
with pytest.raises(ValueError) as e:
main_graph_base_converter(file_config=file_config)

assert os.path.isfile(os.path.join(output, expected_model_filename))
assert os.path.isfile(os.path.join(output, expected_report_filename))

with open(os.path.join(output, expected_report_filename)) as converted_r:
converted_report = converted_r.readlines()
converted_rate = re.findall(r".*(?:Converted Rate: )(.*)[.]", converted_report[-1])

assert converted_rate[0] == '100.00%'

exec_msg = e.value.args[0]
assert exec_msg == "torch.__spec__ is None"

Loading…
Cancel
Save