Browse Source

Support convert large model(>2GB) to mindspore

pull/1280/head
范吉斌 4 years ago
parent
commit
868be184cc
8 changed files with 137 additions and 32 deletions
  1. +13
    -0
      mindinsight/mindconverter/common/exceptions.py
  2. +35
    -9
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  3. +1
    -1
      mindinsight/mindconverter/graph_based_converter/constant.py
  4. +12
    -2
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  5. +3
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  6. +5
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  7. +41
    -14
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  8. +27
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py

+ 13
- 0
mindinsight/mindconverter/common/exceptions.py View File

@@ -283,6 +283,7 @@ class SourceFilesSaveError(MindConverterException):
REPORT_GENERATE_FAIL = 3
CKPT_GENERATE_FAIL = 4
MAP_GENERATE_FAIL = 5
MODEL_SAVE_FAIL = 6

BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value
ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
@@ -299,6 +300,7 @@ class SourceFilesSaveError(MindConverterException):
ReportGenerationError,
CheckPointGenerationError,
WeightMapGenerationError,
OnnxModelSaveError,
IOError, cls)
return except_source

@@ -430,6 +432,17 @@ class WeightMapGenerationError(SourceFilesSaveError):
"""Raise from exception below."""
return cls

class OnnxModelSaveError(SourceFilesSaveError):
"""The onnx model save fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.MODEL_SAVE_FAIL.value

def __init__(self, msg):
super(OnnxModelSaveError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
"""Raise from exception below."""
return cls

class SubGraphSearchingError(MindConverterException):
"""Sub-graph searching exception."""


+ 35
- 9
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -16,6 +16,7 @@
import json
import os
import stat
import uuid
from importlib import import_module
from importlib.util import find_spec
from typing import List, Tuple, Mapping
@@ -23,7 +24,7 @@ from typing import List, Tuple, Mapping
import numpy as np

from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \
CheckPointGenerationError, WeightMapGenerationError
CheckPointGenerationError, WeightMapGenerationError, ModelLoadingError, OnnxModelSaveError
from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, FrameworkType, \
TENSORFLOW_MODEL_SUFFIX, THIRD_PART_VERSION, ONNX_MODEL_SUFFIX, DTYPE_MAP

@@ -84,7 +85,7 @@ def build_feed_dict(onnx_model, input_nodes: dict):
return feed_dict


def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]):
def fetch_output_from_onnx_model(model, model_path: str, feed_dict: dict, output_nodes: List[str]):
"""
Fetch specific nodes output from onnx model.

@@ -93,6 +94,7 @@ def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]

Args:
model (ModelProto): ONNX model.
model_path (str): ONNX model path.
feed_dict (dict): Feed forward inputs.
output_nodes (list[str]): Output nodes list.

@@ -105,9 +107,29 @@ def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]

edit_model = _add_outputs_of_onnx_model(model, output_nodes)

onnx = import_module("onnx")
ort = import_module("onnxruntime")
sess = ort.InferenceSession(path_or_bytes=bytes(edit_model.SerializeToString()))
fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict)
try:
dir_path = os.path.dirname(model_path)
stem_name = os.path.splitext(os.path.basename(model_path))[0]
filename = ".~{0}_{1}".format(stem_name, str(uuid.uuid4()))
tmp_file = os.path.join(dir_path, filename)
onnx.save_tensor(edit_model, tmp_file)
except (TypeError, IOError) as error:
if os.path.exists(tmp_file):
os.remove(tmp_file)
raise OnnxModelSaveError("Onnx model save failed, {}".format(str(error)))


try:
sess = ort.InferenceSession(path_or_bytes=tmp_file)
fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict)
except ModelLoadingError.raise_from() as error:
raise ModelLoadingError("OnnxRuntimeError, {}".format(str(error)))
finally:
if os.path.exists(tmp_file):
os.remove(tmp_file)

run_result = dict()
for idx, opt in enumerate(output_nodes):
run_result[opt] = fetched_res[idx]
@@ -161,13 +183,17 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
raise ReportGenerationError(str(error))

save_checkpoint = getattr(import_module("mindspore.train.serialization"), "save_checkpoint")
ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt"))
try:
for idx, trainable_weight in enumerate(trainable_weights):
if len(trainable_weights) > 1:
ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}_{idx}.ckpt"))
else:
ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt"))
if os.path.exists(ckpt_file_path):
raise CheckPointGenerationError("Checkpoint file with the same name already exists.")
save_checkpoint(trainable_weights, ckpt_file_path)
except TypeError as error:
raise CheckPointGenerationError(str(error))
try:
save_checkpoint(trainable_weight, ckpt_file_path)
except TypeError as error:
raise CheckPointGenerationError(str(error))

weight_map_path = os.path.realpath(os.path.join(report_folder, f"weight_map_of_{model_name}.json"))
try:


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -46,7 +46,7 @@ TF2ONNX_MIN_VER = "1.7.1"
ONNXRUNTIME_MIN_VER = "1.5.2"
ONNXOPTIMIZER_MIN_VER = "0.1.2"
ONNXOPTIMIZER_MAX_VER = "0.1.2"
CHECKPOINT_SEGMENT_SIZE = 2040109465 # 1.9GB, no more than 2GB

DTYPE_MAP = {
1: np.float32,


+ 12
- 2
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -35,7 +35,7 @@ from mindinsight.mindconverter.graph_based_converter.report_generator import Rep
from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list
from mindinsight.mindconverter.graph_based_converter.generator.matcher import MatcherLauncher
from mindinsight.mindconverter.graph_based_converter.generator.shared_weights import SharedWeightHelper
from mindinsight.mindconverter.graph_based_converter.constant import CHECKPOINT_SEGMENT_SIZE

class CodeStruct:
"""
@@ -635,15 +635,25 @@ class Generator:
}
)

save_obj_list = list()
save_obj = list()
data_nbytes = 0
for weight_name, weight_value in trainable_weights_dict.items():
obj = {
'name': weight_name,
'data': mindspore.Tensor(weight_value)
}
data_nbytes += obj['data'].nbytes
if data_nbytes > CHECKPOINT_SEGMENT_SIZE:
save_obj_list.append(save_obj)
save_obj = []
data_nbytes = obj['data'].nbytes
save_obj.append(obj)

return save_obj, weight_map
if save_obj:
save_obj_list.append(save_obj)

return save_obj_list, weight_map

@GeneratorError.check_except("Generator occurs an error when generating code statements.")
def generate(self):


+ 3
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -90,9 +90,10 @@ class Graph(BaseGraph, abc.ABC):

sorted = False

def __init__(self, model, **kwargs):
def __init__(self, model, model_path, **kwargs):
super(Graph, self).__init__()
self.model = model
self.model_path = model_path
self._raw_input_nodes = kwargs.get("input_nodes")
self._raw_output_nodes = kwargs.get("output_nodes")
self._nodes_collection = OrderedDict()
@@ -246,7 +247,7 @@ class Graph(BaseGraph, abc.ABC):
cls, graph instance.
"""
src_graph = cls.load_graph(graph_path=model_path, **kwargs)
return cls(src_graph, **kwargs)
return cls(src_graph, model_path, **kwargs)


class GraphNode(abc.ABC):


+ 5
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -56,10 +56,11 @@ class OnnxGraph(Graph):

Args:
model (onnx.ModelProto): Onnx defined model proto.
model_path (str): Onnx model path.
"""

def __init__(self, model, **kwargs):
super(OnnxGraph, self).__init__(model=model, **kwargs)
def __init__(self, model, model_path, **kwargs):
super(OnnxGraph, self).__init__(model=model, model_path=model_path, **kwargs)

self.build()

@@ -115,6 +116,7 @@ class OnnxGraph(Graph):
def build(self):
"""Build graph tree."""
model_data = OnnxDataLoader(self.model,
self.model_path,
input_nodes=self._raw_input_nodes,
output_nodes=self._raw_output_nodes)
scope_name_list = generate_scope_name(model_data)
@@ -207,7 +209,7 @@ class OnnxGraph(Graph):
output_nodes=output_nodes)
else:
onnx = import_module('onnx')
onnx_model = onnx.load(graph_path)
onnx_model = onnx.load(graph_path, load_external_data=False)

onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input]



+ 41
- 14
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Define ONNX related operations."""
import os
import itertools
import re
import abc
@@ -249,17 +250,19 @@ class OnnxDataLoader:

Args:
onnx_model (onnx.ModelProto): Original Onnx defined model.
model_path (str): Onnx model path.
input_nodes (Union[str, list]): Input nodes of ONNX model.
output_nodes (Union[str, list]): Output nodes of ONNX model.
infer_shape (bool): Enable the shape inference after conversion.
Default: True
"""

def __init__(self, onnx_model, input_nodes: dict,
def __init__(self, onnx_model, model_path: str, input_nodes: dict,
output_nodes: list, infer_shape=True):
onnx_sim = OnnxSimplify()
onnx_model_sim = onnx_sim.run_onnx_simplify(onnx_model, input_nodes)
onnx_model_sim = onnx_sim.run_onnx_simplify(onnx_model, model_path, input_nodes)
self.model = onnx_model_sim
self.model_path = model_path
self.graph = onnx_model_sim.graph
self.nodes = onnx_model_sim.graph.node
self.input_nodes = input_nodes
@@ -384,7 +387,10 @@ class OnnxDataLoader:

feed_dict = build_feed_dict(self.inferred_model, self.input_nodes)

outputs_infer = fetch_output_from_onnx_model(self.model, feed_dict, output_nodes_name)
outputs_infer = fetch_output_from_onnx_model(self.model,
self.model_path,
feed_dict,
output_nodes_name)
return outputs_infer

def _parse_nodes(self):
@@ -434,6 +440,9 @@ class OnnxDataLoader:
t = OnnxTensor(tensor)
self.tensors_dict[t.name] = t

self._global_context.onnx_tensors_collection = self.tensors_dict

def _remove_extra_graph_output(self):
idx = 0
while idx < len(self.model.graph.output):
cur_opt = self.model.graph.output[idx]
@@ -442,8 +451,6 @@ class OnnxDataLoader:
continue
idx += 1

self._global_context.onnx_tensors_collection = self.tensors_dict

def _parse_node_output_shape(self):
"""
Parse the inferred output shape of each node.
@@ -516,13 +523,13 @@ class OnnxDataLoader:
# Parse ONNX Graph level info
self._parse_graph()

# 1. parse all tensors
self._parse_tensors()

# 2. parse all nodes, note that parse tensors must be done as nodes require tensor info
# 1. parse all nodes, note that parse tensors must be done as nodes require tensor info
# to process the node weight sharing.
self._parse_nodes()

# 2. remove extra output from onnx model graph
self._remove_extra_graph_output()

# 3. parse value info (incl. node output shape)
if self._is_infer_shape:
try:
@@ -534,15 +541,29 @@ class OnnxDataLoader:
log.exception(e)
raise e

# 4. Optimize graph to eliminate some nodes.
# 4. load external_data for initializer
self.load_external_data()

# 5. parse all tensors
self._parse_tensors()

# 6. Optimize graph to eliminate some nodes.
self._find_nodes_to_be_eliminated()

# 5. build nodes connections
# 7. build nodes connections
self.build_nodes_connection()

# 6. Run onnx model to fetch actual value of eliminated nodes.
# 8. Run onnx model to fetch actual value of eliminated nodes.
self._fetch_eliminated_nodes_value()

def load_external_data(self):
load_external_data_for_model = getattr(import_module("onnx.external_data_helper"),
"load_external_data_for_model")
model_filepath = os.path.realpath(self.model_path)
if model_filepath:
base_dir = os.path.dirname(model_filepath)
load_external_data_for_model(self.model, base_dir)

def _fetch_eliminated_nodes_value(self):
"""Fetch eliminated nodes values by running onnx inference."""

@@ -556,7 +577,10 @@ class OnnxDataLoader:
shape_ref = self._nodes_dict[node].input_name_list[1]
output_tensors.append(shape_ref)
feed_dict = build_feed_dict(self.model, self.input_nodes)
fetch_dict = fetch_output_from_onnx_model(self.model, feed_dict=feed_dict, output_nodes=output_tensors)
fetch_dict = fetch_output_from_onnx_model(self.model,
self.model_path,
feed_dict=feed_dict,
output_nodes=output_tensors)
for opt_tensor_name, value in fetch_dict.items():
self.tensors_dict[opt_tensor_name] = OnnxTensor(value, opt_tensor_name)

@@ -570,7 +594,10 @@ class OnnxDataLoader:
shape_ref = self._nodes_dict[node].input_name_list[3]
output_tensors.append(shape_ref)
feed_dict = build_feed_dict(self.model, self.input_nodes)
fetch_dict = fetch_output_from_onnx_model(self.model, feed_dict=feed_dict, output_nodes=output_tensors)
fetch_dict = fetch_output_from_onnx_model(self.model,
self.model_path,
feed_dict=feed_dict,
output_nodes=output_tensors)
for opt_tensor_name, value in fetch_dict.items():
self.tensors_dict[opt_tensor_name] = OnnxTensor(value, opt_tensor_name)



+ 27
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py View File

@@ -25,18 +25,21 @@ class OnnxSimplify:

def __init__(self):
self._onnx_model = None
self.model_path = str()
self._constant_nodes = list()
self._outputs_infer = dict()

def run_onnx_simplify(self, onnx_model, input_nodes):
def run_onnx_simplify(self, onnx_model, model_path, input_nodes):
"""
Run to simplify onnx model.

Args:
onnx_model (onnx.ModelProto): Onnx Model.
model_path (str): Onnx model path.
input_nodes (dict): Input nodes and corresponding sample shape.
"""
self._onnx_model = onnx_model
self.model_path = model_path
self._optimizer()
self._get_constant_nodes()
self._onnx_infer(input_nodes)
@@ -70,13 +73,35 @@ class OnnxSimplify:
'fuse_transpose_into_gemm'
]

raw_initializers = dict()
for initializer in self._onnx_model.graph.initializer:
raw_initializers[initializer.name] = initializer

input_num = len(self._onnx_model.graph.input)
onnx_model_optimized = onnxoptimizer.optimize(self._onnx_model, optimizers_list, fixed_point=True)

for initializer in onnx_model_optimized.graph.initializer:
if raw_initializers.get(initializer.name, None):
raw_initializer = raw_initializers.get(initializer.name)
self.copy_initializer_external_data(raw_initializer, initializer)

if self._onnx_model.ir_version > 3:
del onnx_model_optimized.graph.input[input_num:]
self._onnx_model = onnx_model_optimized

def copy_initializer_external_data(self, src, dst):
"""
Copy external_data from src initializer to dst initializer.

Args:
src (graph.initializer): Source initializer.
dst (graph.initializer): Destination initializer.
"""
if src.HasField('data_location'):
dst.data_location = src.data_location
if src.external_data:
dst.external_data.extend(src.external_data)

def _get_constant_nodes(self):
"""Get constant nodes."""

@@ -109,6 +134,7 @@ class OnnxSimplify:
output_nodes_name.extend(node.output)
original_outputs = [nd.name for nd in self._onnx_model.graph.output]
self._outputs_infer = fetch_output_from_onnx_model(self._onnx_model,
self.model_path,
feed_dict, output_nodes_name)
idx = 0
while idx < len(self._onnx_model.graph.output):


Loading…
Cancel
Save