diff --git a/mindinsight/mindconverter/common/exceptions.py b/mindinsight/mindconverter/common/exceptions.py index 51edc5e2..1c44cc5e 100644 --- a/mindinsight/mindconverter/common/exceptions.py +++ b/mindinsight/mindconverter/common/exceptions.py @@ -283,6 +283,7 @@ class SourceFilesSaveError(MindConverterException): REPORT_GENERATE_FAIL = 3 CKPT_GENERATE_FAIL = 4 MAP_GENERATE_FAIL = 5 + MODEL_SAVE_FAIL = 6 BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value ERROR_CODE = ErrCode.UNKNOWN_ERROR.value @@ -299,6 +300,7 @@ class SourceFilesSaveError(MindConverterException): ReportGenerationError, CheckPointGenerationError, WeightMapGenerationError, + OnnxModelSaveError, IOError, cls) return except_source @@ -430,6 +432,17 @@ class WeightMapGenerationError(SourceFilesSaveError): """Raise from exception below.""" return cls +class OnnxModelSaveError(SourceFilesSaveError): + """The onnx model save fail error.""" + ERROR_CODE = SourceFilesSaveError.ErrCode.MODEL_SAVE_FAIL.value + + def __init__(self, msg): + super(OnnxModelSaveError, self).__init__(msg=msg) + + @classmethod + def raise_from(cls): + """Raise from exception below.""" + return cls class SubGraphSearchingError(MindConverterException): """Sub-graph searching exception.""" diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py index 1ded8290..2b0e6a75 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/utils.py +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -16,6 +16,7 @@ import json import os import stat +import uuid from importlib import import_module from importlib.util import find_spec from typing import List, Tuple, Mapping @@ -23,7 +24,7 @@ from typing import List, Tuple, Mapping import numpy as np from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \ - CheckPointGenerationError, WeightMapGenerationError + CheckPointGenerationError, WeightMapGenerationError, ModelLoadingError, OnnxModelSaveError from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, FrameworkType, \ TENSORFLOW_MODEL_SUFFIX, THIRD_PART_VERSION, ONNX_MODEL_SUFFIX, DTYPE_MAP @@ -84,7 +85,7 @@ def build_feed_dict(onnx_model, input_nodes: dict): return feed_dict -def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]): +def fetch_output_from_onnx_model(model, model_path: str, feed_dict: dict, output_nodes: List[str]): """ Fetch specific nodes output from onnx model. @@ -93,6 +94,7 @@ def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str] Args: model (ModelProto): ONNX model. + model_path (str): ONNX model path. feed_dict (dict): Feed forward inputs. output_nodes (list[str]): Output nodes list. @@ -105,9 +107,29 @@ def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str] edit_model = _add_outputs_of_onnx_model(model, output_nodes) + onnx = import_module("onnx") ort = import_module("onnxruntime") - sess = ort.InferenceSession(path_or_bytes=bytes(edit_model.SerializeToString())) - fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict) + try: + dir_path = os.path.dirname(model_path) + stem_name = os.path.splitext(os.path.basename(model_path))[0] + filename = ".~{0}_{1}".format(stem_name, str(uuid.uuid4())) + tmp_file = os.path.join(dir_path, filename) + onnx.save_tensor(edit_model, tmp_file) + except (TypeError, IOError) as error: + if os.path.exists(tmp_file): + os.remove(tmp_file) + raise OnnxModelSaveError("Onnx model save failed, {}".format(str(error))) + + + try: + sess = ort.InferenceSession(path_or_bytes=tmp_file) + fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict) + except ModelLoadingError.raise_from() as error: + raise ModelLoadingError("OnnxRuntimeError, {}".format(str(error))) + finally: + if os.path.exists(tmp_file): + os.remove(tmp_file) + run_result = dict() for idx, opt in enumerate(output_nodes): run_result[opt] = fetched_res[idx] @@ -161,13 +183,17 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], raise ReportGenerationError(str(error)) save_checkpoint = getattr(import_module("mindspore.train.serialization"), "save_checkpoint") - ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) - try: + for idx, trainable_weight in enumerate(trainable_weights): + if len(trainable_weights) > 1: + ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}_{idx}.ckpt")) + else: + ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) if os.path.exists(ckpt_file_path): raise CheckPointGenerationError("Checkpoint file with the same name already exists.") - save_checkpoint(trainable_weights, ckpt_file_path) - except TypeError as error: - raise CheckPointGenerationError(str(error)) + try: + save_checkpoint(trainable_weight, ckpt_file_path) + except TypeError as error: + raise CheckPointGenerationError(str(error)) weight_map_path = os.path.realpath(os.path.join(report_folder, f"weight_map_of_{model_name}.json")) try: diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index 2c781d52..3bdba22c 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -46,7 +46,7 @@ TF2ONNX_MIN_VER = "1.7.1" ONNXRUNTIME_MIN_VER = "1.5.2" ONNXOPTIMIZER_MIN_VER = "0.1.2" ONNXOPTIMIZER_MAX_VER = "0.1.2" - +CHECKPOINT_SEGMENT_SIZE = 2040109465 # 1.9GB, no more than 2GB DTYPE_MAP = { 1: np.float32, diff --git a/mindinsight/mindconverter/graph_based_converter/generator/generator.py b/mindinsight/mindconverter/graph_based_converter/generator/generator.py index 8f39e0d1..d8bb710f 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/generator.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/generator.py @@ -35,7 +35,7 @@ from mindinsight.mindconverter.graph_based_converter.report_generator import Rep from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list from mindinsight.mindconverter.graph_based_converter.generator.matcher import MatcherLauncher from mindinsight.mindconverter.graph_based_converter.generator.shared_weights import SharedWeightHelper - +from mindinsight.mindconverter.graph_based_converter.constant import CHECKPOINT_SEGMENT_SIZE class CodeStruct: """ @@ -635,15 +635,25 @@ class Generator: } ) + save_obj_list = list() save_obj = list() + data_nbytes = 0 for weight_name, weight_value in trainable_weights_dict.items(): obj = { 'name': weight_name, 'data': mindspore.Tensor(weight_value) } + data_nbytes += obj['data'].nbytes + if data_nbytes > CHECKPOINT_SEGMENT_SIZE: + save_obj_list.append(save_obj) + save_obj = [] + data_nbytes = obj['data'].nbytes save_obj.append(obj) - return save_obj, weight_map + if save_obj: + save_obj_list.append(save_obj) + + return save_obj_list, weight_map @GeneratorError.check_except("Generator occurs an error when generating code statements.") def generate(self): diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py index c36f0a0f..770e91db 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py @@ -90,9 +90,10 @@ class Graph(BaseGraph, abc.ABC): sorted = False - def __init__(self, model, **kwargs): + def __init__(self, model, model_path, **kwargs): super(Graph, self).__init__() self.model = model + self.model_path = model_path self._raw_input_nodes = kwargs.get("input_nodes") self._raw_output_nodes = kwargs.get("output_nodes") self._nodes_collection = OrderedDict() @@ -246,7 +247,7 @@ class Graph(BaseGraph, abc.ABC): cls, graph instance. """ src_graph = cls.load_graph(graph_path=model_path, **kwargs) - return cls(src_graph, **kwargs) + return cls(src_graph, model_path, **kwargs) class GraphNode(abc.ABC): diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py index c8bb0393..88ab2ea3 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py @@ -56,10 +56,11 @@ class OnnxGraph(Graph): Args: model (onnx.ModelProto): Onnx defined model proto. + model_path (str): Onnx model path. """ - def __init__(self, model, **kwargs): - super(OnnxGraph, self).__init__(model=model, **kwargs) + def __init__(self, model, model_path, **kwargs): + super(OnnxGraph, self).__init__(model=model, model_path=model_path, **kwargs) self.build() @@ -115,6 +116,7 @@ class OnnxGraph(Graph): def build(self): """Build graph tree.""" model_data = OnnxDataLoader(self.model, + self.model_path, input_nodes=self._raw_input_nodes, output_nodes=self._raw_output_nodes) scope_name_list = generate_scope_name(model_data) @@ -207,7 +209,7 @@ class OnnxGraph(Graph): output_nodes=output_nodes) else: onnx = import_module('onnx') - onnx_model = onnx.load(graph_path) + onnx_model = onnx.load(graph_path, load_external_data=False) onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input] diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index cd7acaae..595e267d 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Define ONNX related operations.""" +import os import itertools import re import abc @@ -249,17 +250,19 @@ class OnnxDataLoader: Args: onnx_model (onnx.ModelProto): Original Onnx defined model. + model_path (str): Onnx model path. input_nodes (Union[str, list]): Input nodes of ONNX model. output_nodes (Union[str, list]): Output nodes of ONNX model. infer_shape (bool): Enable the shape inference after conversion. Default: True """ - def __init__(self, onnx_model, input_nodes: dict, + def __init__(self, onnx_model, model_path: str, input_nodes: dict, output_nodes: list, infer_shape=True): onnx_sim = OnnxSimplify() - onnx_model_sim = onnx_sim.run_onnx_simplify(onnx_model, input_nodes) + onnx_model_sim = onnx_sim.run_onnx_simplify(onnx_model, model_path, input_nodes) self.model = onnx_model_sim + self.model_path = model_path self.graph = onnx_model_sim.graph self.nodes = onnx_model_sim.graph.node self.input_nodes = input_nodes @@ -384,7 +387,10 @@ class OnnxDataLoader: feed_dict = build_feed_dict(self.inferred_model, self.input_nodes) - outputs_infer = fetch_output_from_onnx_model(self.model, feed_dict, output_nodes_name) + outputs_infer = fetch_output_from_onnx_model(self.model, + self.model_path, + feed_dict, + output_nodes_name) return outputs_infer def _parse_nodes(self): @@ -434,6 +440,9 @@ class OnnxDataLoader: t = OnnxTensor(tensor) self.tensors_dict[t.name] = t + self._global_context.onnx_tensors_collection = self.tensors_dict + + def _remove_extra_graph_output(self): idx = 0 while idx < len(self.model.graph.output): cur_opt = self.model.graph.output[idx] @@ -442,8 +451,6 @@ class OnnxDataLoader: continue idx += 1 - self._global_context.onnx_tensors_collection = self.tensors_dict - def _parse_node_output_shape(self): """ Parse the inferred output shape of each node. @@ -516,13 +523,13 @@ class OnnxDataLoader: # Parse ONNX Graph level info self._parse_graph() - # 1. parse all tensors - self._parse_tensors() - - # 2. parse all nodes, note that parse tensors must be done as nodes require tensor info + # 1. parse all nodes, note that parse tensors must be done as nodes require tensor info # to process the node weight sharing. self._parse_nodes() + # 2. remove extra output from onnx model graph + self._remove_extra_graph_output() + # 3. parse value info (incl. node output shape) if self._is_infer_shape: try: @@ -534,15 +541,29 @@ class OnnxDataLoader: log.exception(e) raise e - # 4. Optimize graph to eliminate some nodes. + # 4. load external_data for initializer + self.load_external_data() + + # 5. parse all tensors + self._parse_tensors() + + # 6. Optimize graph to eliminate some nodes. self._find_nodes_to_be_eliminated() - # 5. build nodes connections + # 7. build nodes connections self.build_nodes_connection() - # 6. Run onnx model to fetch actual value of eliminated nodes. + # 8. Run onnx model to fetch actual value of eliminated nodes. self._fetch_eliminated_nodes_value() + def load_external_data(self): + load_external_data_for_model = getattr(import_module("onnx.external_data_helper"), + "load_external_data_for_model") + model_filepath = os.path.realpath(self.model_path) + if model_filepath: + base_dir = os.path.dirname(model_filepath) + load_external_data_for_model(self.model, base_dir) + def _fetch_eliminated_nodes_value(self): """Fetch eliminated nodes values by running onnx inference.""" @@ -556,7 +577,10 @@ class OnnxDataLoader: shape_ref = self._nodes_dict[node].input_name_list[1] output_tensors.append(shape_ref) feed_dict = build_feed_dict(self.model, self.input_nodes) - fetch_dict = fetch_output_from_onnx_model(self.model, feed_dict=feed_dict, output_nodes=output_tensors) + fetch_dict = fetch_output_from_onnx_model(self.model, + self.model_path, + feed_dict=feed_dict, + output_nodes=output_tensors) for opt_tensor_name, value in fetch_dict.items(): self.tensors_dict[opt_tensor_name] = OnnxTensor(value, opt_tensor_name) @@ -570,7 +594,10 @@ class OnnxDataLoader: shape_ref = self._nodes_dict[node].input_name_list[3] output_tensors.append(shape_ref) feed_dict = build_feed_dict(self.model, self.input_nodes) - fetch_dict = fetch_output_from_onnx_model(self.model, feed_dict=feed_dict, output_nodes=output_tensors) + fetch_dict = fetch_output_from_onnx_model(self.model, + self.model_path, + feed_dict=feed_dict, + output_nodes=output_tensors) for opt_tensor_name, value in fetch_dict.items(): self.tensors_dict[opt_tensor_name] = OnnxTensor(value, opt_tensor_name) diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py index 390185f2..a1116c73 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py @@ -25,18 +25,21 @@ class OnnxSimplify: def __init__(self): self._onnx_model = None + self.model_path = str() self._constant_nodes = list() self._outputs_infer = dict() - def run_onnx_simplify(self, onnx_model, input_nodes): + def run_onnx_simplify(self, onnx_model, model_path, input_nodes): """ Run to simplify onnx model. Args: onnx_model (onnx.ModelProto): Onnx Model. + model_path (str): Onnx model path. input_nodes (dict): Input nodes and corresponding sample shape. """ self._onnx_model = onnx_model + self.model_path = model_path self._optimizer() self._get_constant_nodes() self._onnx_infer(input_nodes) @@ -70,13 +73,35 @@ class OnnxSimplify: 'fuse_transpose_into_gemm' ] + raw_initializers = dict() + for initializer in self._onnx_model.graph.initializer: + raw_initializers[initializer.name] = initializer + input_num = len(self._onnx_model.graph.input) onnx_model_optimized = onnxoptimizer.optimize(self._onnx_model, optimizers_list, fixed_point=True) + for initializer in onnx_model_optimized.graph.initializer: + if raw_initializers.get(initializer.name, None): + raw_initializer = raw_initializers.get(initializer.name) + self.copy_initializer_external_data(raw_initializer, initializer) + if self._onnx_model.ir_version > 3: del onnx_model_optimized.graph.input[input_num:] self._onnx_model = onnx_model_optimized + def copy_initializer_external_data(self, src, dst): + """ + Copy external_data from src initializer to dst initializer. + + Args: + src (graph.initializer): Source initializer. + dst (graph.initializer): Destination initializer. + """ + if src.HasField('data_location'): + dst.data_location = src.data_location + if src.external_data: + dst.external_data.extend(src.external_data) + def _get_constant_nodes(self): """Get constant nodes.""" @@ -109,6 +134,7 @@ class OnnxSimplify: output_nodes_name.extend(node.output) original_outputs = [nd.name for nd in self._onnx_model.graph.output] self._outputs_infer = fetch_output_from_onnx_model(self._onnx_model, + self.model_path, feed_dict, output_nodes_name) idx = 0 while idx < len(self._onnx_model.graph.output):