| @@ -283,6 +283,7 @@ class SourceFilesSaveError(MindConverterException): | |||||
| REPORT_GENERATE_FAIL = 3 | REPORT_GENERATE_FAIL = 3 | ||||
| CKPT_GENERATE_FAIL = 4 | CKPT_GENERATE_FAIL = 4 | ||||
| MAP_GENERATE_FAIL = 5 | MAP_GENERATE_FAIL = 5 | ||||
| MODEL_SAVE_FAIL = 6 | |||||
| BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value | BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value | ||||
| ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | ERROR_CODE = ErrCode.UNKNOWN_ERROR.value | ||||
| @@ -299,6 +300,7 @@ class SourceFilesSaveError(MindConverterException): | |||||
| ReportGenerationError, | ReportGenerationError, | ||||
| CheckPointGenerationError, | CheckPointGenerationError, | ||||
| WeightMapGenerationError, | WeightMapGenerationError, | ||||
| OnnxModelSaveError, | |||||
| IOError, cls) | IOError, cls) | ||||
| return except_source | return except_source | ||||
| @@ -430,6 +432,17 @@ class WeightMapGenerationError(SourceFilesSaveError): | |||||
| """Raise from exception below.""" | """Raise from exception below.""" | ||||
| return cls | return cls | ||||
| class OnnxModelSaveError(SourceFilesSaveError): | |||||
| """The onnx model save fail error.""" | |||||
| ERROR_CODE = SourceFilesSaveError.ErrCode.MODEL_SAVE_FAIL.value | |||||
| def __init__(self, msg): | |||||
| super(OnnxModelSaveError, self).__init__(msg=msg) | |||||
| @classmethod | |||||
| def raise_from(cls): | |||||
| """Raise from exception below.""" | |||||
| return cls | |||||
| class SubGraphSearchingError(MindConverterException): | class SubGraphSearchingError(MindConverterException): | ||||
| """Sub-graph searching exception.""" | """Sub-graph searching exception.""" | ||||
| @@ -16,6 +16,7 @@ | |||||
| import json | import json | ||||
| import os | import os | ||||
| import stat | import stat | ||||
| import uuid | |||||
| from importlib import import_module | from importlib import import_module | ||||
| from importlib.util import find_spec | from importlib.util import find_spec | ||||
| from typing import List, Tuple, Mapping | from typing import List, Tuple, Mapping | ||||
| @@ -23,7 +24,7 @@ from typing import List, Tuple, Mapping | |||||
| import numpy as np | import numpy as np | ||||
| from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \ | from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \ | ||||
| CheckPointGenerationError, WeightMapGenerationError | |||||
| CheckPointGenerationError, WeightMapGenerationError, ModelLoadingError, OnnxModelSaveError | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, FrameworkType, \ | from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, FrameworkType, \ | ||||
| TENSORFLOW_MODEL_SUFFIX, THIRD_PART_VERSION, ONNX_MODEL_SUFFIX, DTYPE_MAP | TENSORFLOW_MODEL_SUFFIX, THIRD_PART_VERSION, ONNX_MODEL_SUFFIX, DTYPE_MAP | ||||
| @@ -84,7 +85,7 @@ def build_feed_dict(onnx_model, input_nodes: dict): | |||||
| return feed_dict | return feed_dict | ||||
| def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]): | |||||
| def fetch_output_from_onnx_model(model, model_path: str, feed_dict: dict, output_nodes: List[str]): | |||||
| """ | """ | ||||
| Fetch specific nodes output from onnx model. | Fetch specific nodes output from onnx model. | ||||
| @@ -93,6 +94,7 @@ def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str] | |||||
| Args: | Args: | ||||
| model (ModelProto): ONNX model. | model (ModelProto): ONNX model. | ||||
| model_path (str): ONNX model path. | |||||
| feed_dict (dict): Feed forward inputs. | feed_dict (dict): Feed forward inputs. | ||||
| output_nodes (list[str]): Output nodes list. | output_nodes (list[str]): Output nodes list. | ||||
| @@ -105,9 +107,29 @@ def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str] | |||||
| edit_model = _add_outputs_of_onnx_model(model, output_nodes) | edit_model = _add_outputs_of_onnx_model(model, output_nodes) | ||||
| onnx = import_module("onnx") | |||||
| ort = import_module("onnxruntime") | ort = import_module("onnxruntime") | ||||
| sess = ort.InferenceSession(path_or_bytes=bytes(edit_model.SerializeToString())) | |||||
| fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict) | |||||
| try: | |||||
| dir_path = os.path.dirname(model_path) | |||||
| stem_name = os.path.splitext(os.path.basename(model_path))[0] | |||||
| filename = ".~{0}_{1}".format(stem_name, str(uuid.uuid4())) | |||||
| tmp_file = os.path.join(dir_path, filename) | |||||
| onnx.save_tensor(edit_model, tmp_file) | |||||
| except (TypeError, IOError) as error: | |||||
| if os.path.exists(tmp_file): | |||||
| os.remove(tmp_file) | |||||
| raise OnnxModelSaveError("Onnx model save failed, {}".format(str(error))) | |||||
| try: | |||||
| sess = ort.InferenceSession(path_or_bytes=tmp_file) | |||||
| fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict) | |||||
| except ModelLoadingError.raise_from() as error: | |||||
| raise ModelLoadingError("OnnxRuntimeError, {}".format(str(error))) | |||||
| finally: | |||||
| if os.path.exists(tmp_file): | |||||
| os.remove(tmp_file) | |||||
| run_result = dict() | run_result = dict() | ||||
| for idx, opt in enumerate(output_nodes): | for idx, opt in enumerate(output_nodes): | ||||
| run_result[opt] = fetched_res[idx] | run_result[opt] = fetched_res[idx] | ||||
| @@ -161,13 +183,17 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], | |||||
| raise ReportGenerationError(str(error)) | raise ReportGenerationError(str(error)) | ||||
| save_checkpoint = getattr(import_module("mindspore.train.serialization"), "save_checkpoint") | save_checkpoint = getattr(import_module("mindspore.train.serialization"), "save_checkpoint") | ||||
| ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) | |||||
| try: | |||||
| for idx, trainable_weight in enumerate(trainable_weights): | |||||
| if len(trainable_weights) > 1: | |||||
| ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}_{idx}.ckpt")) | |||||
| else: | |||||
| ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) | |||||
| if os.path.exists(ckpt_file_path): | if os.path.exists(ckpt_file_path): | ||||
| raise CheckPointGenerationError("Checkpoint file with the same name already exists.") | raise CheckPointGenerationError("Checkpoint file with the same name already exists.") | ||||
| save_checkpoint(trainable_weights, ckpt_file_path) | |||||
| except TypeError as error: | |||||
| raise CheckPointGenerationError(str(error)) | |||||
| try: | |||||
| save_checkpoint(trainable_weight, ckpt_file_path) | |||||
| except TypeError as error: | |||||
| raise CheckPointGenerationError(str(error)) | |||||
| weight_map_path = os.path.realpath(os.path.join(report_folder, f"weight_map_of_{model_name}.json")) | weight_map_path = os.path.realpath(os.path.join(report_folder, f"weight_map_of_{model_name}.json")) | ||||
| try: | try: | ||||
| @@ -46,7 +46,7 @@ TF2ONNX_MIN_VER = "1.7.1" | |||||
| ONNXRUNTIME_MIN_VER = "1.5.2" | ONNXRUNTIME_MIN_VER = "1.5.2" | ||||
| ONNXOPTIMIZER_MIN_VER = "0.1.2" | ONNXOPTIMIZER_MIN_VER = "0.1.2" | ||||
| ONNXOPTIMIZER_MAX_VER = "0.1.2" | ONNXOPTIMIZER_MAX_VER = "0.1.2" | ||||
| CHECKPOINT_SEGMENT_SIZE = 2040109465 # 1.9GB, no more than 2GB | |||||
| DTYPE_MAP = { | DTYPE_MAP = { | ||||
| 1: np.float32, | 1: np.float32, | ||||
| @@ -35,7 +35,7 @@ from mindinsight.mindconverter.graph_based_converter.report_generator import Rep | |||||
| from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list | from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list | ||||
| from mindinsight.mindconverter.graph_based_converter.generator.matcher import MatcherLauncher | from mindinsight.mindconverter.graph_based_converter.generator.matcher import MatcherLauncher | ||||
| from mindinsight.mindconverter.graph_based_converter.generator.shared_weights import SharedWeightHelper | from mindinsight.mindconverter.graph_based_converter.generator.shared_weights import SharedWeightHelper | ||||
| from mindinsight.mindconverter.graph_based_converter.constant import CHECKPOINT_SEGMENT_SIZE | |||||
| class CodeStruct: | class CodeStruct: | ||||
| """ | """ | ||||
| @@ -635,15 +635,25 @@ class Generator: | |||||
| } | } | ||||
| ) | ) | ||||
| save_obj_list = list() | |||||
| save_obj = list() | save_obj = list() | ||||
| data_nbytes = 0 | |||||
| for weight_name, weight_value in trainable_weights_dict.items(): | for weight_name, weight_value in trainable_weights_dict.items(): | ||||
| obj = { | obj = { | ||||
| 'name': weight_name, | 'name': weight_name, | ||||
| 'data': mindspore.Tensor(weight_value) | 'data': mindspore.Tensor(weight_value) | ||||
| } | } | ||||
| data_nbytes += obj['data'].nbytes | |||||
| if data_nbytes > CHECKPOINT_SEGMENT_SIZE: | |||||
| save_obj_list.append(save_obj) | |||||
| save_obj = [] | |||||
| data_nbytes = obj['data'].nbytes | |||||
| save_obj.append(obj) | save_obj.append(obj) | ||||
| return save_obj, weight_map | |||||
| if save_obj: | |||||
| save_obj_list.append(save_obj) | |||||
| return save_obj_list, weight_map | |||||
| @GeneratorError.check_except("Generator occurs an error when generating code statements.") | @GeneratorError.check_except("Generator occurs an error when generating code statements.") | ||||
| def generate(self): | def generate(self): | ||||
| @@ -90,9 +90,10 @@ class Graph(BaseGraph, abc.ABC): | |||||
| sorted = False | sorted = False | ||||
| def __init__(self, model, **kwargs): | |||||
| def __init__(self, model, model_path, **kwargs): | |||||
| super(Graph, self).__init__() | super(Graph, self).__init__() | ||||
| self.model = model | self.model = model | ||||
| self.model_path = model_path | |||||
| self._raw_input_nodes = kwargs.get("input_nodes") | self._raw_input_nodes = kwargs.get("input_nodes") | ||||
| self._raw_output_nodes = kwargs.get("output_nodes") | self._raw_output_nodes = kwargs.get("output_nodes") | ||||
| self._nodes_collection = OrderedDict() | self._nodes_collection = OrderedDict() | ||||
| @@ -246,7 +247,7 @@ class Graph(BaseGraph, abc.ABC): | |||||
| cls, graph instance. | cls, graph instance. | ||||
| """ | """ | ||||
| src_graph = cls.load_graph(graph_path=model_path, **kwargs) | src_graph = cls.load_graph(graph_path=model_path, **kwargs) | ||||
| return cls(src_graph, **kwargs) | |||||
| return cls(src_graph, model_path, **kwargs) | |||||
| class GraphNode(abc.ABC): | class GraphNode(abc.ABC): | ||||
| @@ -56,10 +56,11 @@ class OnnxGraph(Graph): | |||||
| Args: | Args: | ||||
| model (onnx.ModelProto): Onnx defined model proto. | model (onnx.ModelProto): Onnx defined model proto. | ||||
| model_path (str): Onnx model path. | |||||
| """ | """ | ||||
| def __init__(self, model, **kwargs): | |||||
| super(OnnxGraph, self).__init__(model=model, **kwargs) | |||||
| def __init__(self, model, model_path, **kwargs): | |||||
| super(OnnxGraph, self).__init__(model=model, model_path=model_path, **kwargs) | |||||
| self.build() | self.build() | ||||
| @@ -115,6 +116,7 @@ class OnnxGraph(Graph): | |||||
| def build(self): | def build(self): | ||||
| """Build graph tree.""" | """Build graph tree.""" | ||||
| model_data = OnnxDataLoader(self.model, | model_data = OnnxDataLoader(self.model, | ||||
| self.model_path, | |||||
| input_nodes=self._raw_input_nodes, | input_nodes=self._raw_input_nodes, | ||||
| output_nodes=self._raw_output_nodes) | output_nodes=self._raw_output_nodes) | ||||
| scope_name_list = generate_scope_name(model_data) | scope_name_list = generate_scope_name(model_data) | ||||
| @@ -207,7 +209,7 @@ class OnnxGraph(Graph): | |||||
| output_nodes=output_nodes) | output_nodes=output_nodes) | ||||
| else: | else: | ||||
| onnx = import_module('onnx') | onnx = import_module('onnx') | ||||
| onnx_model = onnx.load(graph_path) | |||||
| onnx_model = onnx.load(graph_path, load_external_data=False) | |||||
| onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input] | onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input] | ||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Define ONNX related operations.""" | """Define ONNX related operations.""" | ||||
| import os | |||||
| import itertools | import itertools | ||||
| import re | import re | ||||
| import abc | import abc | ||||
| @@ -249,17 +250,19 @@ class OnnxDataLoader: | |||||
| Args: | Args: | ||||
| onnx_model (onnx.ModelProto): Original Onnx defined model. | onnx_model (onnx.ModelProto): Original Onnx defined model. | ||||
| model_path (str): Onnx model path. | |||||
| input_nodes (Union[str, list]): Input nodes of ONNX model. | input_nodes (Union[str, list]): Input nodes of ONNX model. | ||||
| output_nodes (Union[str, list]): Output nodes of ONNX model. | output_nodes (Union[str, list]): Output nodes of ONNX model. | ||||
| infer_shape (bool): Enable the shape inference after conversion. | infer_shape (bool): Enable the shape inference after conversion. | ||||
| Default: True | Default: True | ||||
| """ | """ | ||||
| def __init__(self, onnx_model, input_nodes: dict, | |||||
| def __init__(self, onnx_model, model_path: str, input_nodes: dict, | |||||
| output_nodes: list, infer_shape=True): | output_nodes: list, infer_shape=True): | ||||
| onnx_sim = OnnxSimplify() | onnx_sim = OnnxSimplify() | ||||
| onnx_model_sim = onnx_sim.run_onnx_simplify(onnx_model, input_nodes) | |||||
| onnx_model_sim = onnx_sim.run_onnx_simplify(onnx_model, model_path, input_nodes) | |||||
| self.model = onnx_model_sim | self.model = onnx_model_sim | ||||
| self.model_path = model_path | |||||
| self.graph = onnx_model_sim.graph | self.graph = onnx_model_sim.graph | ||||
| self.nodes = onnx_model_sim.graph.node | self.nodes = onnx_model_sim.graph.node | ||||
| self.input_nodes = input_nodes | self.input_nodes = input_nodes | ||||
| @@ -384,7 +387,10 @@ class OnnxDataLoader: | |||||
| feed_dict = build_feed_dict(self.inferred_model, self.input_nodes) | feed_dict = build_feed_dict(self.inferred_model, self.input_nodes) | ||||
| outputs_infer = fetch_output_from_onnx_model(self.model, feed_dict, output_nodes_name) | |||||
| outputs_infer = fetch_output_from_onnx_model(self.model, | |||||
| self.model_path, | |||||
| feed_dict, | |||||
| output_nodes_name) | |||||
| return outputs_infer | return outputs_infer | ||||
| def _parse_nodes(self): | def _parse_nodes(self): | ||||
| @@ -434,6 +440,9 @@ class OnnxDataLoader: | |||||
| t = OnnxTensor(tensor) | t = OnnxTensor(tensor) | ||||
| self.tensors_dict[t.name] = t | self.tensors_dict[t.name] = t | ||||
| self._global_context.onnx_tensors_collection = self.tensors_dict | |||||
| def _remove_extra_graph_output(self): | |||||
| idx = 0 | idx = 0 | ||||
| while idx < len(self.model.graph.output): | while idx < len(self.model.graph.output): | ||||
| cur_opt = self.model.graph.output[idx] | cur_opt = self.model.graph.output[idx] | ||||
| @@ -442,8 +451,6 @@ class OnnxDataLoader: | |||||
| continue | continue | ||||
| idx += 1 | idx += 1 | ||||
| self._global_context.onnx_tensors_collection = self.tensors_dict | |||||
| def _parse_node_output_shape(self): | def _parse_node_output_shape(self): | ||||
| """ | """ | ||||
| Parse the inferred output shape of each node. | Parse the inferred output shape of each node. | ||||
| @@ -516,13 +523,13 @@ class OnnxDataLoader: | |||||
| # Parse ONNX Graph level info | # Parse ONNX Graph level info | ||||
| self._parse_graph() | self._parse_graph() | ||||
| # 1. parse all tensors | |||||
| self._parse_tensors() | |||||
| # 2. parse all nodes, note that parse tensors must be done as nodes require tensor info | |||||
| # 1. parse all nodes, note that parse tensors must be done as nodes require tensor info | |||||
| # to process the node weight sharing. | # to process the node weight sharing. | ||||
| self._parse_nodes() | self._parse_nodes() | ||||
| # 2. remove extra output from onnx model graph | |||||
| self._remove_extra_graph_output() | |||||
| # 3. parse value info (incl. node output shape) | # 3. parse value info (incl. node output shape) | ||||
| if self._is_infer_shape: | if self._is_infer_shape: | ||||
| try: | try: | ||||
| @@ -534,15 +541,29 @@ class OnnxDataLoader: | |||||
| log.exception(e) | log.exception(e) | ||||
| raise e | raise e | ||||
| # 4. Optimize graph to eliminate some nodes. | |||||
| # 4. load external_data for initializer | |||||
| self.load_external_data() | |||||
| # 5. parse all tensors | |||||
| self._parse_tensors() | |||||
| # 6. Optimize graph to eliminate some nodes. | |||||
| self._find_nodes_to_be_eliminated() | self._find_nodes_to_be_eliminated() | ||||
| # 5. build nodes connections | |||||
| # 7. build nodes connections | |||||
| self.build_nodes_connection() | self.build_nodes_connection() | ||||
| # 6. Run onnx model to fetch actual value of eliminated nodes. | |||||
| # 8. Run onnx model to fetch actual value of eliminated nodes. | |||||
| self._fetch_eliminated_nodes_value() | self._fetch_eliminated_nodes_value() | ||||
| def load_external_data(self): | |||||
| load_external_data_for_model = getattr(import_module("onnx.external_data_helper"), | |||||
| "load_external_data_for_model") | |||||
| model_filepath = os.path.realpath(self.model_path) | |||||
| if model_filepath: | |||||
| base_dir = os.path.dirname(model_filepath) | |||||
| load_external_data_for_model(self.model, base_dir) | |||||
| def _fetch_eliminated_nodes_value(self): | def _fetch_eliminated_nodes_value(self): | ||||
| """Fetch eliminated nodes values by running onnx inference.""" | """Fetch eliminated nodes values by running onnx inference.""" | ||||
| @@ -556,7 +577,10 @@ class OnnxDataLoader: | |||||
| shape_ref = self._nodes_dict[node].input_name_list[1] | shape_ref = self._nodes_dict[node].input_name_list[1] | ||||
| output_tensors.append(shape_ref) | output_tensors.append(shape_ref) | ||||
| feed_dict = build_feed_dict(self.model, self.input_nodes) | feed_dict = build_feed_dict(self.model, self.input_nodes) | ||||
| fetch_dict = fetch_output_from_onnx_model(self.model, feed_dict=feed_dict, output_nodes=output_tensors) | |||||
| fetch_dict = fetch_output_from_onnx_model(self.model, | |||||
| self.model_path, | |||||
| feed_dict=feed_dict, | |||||
| output_nodes=output_tensors) | |||||
| for opt_tensor_name, value in fetch_dict.items(): | for opt_tensor_name, value in fetch_dict.items(): | ||||
| self.tensors_dict[opt_tensor_name] = OnnxTensor(value, opt_tensor_name) | self.tensors_dict[opt_tensor_name] = OnnxTensor(value, opt_tensor_name) | ||||
| @@ -570,7 +594,10 @@ class OnnxDataLoader: | |||||
| shape_ref = self._nodes_dict[node].input_name_list[3] | shape_ref = self._nodes_dict[node].input_name_list[3] | ||||
| output_tensors.append(shape_ref) | output_tensors.append(shape_ref) | ||||
| feed_dict = build_feed_dict(self.model, self.input_nodes) | feed_dict = build_feed_dict(self.model, self.input_nodes) | ||||
| fetch_dict = fetch_output_from_onnx_model(self.model, feed_dict=feed_dict, output_nodes=output_tensors) | |||||
| fetch_dict = fetch_output_from_onnx_model(self.model, | |||||
| self.model_path, | |||||
| feed_dict=feed_dict, | |||||
| output_nodes=output_tensors) | |||||
| for opt_tensor_name, value in fetch_dict.items(): | for opt_tensor_name, value in fetch_dict.items(): | ||||
| self.tensors_dict[opt_tensor_name] = OnnxTensor(value, opt_tensor_name) | self.tensors_dict[opt_tensor_name] = OnnxTensor(value, opt_tensor_name) | ||||
| @@ -25,18 +25,21 @@ class OnnxSimplify: | |||||
| def __init__(self): | def __init__(self): | ||||
| self._onnx_model = None | self._onnx_model = None | ||||
| self.model_path = str() | |||||
| self._constant_nodes = list() | self._constant_nodes = list() | ||||
| self._outputs_infer = dict() | self._outputs_infer = dict() | ||||
| def run_onnx_simplify(self, onnx_model, input_nodes): | |||||
| def run_onnx_simplify(self, onnx_model, model_path, input_nodes): | |||||
| """ | """ | ||||
| Run to simplify onnx model. | Run to simplify onnx model. | ||||
| Args: | Args: | ||||
| onnx_model (onnx.ModelProto): Onnx Model. | onnx_model (onnx.ModelProto): Onnx Model. | ||||
| model_path (str): Onnx model path. | |||||
| input_nodes (dict): Input nodes and corresponding sample shape. | input_nodes (dict): Input nodes and corresponding sample shape. | ||||
| """ | """ | ||||
| self._onnx_model = onnx_model | self._onnx_model = onnx_model | ||||
| self.model_path = model_path | |||||
| self._optimizer() | self._optimizer() | ||||
| self._get_constant_nodes() | self._get_constant_nodes() | ||||
| self._onnx_infer(input_nodes) | self._onnx_infer(input_nodes) | ||||
| @@ -70,13 +73,35 @@ class OnnxSimplify: | |||||
| 'fuse_transpose_into_gemm' | 'fuse_transpose_into_gemm' | ||||
| ] | ] | ||||
| raw_initializers = dict() | |||||
| for initializer in self._onnx_model.graph.initializer: | |||||
| raw_initializers[initializer.name] = initializer | |||||
| input_num = len(self._onnx_model.graph.input) | input_num = len(self._onnx_model.graph.input) | ||||
| onnx_model_optimized = onnxoptimizer.optimize(self._onnx_model, optimizers_list, fixed_point=True) | onnx_model_optimized = onnxoptimizer.optimize(self._onnx_model, optimizers_list, fixed_point=True) | ||||
| for initializer in onnx_model_optimized.graph.initializer: | |||||
| if raw_initializers.get(initializer.name, None): | |||||
| raw_initializer = raw_initializers.get(initializer.name) | |||||
| self.copy_initializer_external_data(raw_initializer, initializer) | |||||
| if self._onnx_model.ir_version > 3: | if self._onnx_model.ir_version > 3: | ||||
| del onnx_model_optimized.graph.input[input_num:] | del onnx_model_optimized.graph.input[input_num:] | ||||
| self._onnx_model = onnx_model_optimized | self._onnx_model = onnx_model_optimized | ||||
| def copy_initializer_external_data(self, src, dst): | |||||
| """ | |||||
| Copy external_data from src initializer to dst initializer. | |||||
| Args: | |||||
| src (graph.initializer): Source initializer. | |||||
| dst (graph.initializer): Destination initializer. | |||||
| """ | |||||
| if src.HasField('data_location'): | |||||
| dst.data_location = src.data_location | |||||
| if src.external_data: | |||||
| dst.external_data.extend(src.external_data) | |||||
| def _get_constant_nodes(self): | def _get_constant_nodes(self): | ||||
| """Get constant nodes.""" | """Get constant nodes.""" | ||||
| @@ -109,6 +134,7 @@ class OnnxSimplify: | |||||
| output_nodes_name.extend(node.output) | output_nodes_name.extend(node.output) | ||||
| original_outputs = [nd.name for nd in self._onnx_model.graph.output] | original_outputs = [nd.name for nd in self._onnx_model.graph.output] | ||||
| self._outputs_infer = fetch_output_from_onnx_model(self._onnx_model, | self._outputs_infer = fetch_output_from_onnx_model(self._onnx_model, | ||||
| self.model_path, | |||||
| feed_dict, output_nodes_name) | feed_dict, output_nodes_name) | ||||
| idx = 0 | idx = 0 | ||||
| while idx < len(self._onnx_model.graph.output): | while idx < len(self._onnx_model.graph.output): | ||||