From: @ghty0625 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -304,8 +304,30 @@ class NewFragment: | |||
| """ | |||
| rewrite_data = {var: data[ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value]} | |||
| if ExchangeMessageKeywords.VariableScope.value.INPUTS.value in data: | |||
| rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join( | |||
| data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) | |||
| group_inputs = ExchangeMessageKeywords.VariableScope.value.GROUP_INPUTS.value | |||
| if group_inputs in data: | |||
| input_tuple_list = [] | |||
| tuple_index = 0 | |||
| tuple_id = 0 | |||
| while tuple_index < len(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]): | |||
| if tuple_id < len(data[group_inputs]) and tuple_index in data[group_inputs][tuple_id]: | |||
| tuple_added = ", ".join(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] | |||
| [data[group_inputs][tuple_id][0]: | |||
| data[group_inputs][tuple_id][-1]+1]) | |||
| tuple_added = f"({tuple_added})" | |||
| input_tuple_list.append(tuple_added) | |||
| tuple_index = data[group_inputs][tuple_id][-1]+1 | |||
| tuple_id += 1 | |||
| continue | |||
| input_tuple_list.append(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] | |||
| [tuple_index]) | |||
| tuple_index += 1 | |||
| rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = \ | |||
| ", ".join(input_tuple_list) | |||
| else: | |||
| rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join( | |||
| data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) | |||
| if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data: | |||
| rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) | |||
| if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data: | |||
| @@ -108,6 +108,7 @@ class ExchangeMessageKeywords(Enum): | |||
| WEIGHTS = "weights" | |||
| TRAINABLE_PARAMS = "trainable_params" | |||
| PARAMETERS_DECLARED = "parameters" | |||
| GROUP_INPUTS = "group_inputs" | |||
| BINARY_HEADER_PYTORCH_FILE = \ | |||
| @@ -21,7 +21,7 @@ from typing import List | |||
| from importlib import import_module | |||
| from importlib.util import find_spec | |||
| from functools import partial | |||
| from google.protobuf.internal import api_implementation | |||
| from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ | |||
| save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info | |||
| @@ -35,7 +35,7 @@ from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCrea | |||
| BadParamError | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory | |||
| from google.protobuf.internal import api_implementation | |||
| check_common_dependency_integrity = partial(check_dependency_integrity, | |||
| "onnx", "onnxruntime", "onnxoptimizer") | |||
| @@ -267,7 +267,8 @@ def main_graph_base_converter(file_config): | |||
| if api_implementation.Type() != 'cpp' or os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION') != 'cpp': | |||
| log_console.warning("Protobuf is currently implemented in \"Python\". " | |||
| "The conversion process may take a long time. Please use the \"C++\" backend version.") | |||
| "The conversion process may take a long time. " | |||
| "Please use `export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp` to enable cpp backend.") | |||
| graph_path = file_config['model_file'] | |||
| frame_type = get_framework_type(graph_path) | |||
| @@ -0,0 +1,108 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| class LSTMMapper(ONNXToMindSporeMapper): | |||
| """LSTM mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "nn.LSTM" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| """convert params""" | |||
| weights = kwargs["weights"] | |||
| input_weights = LSTMMapper._find_val_by_index(0, weights) | |||
| embed_dim = input_weights.shape[2] | |||
| params = kwargs['params'] | |||
| output_shape_list = kwargs.get("params").get("output_shape") | |||
| output_shape = output_shape_list[0].node_output_shape | |||
| # Here the first element determine if the lstm is bidirectional | |||
| # `1` means unidirectional. `2` means bidirectional | |||
| if output_shape[1] == 2: | |||
| return { | |||
| "input_size": embed_dim, | |||
| "hidden_size": params["hidden_size"], | |||
| "bidirectional": True | |||
| } | |||
| return { | |||
| "input_size": embed_dim, | |||
| "hidden_size": params["hidden_size"] | |||
| } | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| """generate snippet template""" | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params", dict()) | |||
| weights = kwargs.get("weights") | |||
| output_shape_list = kwargs.get("raw_params").get("output_shape") | |||
| output_shape = output_shape_list[0].node_output_shape | |||
| output_reshape = (output_shape[0], output_shape[2], output_shape[1], output_shape[3]) | |||
| trainable_params = kwargs.get("trainable_params", dict()) | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| init_reshape = f"self.{{{variable_slot}}}_reshape = P.Reshape()" | |||
| init_transpose = f"self.{{{variable_slot}}}_transpose = P.Transpose()" | |||
| init_cast = f"self.{{{variable_slot}}}_cast = P.Cast()" | |||
| construct_template = f"opt_{{{variable_slot}}}, (opt_{{{variable_slot}}}_h, " \ | |||
| f"opt_{{{variable_slot}}}_c) = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" | |||
| construct_template_cast = f"opt_{{{variable_slot}}} = " \ | |||
| f"self.{{{variable_slot}}}_cast(" \ | |||
| f"opt_{{{variable_slot}}}, mindspore.float32)" | |||
| construct_template_reshape = f"opt_{{{variable_slot}}} = " \ | |||
| f"self.{{{variable_slot}}}_reshape(" \ | |||
| f"opt_{{{variable_slot}}}, {output_reshape})" | |||
| construct_template_transpose = f"opt_{{{variable_slot}}} = " \ | |||
| f"self.{{{variable_slot}}}_transpose(" \ | |||
| f"opt_{{{variable_slot}}}, (0, 2, 1, 3))" | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template, init_cast, init_reshape, init_transpose], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template, | |||
| construct_template_cast, | |||
| construct_template_reshape, | |||
| construct_template_transpose] | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.GROUP_INPUTS.value: [(1, 2)], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}", f"opt_{{{variable_slot}}}_h", f"opt_{{{variable_slot}}}_c"] | |||
| outputs_mapping = ((0, 0), (1, 1), (2, 2),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -13,7 +13,6 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| @@ -36,11 +35,56 @@ class ConcatMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| weights = kwargs.get("weights") | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params", dict()) | |||
| trainable_params = kwargs.get("trainable_params", dict()) | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}},))" | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| if weights: | |||
| tensor = ConcatMapper._find_val_by_index(0, weights) | |||
| weight_shape = tensor.shape | |||
| weight_type = tensor.dtype | |||
| args["weight_shape"] = weight_shape | |||
| args["weight_type"] = weight_type | |||
| init_tensor = f"self.{{{variable_slot}}}_w = " \ | |||
| f"Parameter(Tensor(np.zeros({{weight_shape}}).astype(np.{{weight_type}})), " \ | |||
| f"name=None)" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ | |||
| f"self.{{{variable_slot}}}_w))" | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template, init_tensor], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||
| } | |||
| } | |||
| else: | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class FloorMapper(ONNXToMindSporeMapper): | |||
| """Floor mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Floor" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @@ -37,6 +37,7 @@ class SliceMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| """Generate snippet template.""" | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| op = kwargs.get("operation") | |||
| @@ -47,13 +48,14 @@ class SliceMapper(ONNXToMindSporeMapper): | |||
| starts = SliceMapper._find_val_by_index(0, weights) | |||
| ends = SliceMapper._find_val_by_index(1, weights) | |||
| axes = SliceMapper._find_val_by_index(2, weights, np.array([i for i in range(len(ipt_shape))])) | |||
| axes = SliceMapper._find_val_by_index(2, weights, np.array(list(range(len(ipt_shape))))) | |||
| steps = SliceMapper._find_val_by_index(3, weights, np.array([1 for _ in range(len(ipt_shape))])) | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| if not weights: | |||
| raise ValueError("Cannot get required params from slice.") | |||
| raise ValueError("Can not get required params from slice.") | |||
| if axes.shape != (1,): | |||
| ordered_begin = sorted(zip(starts.tolist(), axes.tolist()), key=lambda x: x[1], reverse=False) | |||
| @@ -71,9 +73,9 @@ class SliceMapper(ONNXToMindSporeMapper): | |||
| end[axis] = min(ends.tolist()[0], end[axis]) | |||
| strides[axis] = steps.tolist()[0] | |||
| args["begin"] = tuple(begin) | |||
| args["end"] = tuple(end) | |||
| args["strides"] = tuple(strides) | |||
| args['begin'] = tuple(begin) | |||
| args['end'] = tuple(end) | |||
| args['strides'] = tuple(strides) | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}()" | |||
| @@ -84,7 +86,8 @@ class SliceMapper(ONNXToMindSporeMapper): | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ | |||
| f"self.{{{variable_slot}}}_begin, self.{{{variable_slot}}}_end, " \ | |||
| f"self.{{{variable_slot}}}_strides)" | |||
| template = reset_init_or_construct(template, variable_slot, [init_template, init_begin, init_end, init_strides], | |||
| template = reset_init_or_construct(template, variable_slot, | |||
| [init_template, init_begin, init_end, init_strides], | |||
| TemplateKeywords.INIT.value) | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| @@ -27,8 +27,8 @@ class SplitMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| axis = kwargs["params"]["axis"] | |||
| split = kwargs["params"]["split"] | |||
| output_num = len(split) | |||
| output_shape_list = kwargs["params"]["output_shape"] | |||
| output_num = len(output_shape_list) | |||
| return {"axis": axis, | |||
| "output_num": output_num} | |||
| @@ -43,9 +43,16 @@ class SplitMapper(ONNXToMindSporeMapper): | |||
| weights = kwargs.get("weights") | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| converted_params = kwargs["converted_params"] | |||
| output_num = converted_params["output_num"] | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| slot_list = [f"opt_{{{variable_slot}}}"] | |||
| for i in range(1, output_num): # Here `1` means the second output of the operator | |||
| slot_list.append(f"opt_{{{variable_slot}}}_{i}") | |||
| slot_final = ", ".join(slot_list) | |||
| construct_template = f"{slot_final} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" | |||
| template = { | |||
| variable_slot: { | |||
| @@ -58,13 +65,16 @@ class SplitMapper(ONNXToMindSporeMapper): | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.ARR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| outputs_list = slot_list | |||
| outputs_mapping = [] | |||
| for i in range(output_num): | |||
| outputs_mapping.append((i, i)) | |||
| outputs_mapping = tuple(outputs_mapping) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| class SqueezeMapper(ONNXToMindSporeMapper): | |||
| """Squeeze mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.Squeeze" | |||
| @staticmethod | |||
| def _convert_params(**kwargs): | |||
| params = kwargs["params"] | |||
| if len(params['axes']) == 1: | |||
| return { | |||
| "axis": params['axes'][0] | |||
| } | |||
| return { | |||
| "axis": tuple(params['axes']) | |||
| } | |||
| @staticmethod | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @@ -30,5 +30,8 @@ | |||
| "onnx::Erf": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.erf_mapper.ErfMapper", | |||
| "onnx::Pow": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.pow_mapper.PowMapper", | |||
| "onnx::Einsum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.einsum_mapper.EinSumMapper", | |||
| "onnx::Tanh": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.tanh_mapper.TanhMapper" | |||
| "onnx::Tanh": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.tanh_mapper.TanhMapper", | |||
| "onnx::LSTM": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.lstm_mapper.LSTMMapper", | |||
| "onnx::Squeeze": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.squeeze_mapper.SqueezeMapper", | |||
| "onnx::Floor": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.floor_mapper.FloorMapper" | |||
| } | |||
| @@ -23,7 +23,8 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.input_nod | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_parser import PyTorchGraphParser | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.tf_graph_parser import TFGraphParser | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, NodeWeight | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, \ | |||
| NodeWeight, NodeOutputShape | |||
| from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher import generate_scope_name | |||
| NONE_SCOPE_OP = { | |||
| @@ -139,6 +140,9 @@ class OnnxGraph(Graph): | |||
| super(OnnxGraph, self).build() | |||
| self._collect_input_shape_of_each_node() | |||
| for node_name in self._shape_dict: | |||
| if len(self._shape_dict[node_name]) == 1: | |||
| self._shape_dict[node_name] = self._shape_dict[node_name][0].node_output_shape | |||
| def _collect_input_shape_of_each_node(self): | |||
| """ | |||
| @@ -156,11 +160,17 @@ class OnnxGraph(Graph): | |||
| input_nodes[ipt].set_scope_name(node.scope_name) | |||
| node.precursor_nodes.insert(0, ipt) | |||
| input_nodes[ipt].set_successor_nodes(node_name) | |||
| self._shape_dict[ipt] = input_nodes[ipt].output_shape | |||
| output_shape_single = NodeOutputShape(ipt, None, input_nodes[ipt].output_shape) | |||
| if ipt not in self._shape_dict: | |||
| self._shape_dict.setdefault(ipt, []).append(output_shape_single) | |||
| ipt_shape = [] | |||
| for p_nd in node.precursor_nodes: | |||
| shp = self._shape_dict.get(p_nd) | |||
| ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp) | |||
| for ipt_nd_name in node.ir_node_inputs: | |||
| for p_nd in node.precursor_nodes: | |||
| shp_list = self._shape_dict.get(p_nd) | |||
| for shp in shp_list: | |||
| if ipt_nd_name == shp.node_opt_name: | |||
| shp_single = shp.node_output_shape | |||
| ipt_shape.append(tuple(shp_single) if isinstance(shp_single, list) else shp_single) | |||
| self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape | |||
| @@ -438,6 +438,15 @@ class OnnxDataLoader: | |||
| for tensor in tensors: | |||
| t = OnnxTensor(tensor) | |||
| self.tensors_dict[t.name] = t | |||
| idx = 0 | |||
| while idx < len(self.model.graph.output): | |||
| cur_opt = self.model.graph.output[idx] | |||
| if cur_opt.name not in self.output_nodes: | |||
| self.model.graph.output.remove(cur_opt) | |||
| continue | |||
| idx += 1 | |||
| self._global_context.onnx_tensors_collection = self.tensors_dict | |||
| def _parse_node_output_shape(self): | |||
| @@ -465,7 +474,8 @@ class OnnxDataLoader: | |||
| node_name = self.output_name_to_node_name[node_opt_name] | |||
| if not node_name: | |||
| raise GraphInitError(msg=f"Cannot find where edge {node_opt_name} comes from.") | |||
| self.node_output_shape_dict[node_name] = shape | |||
| node_output_shape = NodeOutputShape(node_opt_name, node_name, shape) | |||
| self.node_output_shape_dict.setdefault(node_name, []).append(node_output_shape) | |||
| def get_node(self, node_name): | |||
| """Get the OnnxNode instance by node name.""" | |||
| @@ -576,12 +586,16 @@ class OnnxDataLoader: | |||
| """Create a PASS to optimize shape and reshape operations in ONNX ir graph.""" | |||
| to_be_eliminated_op = {"Cast", "Concat", "Squeeze", "Unsqueeze", "Slice", | |||
| "Gather", "Shape"} | |||
| def _is_origin_inputs(node): | |||
| for ipt in node.input_name_list: | |||
| if ipt not in self.input_nodes: | |||
| return False | |||
| return True | |||
| def _traceback_precursor_nodes_until_shape_op(node_ref): | |||
| nonlocal self | |||
| e_nodes = [] | |||
| node = self._nodes_dict[self.output_name_to_node_name[node_ref]] | |||
| if node.op_type not in to_be_eliminated_op: | |||
| if node.op_type not in to_be_eliminated_op or _is_origin_inputs(node): | |||
| return e_nodes | |||
| e_nodes.append(node.name) | |||
| for ipt in node.input_name_list: | |||
| @@ -646,3 +660,23 @@ class NodeWeight: | |||
| @property | |||
| def location(self): | |||
| return self._weight_location | |||
| class NodeOutputShape: | |||
| """Node output shape and its name.""" | |||
| def __init__(self, node_opt_name, node_name, node_output_shape): | |||
| self._node_opt_name = node_opt_name | |||
| self._node_name = node_name | |||
| self._node_output_shape = node_output_shape | |||
| @property | |||
| def node_opt_name(self): | |||
| return self._node_opt_name | |||
| @property | |||
| def node_name(self): | |||
| return self._node_name | |||
| @property | |||
| def node_output_shape(self): | |||
| return self._node_output_shape | |||