Browse Source

Pre Merge pull request !1334 from 刘崇鸣/ud_code_struct

pull/1334/MERGE
刘崇鸣 Gitee 4 years ago
parent
commit
2a1ef885aa
10 changed files with 204 additions and 75 deletions
  1. +2
    -2
      mindinsight/mindconverter/__init__.py
  2. +80
    -12
      mindinsight/mindconverter/cli.py
  3. +16
    -14
      mindinsight/mindconverter/common/exceptions.py
  4. +1
    -1
      mindinsight/mindconverter/docs/error_code_definition.md
  5. +1
    -1
      mindinsight/mindconverter/docs/error_code_definition_cn.md
  6. +56
    -12
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  7. +7
    -2
      mindinsight/mindconverter/graph_based_converter/constant.py
  8. +30
    -18
      mindinsight/mindconverter/graph_based_converter/framework.py
  9. +1
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py
  10. +10
    -12
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py

+ 2
- 2
mindinsight/mindconverter/__init__.py View File

@@ -19,7 +19,7 @@ MindConverter is a migration tool to transform the model scripts from PyTorch to
Users can migrate their PyTorch models to Mindspore rapidly with minor changes according to the conversion report.
"""

__all__ = ["user_defined_pattern", "main_entry"]
__all__ = ["user_defined_pattern", "convert", "query_graph"]

from mindinsight.mindconverter.cli import run as main_entry
from mindinsight.mindconverter.cli import convert, query_graph
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import user_defined_pattern

+ 80
- 12
mindinsight/mindconverter/cli.py View File

@@ -387,22 +387,22 @@ def cli_entry():
if args.report is None:
args.report = args.output
os.makedirs(args.report, mode=mode, exist_ok=True)
run(args.in_file, args.model_file,
args.shape,
args.input_nodes, args.output_nodes,
args.output, args.report)
_run(args.in_file, args.model_file,
args.shape,
args.input_nodes, args.output_nodes,
args.output, args.report)


def run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report):
def _run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report):
"""
Run converter command.

Args:
in_files (str): The file path or directory to convert.
model_file(str): The model to convert on graph based schema.
shape(list): The input tensor shape of module_file.
input_nodes(str): The input node(s) name of model.
output_nodes(str): The output node(s) name of model.
model_file (str): The model to convert on graph based schema.
shape (Sequence[tuple]): The input tensor shape of the model.
input_nodes (Sequence[str]): The input node(s) name of model.
output_nodes (Sequence[str]): The output node(s) name of model.
out_dir (str): The output directory to save converted file.
report (str): The report file path.
"""
@@ -413,7 +413,6 @@ def run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report)
'outfile_dir': out_dir,
'report_dir': report if report else out_dir
}

if os.path.isfile(in_files):
files_config['root_path'] = os.path.dirname(in_files)
files_config['in_files'] = [in_files]
@@ -433,12 +432,81 @@ def run(in_files, model_file, shape, input_nodes, output_nodes, out_dir, report)
'outfile_dir': out_dir,
'report_dir': report if report else out_dir
}

main_graph_base_converter(file_config)
log_console.info("MindConverter: conversion is completed.")
else:
error_msg = "`--in_file` and `--model_file` should be set at least one."
error = FileNotFoundError(error_msg)
log.error(str(error))
log_console.error(f"mindconverter: error: {str(error)}")
log_console.error(str(error))
sys.exit(-1)


def convert(model_file, shape, input_nodes, output_nodes, out_dir: str = "output",
report: str = None):
"""
Run converter command.

Args:
model_file (str): The model to convert on graph based schema.
shape (Sequence[tuple]): The input tensor shape of the model.
input_nodes (Sequence[str]): The input node(s) name of model.
output_nodes (Sequence[str]): The output node(s) name of model.
out_dir (str): The output directory to save converted file.
report (str): The report file path.

Examples:
>>> from mindinsight.mindconverter import convert
>>> model_file = "resnet50.onnx"
>>> input_nodes = ["img"]
>>> shape = [(1, 3, 224, 224)]
>>> output_nodes = ["logits"]
>>> convert(model_file, shape, input_nodes, output_nodes)
"""
file_config = {
'model_file': model_file,
'shape': shape if shape else [],
'input_nodes': input_nodes,
'output_nodes': output_nodes,
'outfile_dir': out_dir,
'report_dir': report if report else out_dir
}
main_graph_base_converter(file_config)
log_console.info("MindConverter: conversion is completed.")


def query_graph(model_file, shape, input_nodes, output_nodes, query_result_folder):
"""
Run converter command.

Args:
model_file (str): The model to convert on graph based schema.
shape (Sequence[tuple]): The input tensor shape of the model.
input_nodes (Sequence[str]): The input node(s) name of model.
output_nodes (Sequence[str]): The output node(s) name of model.
query_result_folder (str): Save the optimized graph and its topological order to disk.

Examples:
>>> from mindinsight.mindconverter import query_graph
>>> model_file = "resnet50.onnx"
>>> input_nodes = ["img"]
>>> shape = [(1, 3, 224, 224)]
>>> output_nodes = ["logits"]
>>> query_result_folder = "result"
>>> query_graph(model_file, shape, input_nodes, output_nodes, query_result_folder)
"""
if not query_result_folder:
err_msg = "`query_result_folder` is required, when query the optimized graph and its topological order."
log.error(err_msg)
raise ValueError(err_msg)
file_config = {
"model_file": model_file,
"shape": shape if shape else [],
"input_nodes": input_nodes,
"output_nodes": output_nodes,
"outfile_dir": "",
"report_dir": "",
"query_result_folder": query_result_folder
}
main_graph_base_converter(file_config)
log_console.info("MindConverter: query is completed.")

+ 16
- 14
mindinsight/mindconverter/common/exceptions.py View File

@@ -271,7 +271,7 @@ class GraphInitError(MindConverterException):
return except_source


class SourceFilesSaveError(MindConverterException):
class FileSaveError(MindConverterException):
"""The source files save fail error."""

@unique
@@ -290,7 +290,7 @@ class SourceFilesSaveError(MindConverterException):
DEFAULT_MSG = "Error occurred when save source files."

def __init__(self, msg=DEFAULT_MSG):
super(SourceFilesSaveError, self).__init__(user_msg=msg)
super(FileSaveError, self).__init__(user_msg=msg)

@classmethod
def raise_from(cls):
@@ -366,9 +366,9 @@ class RuntimeIntegrityError(GraphInitError):
return RuntimeError, AttributeError, ImportError, ModuleNotFoundError, cls


class NodeInputTypeNotSupportError(SourceFilesSaveError):
class NodeInputTypeNotSupportError(FileSaveError):
"""The node input type NOT support error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.NODE_INPUT_TYPE_NOT_SUPPORT.value
ERROR_CODE = FileSaveError.ErrCode.NODE_INPUT_TYPE_NOT_SUPPORT.value

def __init__(self, msg):
super(NodeInputTypeNotSupportError, self).__init__(msg=msg)
@@ -378,9 +378,9 @@ class NodeInputTypeNotSupportError(SourceFilesSaveError):
return ValueError, TypeError, IndexError, cls


class ScriptGenerationError(SourceFilesSaveError):
class ScriptGenerationError(FileSaveError):
"""The script generate fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.SCRIPT_GENERATE_FAIL.value
ERROR_CODE = FileSaveError.ErrCode.SCRIPT_GENERATE_FAIL.value

def __init__(self, msg):
super(ScriptGenerationError, self).__init__(msg=msg)
@@ -394,9 +394,9 @@ class ScriptGenerationError(SourceFilesSaveError):
return except_source


class ReportGenerationError(SourceFilesSaveError):
class ReportGenerationError(FileSaveError):
"""The report generate fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.REPORT_GENERATE_FAIL.value
ERROR_CODE = FileSaveError.ErrCode.REPORT_GENERATE_FAIL.value

def __init__(self, msg):
super(ReportGenerationError, self).__init__(msg=msg)
@@ -407,9 +407,9 @@ class ReportGenerationError(SourceFilesSaveError):
return ZeroDivisionError, cls


class CheckPointGenerationError(SourceFilesSaveError):
class CheckPointGenerationError(FileSaveError):
"""The checkpoint generate fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.CKPT_GENERATE_FAIL.value
ERROR_CODE = FileSaveError.ErrCode.CKPT_GENERATE_FAIL.value

def __init__(self, msg):
super(CheckPointGenerationError, self).__init__(msg=msg)
@@ -420,9 +420,9 @@ class CheckPointGenerationError(SourceFilesSaveError):
return cls


class WeightMapGenerationError(SourceFilesSaveError):
class WeightMapGenerationError(FileSaveError):
"""The weight names map generate fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.MAP_GENERATE_FAIL.value
ERROR_CODE = FileSaveError.ErrCode.MAP_GENERATE_FAIL.value

def __init__(self, msg):
super(WeightMapGenerationError, self).__init__(msg=msg)
@@ -432,9 +432,10 @@ class WeightMapGenerationError(SourceFilesSaveError):
"""Raise from exception below."""
return cls

class OnnxModelSaveError(SourceFilesSaveError):

class OnnxModelSaveError(FileSaveError):
"""The onnx model save fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.MODEL_SAVE_FAIL.value
ERROR_CODE = FileSaveError.ErrCode.MODEL_SAVE_FAIL.value

def __init__(self, msg):
super(OnnxModelSaveError, self).__init__(msg=msg)
@@ -444,6 +445,7 @@ class OnnxModelSaveError(SourceFilesSaveError):
"""Raise from exception below."""
return cls


class SubGraphSearchingError(MindConverterException):
"""Sub-graph searching exception."""



+ 1
- 1
mindinsight/mindconverter/docs/error_code_definition.md View File

@@ -13,7 +13,7 @@
| ModelLoadingError | Fail to load the model | 1000001 | Given `--input_nodes`, `--output_nodes`, `--shape` don't match the input model; Meanwhile, the model file can not be loaded also can cause this error |
| TfRuntimeError | Fail to initialize the TF runtime | 1000002 | Resources required by TensorFlow are not available |
| RuntimeIntegrityError | Fail to locate required third party dependency | 1000003 | Caused by required third party packages are not installed |
| SourceFilesSaveError | Fail to generate or save converted script | 2000000 | Exception caused by 2000001~2000005 |
| FileSaveError | Fail to generate or save converted script | 2000000 | Exception caused by 2000001~2000005 |
| NodeInputTypeNotSupportError | Fail to recognize the input type of converted operator | 2000001 | Wrong input type set in mapper |
| ScriptGenerationError | Fail to generate converted script | 2000002 | No left space on hard disk; Converted code is not legal; A file with the same name already exists in `--output` |
| ReportGenerationError | Fail to generate converted script | 2000003 | No left space on hard disk; No available operator to be converted;A file with the same name already exists in `--report` |


+ 1
- 1
mindinsight/mindconverter/docs/error_code_definition_cn.md View File

@@ -13,7 +13,7 @@
| ModelLoadingError | 模型加载失败 | 1000001 | 给定的`--input_nodes`, `--output_nodes`, `--shape`与实际模型不符;<br />或模型文件存在问题导致模型无法加载 |
| TfRuntimeError | TensorFlow库执行出错 | 1000002 | TensorFlow启动申请所需资源失败导致无法正常启动,<br />请检查系统资源(进程数、内存、显存占用、CPU占用)是否充足 |
| RuntimeIntegrityError | 三方依赖库不完整 | 1000003 | MindConverter运行时所需的三方依赖库未安装 | |
| SourceFilesSaveError | 生成和保存转换后的脚本文件失败 | 2000000 | 由200000至2000005导致的脚本生成保存失败 |
| FileSaveError | 生成和保存转换后的脚本文件失败 | 2000000 | 由200000至2000005导致的脚本生成保存失败 |
| NodeInputTypeNotSupportError | 网络节点输入类型未知 | 2000001 | 映射关系中设置节点输入类型错误 |
| ScriptGenerationError | 转换脚本生成失败 | 2000002 | 空间不足;生成的脚本不符合PEP-8规范;`--output`目录下已有同名文件存在 |
| ReportGenerationError | 转换报告生成失败 | 2000003 | 空间不足;脚本中没有需要转换的算子;`--report`目录下已有同名文件存在 |


+ 56
- 12
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -13,9 +13,9 @@
# limitations under the License.
# ============================================================================
"""Define common utils."""
import json
import os
import stat
import json
import uuid
from importlib import import_module
from importlib.util import find_spec
@@ -27,7 +27,7 @@ from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \
CheckPointGenerationError, WeightMapGenerationError, ModelLoadingError, OnnxModelSaveError
from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, FrameworkType, \
TENSORFLOW_MODEL_SUFFIX, THIRD_PART_VERSION, ONNX_MODEL_SUFFIX, DTYPE_MAP
TENSORFLOW_MODEL_SUFFIX, THIRD_PART_VERSION, ONNX_MODEL_SUFFIX, DTYPE_MAP, WRITE_FLAGS, WRITE_MODES, WRITE_MODES_USR


def is_converted(operation: str):
@@ -148,10 +148,6 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
out_folder (str): Output folder.
report_folder (str): Report output folder.
"""
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IRUSR | stat.S_IWUSR
modes_usr = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR

out_folder = os.path.realpath(out_folder)
if not report_folder:
report_folder = out_folder
@@ -159,9 +155,9 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
report_folder = os.path.realpath(report_folder)

if not os.path.exists(out_folder):
os.makedirs(out_folder, modes_usr)
os.makedirs(out_folder, WRITE_MODES_USR)
if not os.path.exists(report_folder):
os.makedirs(report_folder, modes_usr)
os.makedirs(report_folder, WRITE_MODES_USR)

for file_name in code_lines:
code, report, trainable_weights, weight_map = code_lines[file_name]
@@ -170,7 +166,7 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
try:
if os.path.exists(code_file_path):
raise ScriptGenerationError("Code file with the same name already exists.")
with os.fdopen(os.open(code_file_path, flags, modes), 'w') as file:
with os.fdopen(os.open(code_file_path, WRITE_FLAGS, WRITE_MODES), 'w') as file:
file.write(code)
except (IOError, FileExistsError) as error:
raise ScriptGenerationError(str(error))
@@ -178,7 +174,7 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
try:
if os.path.exists(report_file_path):
raise ReportGenerationError("Report file with the same name already exists.")
with os.fdopen(os.open(report_file_path, flags, stat.S_IRUSR), "w") as rpt_f:
with os.fdopen(os.open(report_file_path, WRITE_FLAGS, stat.S_IRUSR), "w") as rpt_f:
rpt_f.write(report)
except (IOError, FileExistsError) as error:
raise ReportGenerationError(str(error))
@@ -200,7 +196,7 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
try:
if os.path.exists(weight_map_path):
raise WeightMapGenerationError("Weight map file with the same name already exists.")
with os.fdopen(os.open(weight_map_path, flags, stat.S_IRUSR), 'w') as map_f:
with os.fdopen(os.open(weight_map_path, WRITE_FLAGS, stat.S_IRUSR), 'w') as map_f:
weight_map_json = {f"{model_name}": weight_map}
json.dump(weight_map_json, map_f)
except (IOError, FileExistsError) as error:
@@ -208,7 +204,7 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],


def onnx_satisfied():
"""Validate ONNX , ONNXRUNTIME, ONNXOPTIMIZER installation."""
"""Validate ONNX, ONNXRUNTIME, ONNXOPTIMIZER installation."""
if not find_spec("onnx") or not find_spec("onnxruntime") or not find_spec("onnxoptimizer"):
return False
return True
@@ -332,3 +328,51 @@ def get_third_part_lib_validation_error_info(lib_list):
else:
error_info = link_str.join((error_info, info))
return error_info


def save_intermediate_graph(dataloader, output_folder):
"""
Save intermediate graph and topological order into output_folder.

Args:
dataloader (OnnxDataLoader): Dataloader inst.
output_folder (str): Output folder path.
"""
node_topo_order = []
placeholder_width = 30
for node_name, node in dataloader.nodes_dict.items():
row = f"{node.op_type.ljust(placeholder_width)} {node_name}\n"
node_topo_order.append(row)

# Import onnx lib.
onnx = import_module("onnx")

out_folder = os.path.realpath(output_folder)
if not os.path.exists(out_folder):
os.makedirs(out_folder, WRITE_MODES_USR)

graph_file = os.path.join(out_folder, "graph.onnx")
topological_order_file = os.path.join(out_folder, "topological_order.txt")

if os.path.exists(topological_order_file):
err_msg = f"{os.path.basename(topological_order_file)} already exists."
log.error(err_msg)
raise FileExistsError(err_msg)
if os.path.exists(graph_file):
err_msg = f"{os.path.basename(graph_file)} already exists."
log.error(err_msg)
raise FileExistsError(err_msg)

# Write topological order to disk.
with os.fdopen(os.open(topological_order_file, WRITE_FLAGS, stat.S_IRUSR), "w") as topo_file:
topo_file.writelines(node_topo_order)

try:
# Save graph to disk.
onnx.save_model(dataloader.inferred_model, graph_file)
except (IOError, OSError, FileExistsError) as e:
if os.path.exists(topological_order_file):
os.remove(topological_order_file)
if os.path.exists(graph_file):
os.remove(graph_file)
raise e

+ 7
- 2
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -14,6 +14,8 @@
# ==============================================================================
"""Constant definition."""
from enum import Enum, unique
import os
import stat

import numpy as np

@@ -46,7 +48,11 @@ TF2ONNX_MIN_VER = "1.7.1"
ONNXRUNTIME_MIN_VER = "1.5.2"
ONNXOPTIMIZER_MIN_VER = "0.1.2"
ONNXOPTIMIZER_MAX_VER = "0.1.2"
CHECKPOINT_SEGMENT_SIZE = 2040109465 # 1.9GB, no more than 2GB
CHECKPOINT_SEGMENT_SIZE = 2040109465 # 1.9GB, no more than 2GB

WRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_EXCL
WRITE_MODES = stat.S_IRUSR | stat.S_IWUSR
WRITE_MODES_USR = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR

DTYPE_MAP = {
1: np.float32,
@@ -126,7 +132,6 @@ MIN_SCOPE_LENGTH = 2

ONNX_OPSET_VERSION = 11


NO_CONVERTED_OPERATORS = [
"onnx::Constant",
"Constant"


+ 30
- 18
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -22,13 +22,14 @@ from functools import partial
from google.protobuf.internal import api_implementation
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \
save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info
save_code_file_and_report, get_framework_type, check_dependency_integrity, \
get_third_part_lib_validation_error_info, save_intermediate_graph
from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER
from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
from mindinsight.mindconverter.common.exceptions import GraphInitError, SourceFilesSaveError, \
from mindinsight.mindconverter.common.exceptions import GraphInitError, FileSaveError, \
BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError, \
BadParamError
from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory
@@ -69,8 +70,7 @@ def onnx_installation_validation(func):
type, inner function.
"""

def _f(graph_path: str, input_nodes: dict, output_nodes: List[str],
output_folder: str, report_folder: str = None):
def _f(*args, **kwargs):
# Check whether onnx is installed.
error_info = f"{get_third_part_lib_validation_error_info(['onnx', 'onnxruntime', 'onnxoptimizer'])} " \
f"are required when using graph based scripts converter or ONNX conversion."
@@ -83,9 +83,7 @@ def onnx_installation_validation(func):
_print_error(RuntimeIntegrityError(error_info))
sys.exit(0)

func(graph_path=graph_path,
input_nodes=input_nodes, output_nodes=output_nodes,
output_folder=output_folder, report_folder=report_folder)
func(*args, **kwargs)

return _f

@@ -111,8 +109,7 @@ def tf_installation_validation(func):
type, inner function.
"""

def _f(graph_path: str, input_nodes: dict, output_nodes: List[str],
output_folder: str, report_folder: str):
def _f(*args, **kwargs):
not_integral_error = RuntimeIntegrityError(
f"TensorFlow, "
f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} "
@@ -135,9 +132,7 @@ def tf_installation_validation(func):
_print_error(not_integral_error)
sys.exit(0)

func(graph_path=graph_path,
input_nodes=input_nodes, output_nodes=output_nodes,
output_folder=output_folder, report_folder=report_folder)
func(*args, **kwargs)

return _f

@@ -160,11 +155,12 @@ def _extract_model_name(model_path):

@onnx_installation_validation
@GraphInitError.uniform_catcher()
@SourceFilesSaveError.uniform_catcher()
@FileSaveError.uniform_catcher()
@GeneratorError.uniform_catcher()
def graph_based_converter_onnx_to_ms(graph_path: str,
input_nodes: dict, output_nodes: List[str],
output_folder: str, report_folder: str = None):
output_folder: str, report_folder: str = None,
query_result_folder: str = None):
"""
ONNX to MindSpore based on Graph.

@@ -174,8 +170,14 @@ def graph_based_converter_onnx_to_ms(graph_path: str,
output_nodes (list[str]): Output node(s) of the model.
output_folder (str): Output folder.
report_folder (str): Report output folder path.
query_result_folder (str): Save the optimized graph and its topological order to disk.
"""
graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes)
if query_result_folder:
save_intermediate_graph(graph_obj.dataloader, query_result_folder)
GlobalContext.release()
return
graph_obj.build()
generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
model_name = _extract_model_name(graph_path)
code_fragments = generator_inst.generate()
@@ -187,11 +189,12 @@ def graph_based_converter_onnx_to_ms(graph_path: str,
@tf_installation_validation
@GraphInitError.uniform_catcher()
@TfRuntimeError.uniform_catcher()
@SourceFilesSaveError.uniform_catcher()
@FileSaveError.uniform_catcher()
@GeneratorError.uniform_catcher()
def graph_based_converter_tf_to_ms(graph_path: str,
input_nodes: dict, output_nodes: List[str],
output_folder: str, report_folder: str = None):
output_folder: str, report_folder: str = None,
query_result_folder: str = None):
"""
Tensorflow to MindSpore based on Graph.

@@ -201,11 +204,17 @@ def graph_based_converter_tf_to_ms(graph_path: str,
output_nodes (list[str]): Output node(s) of the model.
output_folder (str): Output folder.
report_folder (str): Report output folder path.
query_result_folder (str): Save the optimized graph and its topological order to disk.
"""
# Close unnecessary log.
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes)
if query_result_folder:
save_intermediate_graph(graph_obj.dataloader, query_result_folder)
GlobalContext.release()
return
graph_obj.build()
generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
model_name = _extract_model_name(graph_path)
code_fragments = generator_inst.generate()
@@ -249,14 +258,17 @@ def main_graph_base_converter(file_config):
input_nodes=input_nodes,
output_nodes=file_config['output_nodes'],
output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir'])
report_folder=file_config['report_dir'],
query_result_folder=file_config.get("query_result_folder"))

elif frame_type == FrameworkType.TENSORFLOW.value:
graph_based_converter_tf_to_ms(graph_path=graph_path,
input_nodes=input_nodes,
output_nodes=file_config['output_nodes'],
output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir'])
report_folder=file_config['report_dir'],
query_result_folder=file_config.get("query_result_folder"))

else:
error_msg = "Get UNSUPPORTED model."
error = UnknownModelError(error_msg)


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py View File

@@ -41,7 +41,7 @@ class GraphFactory:
"""
if not isinstance(input_nodes, dict):
raise TypeError("`input_nodes` must be type of dict.")
if not isinstance(output_nodes, list):
if not isinstance(output_nodes, (list, tuple)):
raise TypeError("`output_nodes` must be type of list.")
return OnnxGraph.load(model_path=graph_path, input_nodes=input_nodes,
output_nodes=output_nodes)

+ 10
- 12
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -61,8 +61,10 @@ class OnnxGraph(Graph):

def __init__(self, model, model_path, **kwargs):
super(OnnxGraph, self).__init__(model=model, model_path=model_path, **kwargs)

self.build()
self.dataloader = OnnxDataLoader(self.model,
self.model_path,
input_nodes=self._raw_input_nodes,
output_nodes=self._raw_output_nodes)

@staticmethod
def _extract_shape(shape):
@@ -115,21 +117,17 @@ class OnnxGraph(Graph):

def build(self):
"""Build graph tree."""
model_data = OnnxDataLoader(self.model,
self.model_path,
input_nodes=self._raw_input_nodes,
output_nodes=self._raw_output_nodes)
scope_name_list = generate_scope_name(model_data)

self._shape_dict = model_data.node_output_shape_dict
for ind, (node_name, node) in enumerate(model_data.nodes_dict.items()):
scope_name_list = generate_scope_name(self.dataloader)

self._shape_dict = self.dataloader.node_output_shape_dict
for ind, (node_name, node) in enumerate(self.dataloader.nodes_dict.items()):
node_weights = list()
node.scope_name = scope_name_list[ind]
inputs = node.input_name_list
# check each input from node or tensors
for idx, i in enumerate(inputs):
if i in model_data.tensors_dict:
tensor = model_data.tensors_dict[i]
if i in self.dataloader.tensors_dict:
tensor = self.dataloader.tensors_dict[i]
t_name = tensor.name
t_value = tensor.to_array()
node_weights.append(NodeWeight(t_name, t_value, idx))


Loading…
Cancel
Save