| @@ -283,6 +283,7 @@ class SourceFilesSaveError(MindConverterException): | |||
| REPORT_GENERATE_FAIL = 3 | |||
| CKPT_GENERATE_FAIL = 4 | |||
| MAP_GENERATE_FAIL = 5 | |||
| MODEL_SAVE_FAIL = 6 | |||
| BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value | |||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | |||
| @@ -299,6 +300,7 @@ class SourceFilesSaveError(MindConverterException): | |||
| ReportGenerationError, | |||
| CheckPointGenerationError, | |||
| WeightMapGenerationError, | |||
| OnnxModelSaveError, | |||
| IOError, cls) | |||
| return except_source | |||
| @@ -430,6 +432,17 @@ class WeightMapGenerationError(SourceFilesSaveError): | |||
| """Raise from exception below.""" | |||
| return cls | |||
| class OnnxModelSaveError(SourceFilesSaveError): | |||
| """The onnx model save fail error.""" | |||
| ERROR_CODE = SourceFilesSaveError.ErrCode.MODEL_SAVE_FAIL.value | |||
| def __init__(self, msg): | |||
| super(OnnxModelSaveError, self).__init__(msg=msg) | |||
| @classmethod | |||
| def raise_from(cls): | |||
| """Raise from exception below.""" | |||
| return cls | |||
| class SubGraphSearchingError(MindConverterException): | |||
| """Sub-graph searching exception.""" | |||
| @@ -16,6 +16,7 @@ | |||
| import json | |||
| import os | |||
| import stat | |||
| import uuid | |||
| from importlib import import_module | |||
| from importlib.util import find_spec | |||
| from typing import List, Tuple, Mapping | |||
| @@ -23,7 +24,7 @@ from typing import List, Tuple, Mapping | |||
| import numpy as np | |||
| from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \ | |||
| CheckPointGenerationError, WeightMapGenerationError | |||
| CheckPointGenerationError, WeightMapGenerationError, ModelLoadingError, OnnxModelSaveError | |||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, FrameworkType, \ | |||
| TENSORFLOW_MODEL_SUFFIX, THIRD_PART_VERSION, ONNX_MODEL_SUFFIX, DTYPE_MAP | |||
| @@ -84,7 +85,7 @@ def build_feed_dict(onnx_model, input_nodes: dict): | |||
| return feed_dict | |||
| def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]): | |||
| def fetch_output_from_onnx_model(model, model_path: str, feed_dict: dict, output_nodes: List[str]): | |||
| """ | |||
| Fetch specific nodes output from onnx model. | |||
| @@ -93,6 +94,7 @@ def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str] | |||
| Args: | |||
| model (ModelProto): ONNX model. | |||
| model_path (str): ONNX model path. | |||
| feed_dict (dict): Feed forward inputs. | |||
| output_nodes (list[str]): Output nodes list. | |||
| @@ -105,9 +107,29 @@ def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str] | |||
| edit_model = _add_outputs_of_onnx_model(model, output_nodes) | |||
| onnx = import_module("onnx") | |||
| ort = import_module("onnxruntime") | |||
| sess = ort.InferenceSession(path_or_bytes=bytes(edit_model.SerializeToString())) | |||
| fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict) | |||
| try: | |||
| dir_path = os.path.dirname(model_path) | |||
| stem_name = os.path.splitext(os.path.basename(model_path))[0] | |||
| filename = ".~{0}_{1}".format(stem_name, str(uuid.uuid4())) | |||
| tmp_file = os.path.join(dir_path, filename) | |||
| onnx.save_tensor(edit_model, tmp_file) | |||
| except (TypeError, IOError) as error: | |||
| if os.path.exists(tmp_file): | |||
| os.remove(tmp_file) | |||
| raise OnnxModelSaveError("Onnx model save failed, {}".format(str(error))) | |||
| try: | |||
| sess = ort.InferenceSession(path_or_bytes=tmp_file) | |||
| fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict) | |||
| except ModelLoadingError.raise_from() as error: | |||
| raise ModelLoadingError("OnnxRuntimeError, {}".format(str(error))) | |||
| finally: | |||
| if os.path.exists(tmp_file): | |||
| os.remove(tmp_file) | |||
| run_result = dict() | |||
| for idx, opt in enumerate(output_nodes): | |||
| run_result[opt] = fetched_res[idx] | |||
| @@ -161,13 +183,17 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], | |||
| raise ReportGenerationError(str(error)) | |||
| save_checkpoint = getattr(import_module("mindspore.train.serialization"), "save_checkpoint") | |||
| ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) | |||
| try: | |||
| for idx, trainable_weight in enumerate(trainable_weights): | |||
| if len(trainable_weights) > 1: | |||
| ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}_{idx}.ckpt")) | |||
| else: | |||
| ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) | |||
| if os.path.exists(ckpt_file_path): | |||
| raise CheckPointGenerationError("Checkpoint file with the same name already exists.") | |||
| save_checkpoint(trainable_weights, ckpt_file_path) | |||
| except TypeError as error: | |||
| raise CheckPointGenerationError(str(error)) | |||
| try: | |||
| save_checkpoint(trainable_weight, ckpt_file_path) | |||
| except TypeError as error: | |||
| raise CheckPointGenerationError(str(error)) | |||
| weight_map_path = os.path.realpath(os.path.join(report_folder, f"weight_map_of_{model_name}.json")) | |||
| try: | |||
| @@ -46,7 +46,7 @@ TF2ONNX_MIN_VER = "1.7.1" | |||
| ONNXRUNTIME_MIN_VER = "1.5.2" | |||
| ONNXOPTIMIZER_MIN_VER = "0.1.2" | |||
| ONNXOPTIMIZER_MAX_VER = "0.1.2" | |||
| CHECKPOINT_SEGMENT_SIZE = 2040109465 # 1.9GB, no more than 2GB | |||
| DTYPE_MAP = { | |||
| 1: np.float32, | |||
| @@ -35,7 +35,7 @@ from mindinsight.mindconverter.graph_based_converter.report_generator import Rep | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list | |||
| from mindinsight.mindconverter.graph_based_converter.generator.matcher import MatcherLauncher | |||
| from mindinsight.mindconverter.graph_based_converter.generator.shared_weights import SharedWeightHelper | |||
| from mindinsight.mindconverter.graph_based_converter.constant import CHECKPOINT_SEGMENT_SIZE | |||
| class CodeStruct: | |||
| """ | |||
| @@ -635,15 +635,25 @@ class Generator: | |||
| } | |||
| ) | |||
| save_obj_list = list() | |||
| save_obj = list() | |||
| data_nbytes = 0 | |||
| for weight_name, weight_value in trainable_weights_dict.items(): | |||
| obj = { | |||
| 'name': weight_name, | |||
| 'data': mindspore.Tensor(weight_value) | |||
| } | |||
| data_nbytes += obj['data'].nbytes | |||
| if data_nbytes > CHECKPOINT_SEGMENT_SIZE: | |||
| save_obj_list.append(save_obj) | |||
| save_obj = [] | |||
| data_nbytes = obj['data'].nbytes | |||
| save_obj.append(obj) | |||
| return save_obj, weight_map | |||
| if save_obj: | |||
| save_obj_list.append(save_obj) | |||
| return save_obj_list, weight_map | |||
| @GeneratorError.check_except("Generator occurs an error when generating code statements.") | |||
| def generate(self): | |||
| @@ -90,9 +90,10 @@ class Graph(BaseGraph, abc.ABC): | |||
| sorted = False | |||
| def __init__(self, model, **kwargs): | |||
| def __init__(self, model, model_path, **kwargs): | |||
| super(Graph, self).__init__() | |||
| self.model = model | |||
| self.model_path = model_path | |||
| self._raw_input_nodes = kwargs.get("input_nodes") | |||
| self._raw_output_nodes = kwargs.get("output_nodes") | |||
| self._nodes_collection = OrderedDict() | |||
| @@ -246,7 +247,7 @@ class Graph(BaseGraph, abc.ABC): | |||
| cls, graph instance. | |||
| """ | |||
| src_graph = cls.load_graph(graph_path=model_path, **kwargs) | |||
| return cls(src_graph, **kwargs) | |||
| return cls(src_graph, model_path, **kwargs) | |||
| class GraphNode(abc.ABC): | |||
| @@ -56,10 +56,11 @@ class OnnxGraph(Graph): | |||
| Args: | |||
| model (onnx.ModelProto): Onnx defined model proto. | |||
| model_path (str): Onnx model path. | |||
| """ | |||
| def __init__(self, model, **kwargs): | |||
| super(OnnxGraph, self).__init__(model=model, **kwargs) | |||
| def __init__(self, model, model_path, **kwargs): | |||
| super(OnnxGraph, self).__init__(model=model, model_path=model_path, **kwargs) | |||
| self.build() | |||
| @@ -115,6 +116,7 @@ class OnnxGraph(Graph): | |||
| def build(self): | |||
| """Build graph tree.""" | |||
| model_data = OnnxDataLoader(self.model, | |||
| self.model_path, | |||
| input_nodes=self._raw_input_nodes, | |||
| output_nodes=self._raw_output_nodes) | |||
| scope_name_list = generate_scope_name(model_data) | |||
| @@ -207,7 +209,7 @@ class OnnxGraph(Graph): | |||
| output_nodes=output_nodes) | |||
| else: | |||
| onnx = import_module('onnx') | |||
| onnx_model = onnx.load(graph_path) | |||
| onnx_model = onnx.load(graph_path, load_external_data=False) | |||
| onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input] | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define ONNX related operations.""" | |||
| import os | |||
| import itertools | |||
| import re | |||
| import abc | |||
| @@ -249,17 +250,19 @@ class OnnxDataLoader: | |||
| Args: | |||
| onnx_model (onnx.ModelProto): Original Onnx defined model. | |||
| model_path (str): Onnx model path. | |||
| input_nodes (Union[str, list]): Input nodes of ONNX model. | |||
| output_nodes (Union[str, list]): Output nodes of ONNX model. | |||
| infer_shape (bool): Enable the shape inference after conversion. | |||
| Default: True | |||
| """ | |||
| def __init__(self, onnx_model, input_nodes: dict, | |||
| def __init__(self, onnx_model, model_path: str, input_nodes: dict, | |||
| output_nodes: list, infer_shape=True): | |||
| onnx_sim = OnnxSimplify() | |||
| onnx_model_sim = onnx_sim.run_onnx_simplify(onnx_model, input_nodes) | |||
| onnx_model_sim = onnx_sim.run_onnx_simplify(onnx_model, model_path, input_nodes) | |||
| self.model = onnx_model_sim | |||
| self.model_path = model_path | |||
| self.graph = onnx_model_sim.graph | |||
| self.nodes = onnx_model_sim.graph.node | |||
| self.input_nodes = input_nodes | |||
| @@ -384,7 +387,10 @@ class OnnxDataLoader: | |||
| feed_dict = build_feed_dict(self.inferred_model, self.input_nodes) | |||
| outputs_infer = fetch_output_from_onnx_model(self.model, feed_dict, output_nodes_name) | |||
| outputs_infer = fetch_output_from_onnx_model(self.model, | |||
| self.model_path, | |||
| feed_dict, | |||
| output_nodes_name) | |||
| return outputs_infer | |||
| def _parse_nodes(self): | |||
| @@ -434,6 +440,9 @@ class OnnxDataLoader: | |||
| t = OnnxTensor(tensor) | |||
| self.tensors_dict[t.name] = t | |||
| self._global_context.onnx_tensors_collection = self.tensors_dict | |||
| def _remove_extra_graph_output(self): | |||
| idx = 0 | |||
| while idx < len(self.model.graph.output): | |||
| cur_opt = self.model.graph.output[idx] | |||
| @@ -442,8 +451,6 @@ class OnnxDataLoader: | |||
| continue | |||
| idx += 1 | |||
| self._global_context.onnx_tensors_collection = self.tensors_dict | |||
| def _parse_node_output_shape(self): | |||
| """ | |||
| Parse the inferred output shape of each node. | |||
| @@ -516,13 +523,13 @@ class OnnxDataLoader: | |||
| # Parse ONNX Graph level info | |||
| self._parse_graph() | |||
| # 1. parse all tensors | |||
| self._parse_tensors() | |||
| # 2. parse all nodes, note that parse tensors must be done as nodes require tensor info | |||
| # 1. parse all nodes, note that parse tensors must be done as nodes require tensor info | |||
| # to process the node weight sharing. | |||
| self._parse_nodes() | |||
| # 2. remove extra output from onnx model graph | |||
| self._remove_extra_graph_output() | |||
| # 3. parse value info (incl. node output shape) | |||
| if self._is_infer_shape: | |||
| try: | |||
| @@ -534,15 +541,29 @@ class OnnxDataLoader: | |||
| log.exception(e) | |||
| raise e | |||
| # 4. Optimize graph to eliminate some nodes. | |||
| # 4. load external_data for initializer | |||
| self.load_external_data() | |||
| # 5. parse all tensors | |||
| self._parse_tensors() | |||
| # 6. Optimize graph to eliminate some nodes. | |||
| self._find_nodes_to_be_eliminated() | |||
| # 5. build nodes connections | |||
| # 7. build nodes connections | |||
| self.build_nodes_connection() | |||
| # 6. Run onnx model to fetch actual value of eliminated nodes. | |||
| # 8. Run onnx model to fetch actual value of eliminated nodes. | |||
| self._fetch_eliminated_nodes_value() | |||
| def load_external_data(self): | |||
| load_external_data_for_model = getattr(import_module("onnx.external_data_helper"), | |||
| "load_external_data_for_model") | |||
| model_filepath = os.path.realpath(self.model_path) | |||
| if model_filepath: | |||
| base_dir = os.path.dirname(model_filepath) | |||
| load_external_data_for_model(self.model, base_dir) | |||
| def _fetch_eliminated_nodes_value(self): | |||
| """Fetch eliminated nodes values by running onnx inference.""" | |||
| @@ -556,7 +577,10 @@ class OnnxDataLoader: | |||
| shape_ref = self._nodes_dict[node].input_name_list[1] | |||
| output_tensors.append(shape_ref) | |||
| feed_dict = build_feed_dict(self.model, self.input_nodes) | |||
| fetch_dict = fetch_output_from_onnx_model(self.model, feed_dict=feed_dict, output_nodes=output_tensors) | |||
| fetch_dict = fetch_output_from_onnx_model(self.model, | |||
| self.model_path, | |||
| feed_dict=feed_dict, | |||
| output_nodes=output_tensors) | |||
| for opt_tensor_name, value in fetch_dict.items(): | |||
| self.tensors_dict[opt_tensor_name] = OnnxTensor(value, opt_tensor_name) | |||
| @@ -570,7 +594,10 @@ class OnnxDataLoader: | |||
| shape_ref = self._nodes_dict[node].input_name_list[3] | |||
| output_tensors.append(shape_ref) | |||
| feed_dict = build_feed_dict(self.model, self.input_nodes) | |||
| fetch_dict = fetch_output_from_onnx_model(self.model, feed_dict=feed_dict, output_nodes=output_tensors) | |||
| fetch_dict = fetch_output_from_onnx_model(self.model, | |||
| self.model_path, | |||
| feed_dict=feed_dict, | |||
| output_nodes=output_tensors) | |||
| for opt_tensor_name, value in fetch_dict.items(): | |||
| self.tensors_dict[opt_tensor_name] = OnnxTensor(value, opt_tensor_name) | |||
| @@ -25,18 +25,21 @@ class OnnxSimplify: | |||
| def __init__(self): | |||
| self._onnx_model = None | |||
| self.model_path = str() | |||
| self._constant_nodes = list() | |||
| self._outputs_infer = dict() | |||
| def run_onnx_simplify(self, onnx_model, input_nodes): | |||
| def run_onnx_simplify(self, onnx_model, model_path, input_nodes): | |||
| """ | |||
| Run to simplify onnx model. | |||
| Args: | |||
| onnx_model (onnx.ModelProto): Onnx Model. | |||
| model_path (str): Onnx model path. | |||
| input_nodes (dict): Input nodes and corresponding sample shape. | |||
| """ | |||
| self._onnx_model = onnx_model | |||
| self.model_path = model_path | |||
| self._optimizer() | |||
| self._get_constant_nodes() | |||
| self._onnx_infer(input_nodes) | |||
| @@ -70,13 +73,35 @@ class OnnxSimplify: | |||
| 'fuse_transpose_into_gemm' | |||
| ] | |||
| raw_initializers = dict() | |||
| for initializer in self._onnx_model.graph.initializer: | |||
| raw_initializers[initializer.name] = initializer | |||
| input_num = len(self._onnx_model.graph.input) | |||
| onnx_model_optimized = onnxoptimizer.optimize(self._onnx_model, optimizers_list, fixed_point=True) | |||
| for initializer in onnx_model_optimized.graph.initializer: | |||
| if raw_initializers.get(initializer.name, None): | |||
| raw_initializer = raw_initializers.get(initializer.name) | |||
| self.copy_initializer_external_data(raw_initializer, initializer) | |||
| if self._onnx_model.ir_version > 3: | |||
| del onnx_model_optimized.graph.input[input_num:] | |||
| self._onnx_model = onnx_model_optimized | |||
| def copy_initializer_external_data(self, src, dst): | |||
| """ | |||
| Copy external_data from src initializer to dst initializer. | |||
| Args: | |||
| src (graph.initializer): Source initializer. | |||
| dst (graph.initializer): Destination initializer. | |||
| """ | |||
| if src.HasField('data_location'): | |||
| dst.data_location = src.data_location | |||
| if src.external_data: | |||
| dst.external_data.extend(src.external_data) | |||
| def _get_constant_nodes(self): | |||
| """Get constant nodes.""" | |||
| @@ -109,6 +134,7 @@ class OnnxSimplify: | |||
| output_nodes_name.extend(node.output) | |||
| original_outputs = [nd.name for nd in self._onnx_model.graph.output] | |||
| self._outputs_infer = fetch_output_from_onnx_model(self._onnx_model, | |||
| self.model_path, | |||
| feed_dict, output_nodes_name) | |||
| idx = 0 | |||
| while idx < len(self._onnx_model.graph.output): | |||