Browse Source

!1032 Modify generated ast report file name and fix error msg when requirments are not installed

From: @liuchongming74
Reviewed-by: @lilongfei15,@lilongfei15,@wangyue01
Signed-off-by: @wangyue01
tags/v1.1.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
672dca9ca4
6 changed files with 40 additions and 28 deletions
  1. +14
    -1
      mindinsight/mindconverter/common/exceptions.py
  2. +2
    -3
      mindinsight/mindconverter/converter.py
  3. +2
    -1
      mindinsight/mindconverter/graph_based_converter/constant.py
  4. +10
    -9
      mindinsight/mindconverter/graph_based_converter/framework.py
  5. +6
    -8
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  6. +6
    -6
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py

+ 14
- 1
mindinsight/mindconverter/common/exceptions.py View File

@@ -182,6 +182,7 @@ class BaseConverterError(MindConverterException):
"""Define error code of BaseConverterError."""
UNKNOWN_ERROR = 0
UNKNOWN_MODEL = 1
PARAM_MISSING = 2

BASE_ERROR_CODE = ConverterErrors.BASE_CONVERTER_FAIL.value
ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
@@ -193,7 +194,7 @@ class BaseConverterError(MindConverterException):
@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
except_source = Exception, cls
except_source = Exception, UnknownModelError, ParamMissingError, cls
return except_source


@@ -209,6 +210,18 @@ class UnknownModelError(BaseConverterError):
return cls


class ParamMissingError(BaseConverterError):
"""Define cli params missing error."""
ERROR_CODE = BaseConverterError.ErrCode.PARAM_MISSING.value

def __init__(self, msg):
super(ParamMissingError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
return cls


class GraphInitError(MindConverterException):
"""The graph init fail error."""



+ 2
- 3
mindinsight/mindconverter/converter.py View File

@@ -77,8 +77,7 @@ class Converter:
raise error
finally:
if self._report:
dest_report_file = os.path.join(report_dir,
'_'.join(os.path.basename(infile).split('.')[:-1]) + '_report.txt')
dest_report_file = os.path.join(report_dir, f"report_of_{os.path.basename(infile).split('.')[0]}.txt")
with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file:
file.write('\n'.join(self._report))
logger.info("Convert report is saved in %s", dest_report_file)
@@ -180,7 +179,7 @@ def _path_split(file):
"""
file_dir, name = os.path.split(file)
if file_dir:
sep = file[len(file_dir)-1]
sep = file[len(file_dir) - 1]
if file_dir.startswith(sep):
return file.split(sep)[1:]



+ 2
- 1
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -44,7 +44,7 @@ ONNXRUNTIME_MIN_VER = "1.5.2"

BINARY_HEADER_PYTORCH_FILE = \
b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00'
TENSORFLOW_MODEL_SUFFIX = "pb"
BINARY_HEADER_PYTORCH_BITS = 32

ARGUMENT_LENGTH_LIMIT = 512
@@ -82,6 +82,7 @@ class InputType(Enum):
class FrameworkType(Enum):
PYTORCH = 0
TENSORFLOW = 1
UNKNOWN = 2


def get_imported_module():


+ 10
- 9
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -24,12 +24,11 @@ import mindinsight
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \
save_code_file_and_report
from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \
BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, TENSORFLOW_MODEL_SUFFIX
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \
BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError
from mindinsight.utils.exceptions import ParamMissError
BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError

permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)
@@ -68,7 +67,7 @@ def torch_installation_validation(func):
# Check whether pytorch is installed.
if not find_spec("torch"):
error = RuntimeIntegrityError("PyTorch is required when using graph based "
"scripts converter, and PyTorch vision must "
"scripts converter, and PyTorch version must "
"be consisted with model generation runtime.")
log.error(error)
log_console.error("\n")
@@ -242,6 +241,9 @@ def main_graph_base_converter(file_config):
"""
graph_path = file_config['model_file']
frame_type = get_framework_type(graph_path)
if not file_config.get("shape"):
raise ParamMissingError("Param missing, `--shape` is required when using graph mode.")

if frame_type == FrameworkType.PYTORCH.value:
graph_based_converter_pytorch_to_ms(graph_path=graph_path,
sample_shape=file_config['shape'],
@@ -259,7 +261,6 @@ def main_graph_base_converter(file_config):
else:
error_msg = "Get UNSUPPORTED model."
error = UnknownModelError(error_msg)
log.error(str(error))
raise error


@@ -269,8 +270,10 @@ def get_framework_type(model_path):
with open(model_path, 'rb') as f:
if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE:
framework_type = FrameworkType.PYTORCH.value
else:
elif os.path.basename(model_path).split(".")[-1].lower() == TENSORFLOW_MODEL_SUFFIX:
framework_type = FrameworkType.TENSORFLOW.value
else:
framework_type = FrameworkType.UNKNOWN.value
except IOError:
error_msg = "Get UNSUPPORTED model."
error = UnknownModelError(error_msg)
@@ -288,6 +291,4 @@ def check_params_exist(params: list, config):
miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param

if miss_param_list:
error = ParamMissError(miss_param_list)
log.error(str(error))
raise error
raise ParamMissingError(f"Param(s) missing, {miss_param_list} is(are) required when using graph mode.")

+ 6
- 8
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -220,9 +220,9 @@ class Generator:
path = Scope.path_str_to_list(scope_path)
if len(path) < depth_control:
continue
else: # depth control within path length
module_num = path[depth_control - 1][0]
repeated_submodules_at_this_depth.add(module_num)
# depth control within path length.
module_num = path[depth_control - 1][0]
repeated_submodules_at_this_depth.add(module_num)
ret[depth_control] = repeated_submodules_at_this_depth

self._repeated_submodules = ret
@@ -249,9 +249,8 @@ class Generator:
compared_value = nd_struct.fragment.actual_args.get(base_parameter)
if compared_value == base_value:
continue
else:
formal_args.add(base_parameter)
break
formal_args.add(base_parameter)
break

return formal_args

@@ -310,8 +309,7 @@ class Generator:
for module_num in module_nums:
if module_num in checked_module: # module already checked
continue
else:
checked_module.add(module_num)
checked_module.add(module_num)
map_filtered = self.module_map_filter(module_num=module_num)
formal_args_in_this_module = self._list_formal_parameters_in_a_module(map_filtered)
formal_args_in_each_submodule[module_num] = formal_args_in_this_module


+ 6
- 6
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -679,12 +679,12 @@ class ModuleStruct:
if submodule_precursor in self.onnx_names: # if internal, match with local nodes/submodules return
# but do nothing here
continue
else: # if external, match with current module construct header x
if submodule_precursor in self.construct_header_x.values():
local_x = get_dict_key_by_value(submodule_precursor, self.construct_header_x)
md_struct.set_inputs_in_construct_header(local_x, submodule_precursor)
else: # Extra precursor nodes, raise error
raise ValueError("Found external inputs of the submodule but the module does not have it.")
# if external, match with current module construct header x
if submodule_precursor in self.construct_header_x.values():
local_x = get_dict_key_by_value(submodule_precursor, self.construct_header_x)
md_struct.set_inputs_in_construct_header(local_x, submodule_precursor)
else: # Extra precursor nodes, raise error
raise ValueError("Found external inputs of the submodule but the module does not have it.")

def register_node_output_to_module(self, nd_struct):
"""Register nodes outputs to this module's return."""


Loading…
Cancel
Save