Browse Source

Check model file format ahead.

tags/v1.1.0
liuchongming 4 years ago
parent
commit
c47e653911
3 changed files with 36 additions and 26 deletions
  1. +8
    -1
      mindinsight/mindconverter/cli.py
  2. +25
    -3
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  3. +3
    -22
      mindinsight/mindconverter/graph_based_converter/framework.py

+ 8
- 1
mindinsight/mindconverter/cli.py View File

@@ -19,7 +19,9 @@ import argparse

import mindinsight
from mindinsight.mindconverter.converter import main
from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER
from mindinsight.mindconverter.graph_based_converter.common.utils import get_framework_type
from mindinsight.mindconverter.graph_based_converter.constant import ARGUMENT_LENGTH_LIMIT, EXPECTED_NUMBER, \
FrameworkType
from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter

from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
@@ -198,6 +200,11 @@ class ModelFileAction(argparse.Action):
if not os.path.isfile(outfile_dir):
parser_in.error(f'{option_string} {outfile_dir} is not a file')

frame_type = get_framework_type(outfile_dir)
if frame_type == FrameworkType.UNKNOWN.value:
parser_in.error(f'{option_string} {outfile_dir} should be an valid '
f'TensorFlow pb or PyTorch pth model file')

setattr(namespace, self.dest, outfile_dir)




+ 25
- 3
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -18,8 +18,10 @@ import stat
from importlib import import_module
from typing import List, Tuple, Mapping

from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError
from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP
from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, UnknownModelError
from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \
FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX


def is_converted(operation: str):
@@ -174,6 +176,7 @@ def get_dict_key_by_value(val, dic):
return d_key
return None


def convert_bytes_string_to_string(bytes_str):
"""
Convert a byte string to string by utf-8.
@@ -186,4 +189,23 @@ def convert_bytes_string_to_string(bytes_str):
"""
if isinstance(bytes_str, bytes):
return bytes_str.decode('utf-8')
return bytes_str
return bytes_str


def get_framework_type(model_path):
"""Get framework type."""
try:
with open(model_path, 'rb') as f:
if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE:
framework_type = FrameworkType.PYTORCH.value
elif os.path.basename(model_path).split(".")[-1].lower() == TENSORFLOW_MODEL_SUFFIX:
framework_type = FrameworkType.TENSORFLOW.value
else:
framework_type = FrameworkType.UNKNOWN.value
except IOError:
error_msg = "Get UNSUPPORTED model."
error = UnknownModelError(error_msg)
log.error(str(error))
raise error

return framework_type

+ 3
- 22
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -22,9 +22,9 @@ from importlib.util import find_spec

import mindinsight
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \
save_code_file_and_report
from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \
BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, TENSORFLOW_MODEL_SUFFIX
save_code_file_and_report, get_framework_type
from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \
@@ -264,25 +264,6 @@ def main_graph_base_converter(file_config):
raise error


def get_framework_type(model_path):
"""Get framework type."""
try:
with open(model_path, 'rb') as f:
if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE:
framework_type = FrameworkType.PYTORCH.value
elif os.path.basename(model_path).split(".")[-1].lower() == TENSORFLOW_MODEL_SUFFIX:
framework_type = FrameworkType.TENSORFLOW.value
else:
framework_type = FrameworkType.UNKNOWN.value
except IOError:
error_msg = "Get UNSUPPORTED model."
error = UnknownModelError(error_msg)
log.error(str(error))
raise error

return framework_type


def check_params_exist(params: list, config):
"""Check params exist."""
miss_param_list = ''


Loading…
Cancel
Save