Browse Source

Add cli for tfToMs

tags/v1.1.0
moran 5 years ago
parent
commit
969abd4cf7
13 changed files with 294 additions and 54 deletions
  1. +69
    -8
      mindinsight/mindconverter/cli.py
  2. +10
    -0
      mindinsight/mindconverter/common/exceptions.py
  3. +3
    -2
      mindinsight/mindconverter/graph_based_converter/__init__.py
  4. +15
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  5. +139
    -19
      mindinsight/mindconverter/graph_based_converter/framework.py
  6. +1
    -0
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py
  7. +5
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  8. +4
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py
  9. +1
    -7
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  10. +0
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  11. +8
    -7
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py
  12. +2
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py
  13. +37
    -1
      tests/st/func/mindconverter/test_converter.py

+ 69
- 8
mindinsight/mindconverter/cli.py View File

@@ -19,6 +19,7 @@ import argparse

import mindinsight
from mindinsight.mindconverter.converter import main
from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_SHAPE_NUMBER
from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter

from mindinsight.mindconverter.common.log import logger as log
@@ -38,6 +39,11 @@ class FileDirAction(argparse.Action):
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile = values

if len(outfile) > ARGUMENT_LENGTH_LIMIT:
parser_in.error(
f"The length of {option_string}{outfile} should be no more than {ARGUMENT_LENGTH_LIMIT}.")

if outfile.startswith('~'):
outfile = os.path.realpath(os.path.expanduser(outfile))

@@ -160,9 +166,6 @@ class ModelFileAction(argparse.Action):
if not os.path.isfile(outfile_dir):
parser_in.error(f'{option_string} {outfile_dir} is not a file')

if not outfile_dir.endswith('.pth'):
parser_in.error(f"{option_string} {outfile_dir} should be a Pytorch model, ending with '.pth'.")

setattr(namespace, self.dest, outfile_dir)


@@ -200,14 +203,42 @@ class ShapeAction(argparse.Action):
"""
in_shape = None
shape_str = values

shape_list = shape_str.split(';')
if not len(shape_list) == EXPECTED_SHAPE_NUMBER:
parser_in.error(f"Only support one shape now, but get {len(shape_list)}.")

try:
in_shape = [int(num_shape) for num_shape in shape_str.split(',')]
in_shape = [int(num_shape) for num_shape in shape_list[0].split(',')]
except ValueError:
parser_in.error(
f"{option_string} {shape_str} should be a list of integer split by ',', check it please.")
setattr(namespace, self.dest, in_shape)


class NodeAction(argparse.Action):
"""Node action class definition."""

def __call__(self, parser_in, namespace, values, option_string=None):
"""
Inherited __call__ method from FileDirAction.

Args:
parser_in (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.

"""
node_str = values
if len(node_str) > ARGUMENT_LENGTH_LIMIT:
parser_in.error(
f"The length of {option_string}{node_str} should be no more than {ARGUMENT_LENGTH_LIMIT}."
)

setattr(namespace, self.dest, node_str)


parser = argparse.ArgumentParser(
prog='mindconverter',
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__))
@@ -234,7 +265,7 @@ parser.add_argument(
action=ModelFileAction,
required=False,
help="""
PyTorch .pth model file path to use graph
PyTorch .pth or Tensorflow .pb model file path to use graph
based schema to do script generation. When
`--in_file` and `--model_file` are both provided,
use AST schema as default.
@@ -250,7 +281,29 @@ parser.add_argument(
Optional, expected input tensor shape of
`--model_file`. It's required when use graph based
schema.
Usage: --shape 3,244,244
Usage: --shape 1,3,244,244
""")

parser.add_argument(
'--input_nodes',
type=str,
action=NodeAction,
default=None,
required=False,
help="""
Optional, input node(s) name of `--model_file`. It's required when use Tensorflow model.
Usage: --input_nodes input_1:0,input_2:0
""")

parser.add_argument(
'--output_nodes',
type=str,
action=NodeAction,
default=None,
required=False,
help="""
Optional, output node(s) name of `--model_file`. It's required when use Tensorflow model.
Usage: --output_nodes output_1:0,output_2:0
""")

parser.add_argument(
@@ -305,10 +358,14 @@ def cli_entry():
if args.report is None:
args.report = args.output
os.makedirs(args.report, mode=mode, exist_ok=True)
_run(args.in_file, args.model_file, args.shape, args.output, args.report, args.project_path)
_run(args.in_file, args.model_file,
args.shape,
args.input_nodes, args.output_nodes,
args.output, args.report,
args.project_path)


def _run(in_files, model_file, shape, out_dir, report, project_path):
def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report, project_path):
"""
Run converter command.

@@ -316,6 +373,8 @@ def _run(in_files, model_file, shape, out_dir, report, project_path):
in_files (str): The file path or directory to convert.
model_file(str): The pytorch .pth to convert on graph based schema.
shape(list): The input tensor shape of module_file.
input_nodes(str): The input node(s) name of Tensorflow model, split by ','.
output_nodes(str): The output node(s) name of Tensorflow model, split by ','.
out_dir (str): The output directory to save converted file.
report (str): The report file path.
project_path(str): Pytorch scripts project path.
@@ -341,6 +400,8 @@ def _run(in_files, model_file, shape, out_dir, report, project_path):
file_config = {
'model_file': model_file,
'shape': shape if shape else [],
'input_nodes': input_nodes,
'output_nodes': output_nodes,
'outfile_dir': out_dir,
'report_dir': report if report else out_dir
}


+ 10
- 0
mindinsight/mindconverter/common/exceptions.py View File

@@ -26,6 +26,7 @@ class ConverterErrors(ScriptConverterErrors):
NODE_TYPE_NOT_SUPPORT = 2
CODE_SYNTAX_ERROR = 3
NODE_INPUT_TYPE_NOT_SUPPORT = 4
UNKNOWN_MODEL = 5


class ScriptNotSupport(MindInsightException):
@@ -62,3 +63,12 @@ class NodeInputTypeNotSupport(MindInsightException):
super(NodeInputTypeNotSupport, self).__init__(ConverterErrors.NODE_INPUT_TYPE_NOT_SUPPORT,
msg,
http_code=400)


class UnknownModel(MindInsightException):
"""The unknown model error."""

def __init__(self, msg):
super(UnknownModel, self).__init__(ConverterErrors.UNKNOWN_MODEL,
msg,
http_code=400)

+ 3
- 2
mindinsight/mindconverter/graph_based_converter/__init__.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Graph based scripts converter definition."""
from .framework import graph_based_converter
from .framework import graph_based_converter_pytorch_to_ms
from .framework import graph_based_converter_tf_to_ms

__all__ = ["graph_based_converter"]
__all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"]

+ 15
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -35,6 +35,15 @@ ONNX_TYPE_FLOAT = 1
ONNX_TYPE_FLOATS = 6
ONNX_TYPE_STRING = 3

BINARY_HEADER_PYTORCH_FILE = \
b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00'

BINARY_HEADER_PYTORCH_BITS = 32

ARGUMENT_LENGTH_LIMIT = 512

EXPECTED_SHAPE_NUMBER = 1


@unique
class CodeFormatConfig(Enum):
@@ -54,3 +63,9 @@ class NodeType(Enum):
class InputType(Enum):
TENSOR = "tensor"
LIST = "list"


@unique
class FrameworkType(Enum):
PYTORCH = 0
TENSORFLOW = 1

+ 139
- 19
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -16,12 +16,16 @@
import os
import re
import argparse
from importlib import import_module
from importlib.util import find_spec

import mindinsight
from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \
BINARY_HEADER_PYTORCH_BITS
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport, UnknownModel
from mindinsight.utils.exceptions import ParamMissError

permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)
@@ -56,8 +60,7 @@ def torch_installation_validation(func):
"""

def _f(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None,
checkpoint_path: str = None):
output_folder: str, report_folder: str = None):
# Check whether pytorch is installed.
if not find_spec("torch"):
error = ModuleNotFoundError("PyTorch is required when using graph based "
@@ -67,9 +70,35 @@ def torch_installation_validation(func):
log.exception(error)
raise error

func(graph_path=graph_path, sample_shape=sample_shape,
output_folder=output_folder, report_folder=report_folder)

return _f


def tf_installation_validation(func):
"""
Validate args of func.

Args:
func(type): Function.

Returns:
type, inner function.
"""

def _f(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None,
input_nodes: str = None, output_nodes: str = None):
# Check whether tensorflow is installed.
if not find_spec("tensorflow") or not find_spec("tf2onnx"):
error = ModuleNotFoundError("Tensorflow and tf2onnx are required when using "
"graph based scripts converter.")
log.error(str(error))
raise error
func(graph_path=graph_path, sample_shape=sample_shape,
output_folder=output_folder, report_folder=report_folder,
checkpoint_path=checkpoint_path)
input_nodes=input_nodes, output_nodes=output_nodes)

return _f

@@ -85,32 +114,33 @@ def _extract_model_name(model_path):
str: Name of Converted model.
"""

model_name = re.findall(r".*[/](.*).pth", model_path)[-1]
model_name = re.findall(r".*[/](.*)(?:\.pth|\.pb)", model_path)[-1]
return model_name


@torch_installation_validation
def graph_based_converter(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None,
checkpoint_path: str = None):
def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
output_folder: str, report_folder: str = None):
"""
Graph based scripts converter.
Pytoch to MindSpore based on Graph.

Args:
graph_path (str): Graph file path.
sample_shape (tuple): Input shape of the model.
output_folder (str): Output folder.
report_folder (str): Report output folder path.
checkpoint_path (str): Checkpoint file path.

"""
from .third_party_graph import GraphFactory
from .hierarchical_tree import HierarchicalTreeFactory
third_party_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph')
hierarchical_tree_module = import_module(
'mindinsight.mindconverter.graph_based_converter.hierarchical_tree')
cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory')

graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
checkpoint=checkpoint_path)
graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape)
try:
hierarchical_tree = HierarchicalTreeFactory.create(graph_obj)
hierarchical_tree = cls_hierarchical_tree_factory.create(graph_obj)
except Exception as e:
log.exception(e)
log.error("Error occur when create hierarchical tree.")
@@ -123,6 +153,49 @@ def graph_based_converter(graph_path: str, sample_shape: tuple,
report_folder=report_folder)


@tf_installation_validation
def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None):
"""
Tensorflow to MindSpore based on Graph.

Args:
graph_path(str): Graph file path.
sample_shape(tuple): Input shape of the model.
input_nodes(str): Input node(s) of the model.
output_nodes(str): Output node(s) of the model.
output_folder(str): Output folder.
report_folder(str): Report output folder path.

"""
third_party_graph_module = import_module(
'mindinsight.mindconverter.graph_based_converter.third_party_graph')
hierarchical_tree_module = import_module(
'mindinsight.mindconverter.graph_based_converter.hierarchical_tree')
cls_graph_factory = getattr(third_party_graph_module, 'GraphFactory')
cls_hierarchical_tree_factory = getattr(hierarchical_tree_module, 'HierarchicalTreeFactory')
# Close unnecessary log.
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

graph_obj = cls_graph_factory.init(graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes)

try:
hierarchical_tree, scope_name_map = cls_hierarchical_tree_factory.create(graph_obj)
except Exception as e:
log.exception(e)
log.error("Error occur when create hierarchical tree.")
raise NodeTypeNotSupport("This model is not supported now.")

model_name = _extract_model_name(graph_path)

hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper,
model_name=model_name,
report_folder=report_folder,
scope_name_map=scope_name_map)


def main_graph_base_converter(file_config):
"""
The entrance for converter, script files will be converted.
@@ -130,7 +203,54 @@ def main_graph_base_converter(file_config):
Args:
file_config (dict): The config of file which to convert.
"""
graph_based_converter(graph_path=file_config['model_file'],
sample_shape=file_config['shape'],
output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir'])
graph_path = file_config['model_file']
frame_type = get_framework_type(graph_path)
if frame_type == FrameworkType.PYTORCH.value:
graph_based_converter_pytorch_to_ms(graph_path=graph_path,
sample_shape=file_config['shape'],
output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir'])
elif frame_type == FrameworkType.TENSORFLOW.value:
check_params = ['input_nodes', 'output_nodes']
check_params_exist(check_params, file_config)
graph_based_converter_tf_to_ms(graph_path=graph_path,
sample_shape=file_config['shape'],
input_nodes=file_config['input_nodes'],
output_nodes=file_config['output_nodes'],
output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir'])
else:
error_msg = "Get UNSUPPORTED model."
error = UnknownModel(error_msg)
log.error(str(error))
log.exception(error)
raise error


def get_framework_type(model_path):
"""Get framework type."""
try:
with open(model_path, 'rb') as f:
if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE:
framework_type = FrameworkType.PYTORCH.value
else:
framework_type = FrameworkType.TENSORFLOW.value
except IOError:
error_msg = "Get UNSUPPORTED model."
error = UnknownModel(error_msg)
log.error(str(error))
log.exception(error)
raise error

return framework_type


def check_params_exist(params: list, config):
"""Check params exist."""
miss_param_list = ''
for param in params:
if not config.get(param) or not config[param]:
miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param

if miss_param_list:
raise ParamMissError(miss_param_list)

+ 1
- 0
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py View File

@@ -182,6 +182,7 @@ class HierarchicalTree(Tree):
mapper (Mapper): Mapper of third party framework and mindspore.
model_name(str): Name of Converted model.
out_folder (str): Output folder.
scope_name_map(str): Scope name map of tensorflow.

"""
if scope_name_map:


+ 5
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -117,12 +117,14 @@ class ConvMapper(ONNXToMindSporeMapper):

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
if not kwargs['weights'].get('weight'): # is from tf
weight = kwargs['weights'].get('weight', 'empty')

if weight == 'empty': # is from tf
kernel_size = kwargs['params'].get('kernel_shape')
dim = len(kernel_size)
return f"nn.Conv{dim}d"

weight = kwargs['weights']['weight'].numpy()
weight = weight.numpy()
dim = weight.ndim - 2
return f"nn.Conv{dim}d"

@@ -131,7 +133,7 @@ class ConvMapper(ONNXToMindSporeMapper):
weights = kwargs['weights']
params = kwargs['params']

if not weights.get('weight'): # is from tf
if weights.get('weight', 'empty') == 'empty': # is from tf
return ConvMapper.convert_params_tf(params=params, weights=weights)
return ConvMapper.convert_params_torch(params=params, weights=weights)



+ 4
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py View File

@@ -25,15 +25,16 @@ class GraphFactory:

@classmethod
def init(cls, graph_path: str,
input_nodes: str, output_nodes: str,
sample_shape: tuple):
sample_shape: tuple,
input_nodes: str = None, output_nodes: str = None):
"""
Init an instance of graph.

Args:
graph_path (str): Graph or model file path.
sample_shape (tuple): Input shape of the model.
checkpoint (str): Checkpoint file path.
input_nodes(str): Input nodes.
output_nodes(str): Output nodes.

Returns:
Graph, graph instance.


+ 1
- 7
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -246,17 +246,11 @@ class Graph(BaseGraph, abc.ABC):
model_path (str): Graph or model file path.
sample_shape (tuple): Input shape of the model.
checkpoint (str): Checkpoint file path.
input_nodes (list[str]): list of input nodes' name
output_nodes (list[str]): list of output nodes' name

Returns:
cls, graph instance.
"""
tf_input_nodes = kwargs.get('input_nodes')
tf_output_nodes = kwargs.get('output_nodes')
src_graph = cls.load_graph(graph_path=model_path,
input_nodes=tf_input_nodes,
output_nodes=tf_output_nodes)
src_graph = cls.load_graph(graph_path=model_path, **kwargs)
ckpt = cls.load_checkpoint(
ckpt_path=checkpoint) if checkpoint else None



+ 0
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -193,8 +193,6 @@ class OnnxGraph(Graph):

Args:
graph_path (str): Graph path.
tf_input_nodes (str): input nodes of tf graph
tf_output_nodes (str): output nodes of tf graph

Returns:
object, ONNX model.


+ 8
- 7
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py View File

@@ -18,8 +18,9 @@ from copy import deepcopy
from .base import GraphNode

from ..constant import NodeType, SEPARATOR_IN_SCOPE, \
SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP
SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP, InputType
from ..mapper.base import Mapper
from ...common.exceptions import NodeInputTypeNotSupport


class OnnxGraphNode(GraphNode):
@@ -160,10 +161,10 @@ class OnnxGraphNode(GraphNode):
Args:
op_name (str): Add the tensor to args if the current node has this
op_name.
t_identifier (str): The unique strinf appeared in the target tensor
t_identifier (str): The unique string appeared in the target tensor
name.
declare_s (str): Declare statement generated in to_code().
init_s (str): init statement generated in to_code().
declare (str): Declare statement generated in to_code().
args (str): Args statement generated in to_code().

Returns:
declare_list list, multiple declare statements.
@@ -226,9 +227,9 @@ class OnnxGraphNode(GraphNode):
declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code(
'onnx::MatMul', 'MatMul', declare, ipt_args_settings_in_construct)

# Extra Tensor generator for onnx::BiasAdd
# Extra Tensor generator for onnx::Add
declare, ipt_args_settings_in_construct = self._add_tensor_args_to_code(
'onnx::MatMul', 'BiasAdd', declare, ipt_args_settings_in_construct)
'onnx::Add', 'BiasAdd', declare, ipt_args_settings_in_construct)

call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_settings_in_construct})"

@@ -320,7 +321,7 @@ class OnnxGraphNode(GraphNode):

def param_transform(self, mapper: Mapper):
"""
Transform torch params into mindspore.
Transform tensorflow params into mindspore.

Args:
mapper (Mapper): Mapper of params.


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -173,7 +173,7 @@ class PyTorchGraph(Graph):
"""
self._check_input_shape(input_shape)

feed_forward_ipt_shape = (1, *input_shape)
feed_forward_ipt_shape = tuple(input_shape)
graph = self._trace_torch_graph(feed_forward_ipt_shape)
nodes = list(graph.nodes())

@@ -283,7 +283,7 @@ class PyTorchGraph(Graph):
raise NotImplementedError(err_msg)

@staticmethod
def load_graph(graph_path: str):
def load_graph(graph_path: str, **kwargs):
"""
Load graph.



+ 37
- 1
tests/st/func/mindconverter/test_converter.py View File

@@ -21,11 +21,13 @@ Usage:
"""
import difflib
import os
import re
import sys

import pytest

from mindinsight.mindconverter.converter import main
from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter


@pytest.mark.usefixtures('create_output_dir')
@@ -36,7 +38,9 @@ class TestConverter:
def setup_class(cls):
"""Setup method."""
cls.script_dir = os.path.join(os.path.dirname(__file__), 'data')
cls.pytorch_dir = '/home/test/mindinsight_sample'
pytorch_base_dir = os.path.dirname(__file__).split('/')[:3]
cls.pytorch_dir = \
'/'.join(pytorch_base_dir + ['share-data', 'dataset', 'mindinsight_dataset', 'resnet50'])
sys.path.insert(0, cls.script_dir)

@classmethod
@@ -78,3 +82,35 @@ class TestConverter:

converted_ratio = 100 - (diff_lines * 100) / (len(expect_source))
assert converted_ratio >= 80

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_main_graph_based_converter(self, output):
"""Test main graph based converter."""
pytorch_filename = "resnet50.pth"
expected_model_filename = "resnet50.py"
expected_report_filename = "report_of_resnet50.txt"
file_config = {
'model_file': os.path.join(self.pytorch_dir, pytorch_filename),
'shape': (1, 3, 224, 224),
'outfile_dir': output,
'report_dir': output
}
with pytest.raises(ValueError) as e:
main_graph_base_converter(file_config=file_config)

assert os.path.isfile(os.path.join(output, expected_model_filename))
assert os.path.isfile(os.path.join(output, expected_report_filename))

with open(os.path.join(output, expected_report_filename)) as converted_r:
converted_report = converted_r.readlines()
converted_rate = re.findall(r".*(?:Converted Rate: )(.*)[.]", converted_report[-1])

assert converted_rate[0] == '100.00%'

exec_msg = e.value.args[0]
assert exec_msg == "torch.__spec__ is None"

Loading…
Cancel
Save