diff --git a/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py index 8c5ac720..1b2accc2 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py +++ b/mindinsight/mindconverter/graph_based_converter/common/code_fragment.py @@ -304,8 +304,30 @@ class NewFragment: """ rewrite_data = {var: data[ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value]} if ExchangeMessageKeywords.VariableScope.value.INPUTS.value in data: - rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join( - data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) + group_inputs = ExchangeMessageKeywords.VariableScope.value.GROUP_INPUTS.value + if group_inputs in data: + input_tuple_list = [] + tuple_index = 0 + tuple_id = 0 + while tuple_index < len(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]): + if tuple_id < len(data[group_inputs]) and tuple_index in data[group_inputs][tuple_id]: + tuple_added = ", ".join(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] + [data[group_inputs][tuple_id][0]: + data[group_inputs][tuple_id][-1]+1]) + tuple_added = f"({tuple_added})" + input_tuple_list.append(tuple_added) + tuple_index = data[group_inputs][tuple_id][-1]+1 + tuple_id += 1 + continue + input_tuple_list.append(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] + [tuple_index]) + tuple_index += 1 + + rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = \ + ", ".join(input_tuple_list) + else: + rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join( + data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data: rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data: diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index 1e1daad0..28d52a5b 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -108,6 +108,7 @@ class ExchangeMessageKeywords(Enum): WEIGHTS = "weights" TRAINABLE_PARAMS = "trainable_params" PARAMETERS_DECLARED = "parameters" + GROUP_INPUTS = "group_inputs" BINARY_HEADER_PYTORCH_FILE = \ diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 1422b5b3..da4f1e9d 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -21,7 +21,7 @@ from typing import List from importlib import import_module from importlib.util import find_spec from functools import partial - +from google.protobuf.internal import api_implementation from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \ save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info @@ -35,7 +35,7 @@ from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCrea BadParamError from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory -from google.protobuf.internal import api_implementation + check_common_dependency_integrity = partial(check_dependency_integrity, "onnx", "onnxruntime", "onnxoptimizer") @@ -267,7 +267,8 @@ def main_graph_base_converter(file_config): if api_implementation.Type() != 'cpp' or os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION') != 'cpp': log_console.warning("Protobuf is currently implemented in \"Python\". " - "The conversion process may take a long time. Please use the \"C++\" backend version.") + "The conversion process may take a long time. " + "Please use `export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp` to enable cpp backend.") graph_path = file_config['model_file'] frame_type = get_framework_type(graph_path) diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/lstm_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/lstm_mapper.py new file mode 100644 index 00000000..9f69673d --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/lstm_mapper.py @@ -0,0 +1,108 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mapper module.""" +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords + + + +class LSTMMapper(ONNXToMindSporeMapper): + """LSTM mapper.""" + + @staticmethod + def _operation_name_in_ms(*args, **kwargs): + return "nn.LSTM" + + @staticmethod + def _convert_params(**kwargs): + """convert params""" + weights = kwargs["weights"] + input_weights = LSTMMapper._find_val_by_index(0, weights) + embed_dim = input_weights.shape[2] + params = kwargs['params'] + output_shape_list = kwargs.get("params").get("output_shape") + output_shape = output_shape_list[0].node_output_shape + # Here the first element determine if the lstm is bidirectional + # `1` means unidirectional. `2` means bidirectional + if output_shape[1] == 2: + return { + "input_size": embed_dim, + "hidden_size": params["hidden_size"], + "bidirectional": True + } + + return { + "input_size": embed_dim, + "hidden_size": params["hidden_size"] + } + + @staticmethod + def _convert_trained_weights(**kwargs): + return dict() + + @staticmethod + def _generate_snippet_template(**kwargs): + """generate snippet template""" + op = kwargs.get("operation") + args = kwargs.get("converted_params", dict()) + weights = kwargs.get("weights") + output_shape_list = kwargs.get("raw_params").get("output_shape") + output_shape = output_shape_list[0].node_output_shape + output_reshape = (output_shape[0], output_shape[2], output_shape[1], output_shape[3]) + trainable_params = kwargs.get("trainable_params", dict()) + if not op: + raise ValueError("Can not get MindSpore operation name.") + variable_slot = "var_0" + init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" + init_reshape = f"self.{{{variable_slot}}}_reshape = P.Reshape()" + init_transpose = f"self.{{{variable_slot}}}_transpose = P.Transpose()" + init_cast = f"self.{{{variable_slot}}}_cast = P.Cast()" + construct_template = f"opt_{{{variable_slot}}}, (opt_{{{variable_slot}}}_h, " \ + f"opt_{{{variable_slot}}}_c) = self.{{{variable_slot}}}" \ + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" + construct_template_cast = f"opt_{{{variable_slot}}} = " \ + f"self.{{{variable_slot}}}_cast(" \ + f"opt_{{{variable_slot}}}, mindspore.float32)" + construct_template_reshape = f"opt_{{{variable_slot}}} = " \ + f"self.{{{variable_slot}}}_reshape(" \ + f"opt_{{{variable_slot}}}, {output_reshape})" + construct_template_transpose = f"opt_{{{variable_slot}}} = " \ + f"self.{{{variable_slot}}}_transpose(" \ + f"opt_{{{variable_slot}}}, (0, 2, 1, 3))" + template = { + variable_slot: { + TemplateKeywords.INIT.value: [init_template, init_cast, init_reshape, init_transpose], + TemplateKeywords.CONSTRUCT.value: [construct_template, + construct_template_cast, + construct_template_reshape, + construct_template_transpose] + } + } + exchange_msg = { + variable_slot: { + ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, + ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, + ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: + ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, + ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], + ExchangeMessageKeywords.VariableScope.value.GROUP_INPUTS.value: [(1, 2)], + ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, + ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, + ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params + } + } + outputs_list = [f"opt_{{{variable_slot}}}", f"opt_{{{variable_slot}}}_h", f"opt_{{{variable_slot}}}_c"] + outputs_mapping = ((0, 0), (1, 1), (2, 2),) + return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py index dfa6c779..c75c2e88 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper @@ -36,11 +35,56 @@ class ConcatMapper(ONNXToMindSporeMapper): @staticmethod def _generate_snippet_template(**kwargs): - template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( - **kwargs) + weights = kwargs.get("weights") + op = kwargs.get("operation") + args = kwargs.get("converted_params", dict()) + trainable_params = kwargs.get("trainable_params", dict()) + if not op: + raise ValueError("Can not get MindSpore operation name.") + variable_slot = "var_0" + init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}},))" - template = reset_init_or_construct(template, variable_slot, [construct_template], - TemplateKeywords.CONSTRUCT.value) + if weights: + tensor = ConcatMapper._find_val_by_index(0, weights) + weight_shape = tensor.shape + weight_type = tensor.dtype + args["weight_shape"] = weight_shape + args["weight_type"] = weight_type + init_tensor = f"self.{{{variable_slot}}}_w = " \ + f"Parameter(Tensor(np.zeros({{weight_shape}}).astype(np.{{weight_type}})), " \ + f"name=None)" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \ + f"self.{{{variable_slot}}}_w))" + template = { + variable_slot: { + TemplateKeywords.INIT.value: [init_template, init_tensor], + TemplateKeywords.CONSTRUCT.value: [construct_template] + } + } + + else: + template = { + variable_slot: { + TemplateKeywords.INIT.value: [init_template], + TemplateKeywords.CONSTRUCT.value: [construct_template] + } + } + + exchange_msg = { + variable_slot: { + ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, + ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, + ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: + ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, + ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], + ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, + ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, + ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params + } + } + outputs_list = [f"opt_{{{variable_slot}}}"] + outputs_mapping = ((0, 0),) return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/floor_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/floor_mapper.py new file mode 100644 index 00000000..2039239a --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/floor_mapper.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mapper module.""" +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper + + +class FloorMapper(ONNXToMindSporeMapper): + """Floor mapper.""" + + @staticmethod + def _operation_name_in_ms(*args, **kwargs): + return "P.Floor" + + @staticmethod + def _convert_params(**kwargs): + return dict() + + @staticmethod + def _convert_trained_weights(**kwargs): + return dict() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py index 6af7378d..b104f2e4 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py @@ -37,6 +37,7 @@ class SliceMapper(ONNXToMindSporeMapper): @staticmethod def _generate_snippet_template(**kwargs): + """Generate snippet template.""" template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( **kwargs) op = kwargs.get("operation") @@ -47,13 +48,14 @@ class SliceMapper(ONNXToMindSporeMapper): starts = SliceMapper._find_val_by_index(0, weights) ends = SliceMapper._find_val_by_index(1, weights) - axes = SliceMapper._find_val_by_index(2, weights, np.array([i for i in range(len(ipt_shape))])) + axes = SliceMapper._find_val_by_index(2, weights, np.array(list(range(len(ipt_shape))))) steps = SliceMapper._find_val_by_index(3, weights, np.array([1 for _ in range(len(ipt_shape))])) if not op: raise ValueError("Can not get MindSpore operation name.") + if not weights: - raise ValueError("Cannot get required params from slice.") + raise ValueError("Can not get required params from slice.") if axes.shape != (1,): ordered_begin = sorted(zip(starts.tolist(), axes.tolist()), key=lambda x: x[1], reverse=False) @@ -71,9 +73,9 @@ class SliceMapper(ONNXToMindSporeMapper): end[axis] = min(ends.tolist()[0], end[axis]) strides[axis] = steps.tolist()[0] - args["begin"] = tuple(begin) - args["end"] = tuple(end) - args["strides"] = tuple(strides) + args['begin'] = tuple(begin) + args['end'] = tuple(end) + args['strides'] = tuple(strides) variable_slot = "var_0" init_template = f"self.{{{variable_slot}}} = {op}()" @@ -84,7 +86,8 @@ class SliceMapper(ONNXToMindSporeMapper): f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ f"self.{{{variable_slot}}}_begin, self.{{{variable_slot}}}_end, " \ f"self.{{{variable_slot}}}_strides)" - template = reset_init_or_construct(template, variable_slot, [init_template, init_begin, init_end, init_strides], + template = reset_init_or_construct(template, variable_slot, + [init_template, init_begin, init_end, init_strides], TemplateKeywords.INIT.value) template = reset_init_or_construct(template, variable_slot, [construct_template], TemplateKeywords.CONSTRUCT.value) diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py index aa682eba..6831bbb8 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py @@ -27,8 +27,8 @@ class SplitMapper(ONNXToMindSporeMapper): @staticmethod def _convert_params(**kwargs): axis = kwargs["params"]["axis"] - split = kwargs["params"]["split"] - output_num = len(split) + output_shape_list = kwargs["params"]["output_shape"] + output_num = len(output_shape_list) return {"axis": axis, "output_num": output_num} @@ -43,9 +43,16 @@ class SplitMapper(ONNXToMindSporeMapper): weights = kwargs.get("weights") if not op: raise ValueError("Can not get MindSpore operation name.") + converted_params = kwargs["converted_params"] + output_num = converted_params["output_num"] + variable_slot = "var_0" init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" - construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ + slot_list = [f"opt_{{{variable_slot}}}"] + for i in range(1, output_num): # Here `1` means the second output of the operator + slot_list.append(f"opt_{{{variable_slot}}}_{i}") + slot_final = ", ".join(slot_list) + construct_template = f"{slot_final} = self.{{{variable_slot}}}" \ f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" template = { variable_slot: { @@ -58,13 +65,16 @@ class SplitMapper(ONNXToMindSporeMapper): ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: - ExchangeMessageKeywords.VariableScope.value.ARR_TYPE.value, + ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} } } - outputs_list = [f"opt_{{{variable_slot}}}"] - outputs_mapping = ((0, 0),) + outputs_list = slot_list + outputs_mapping = [] + for i in range(output_num): + outputs_mapping.append((i, i)) + outputs_mapping = tuple(outputs_mapping) return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/squeeze_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/squeeze_mapper.py new file mode 100644 index 00000000..80e2cf30 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/squeeze_mapper.py @@ -0,0 +1,40 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mapper module.""" +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper + + +class SqueezeMapper(ONNXToMindSporeMapper): + """Squeeze mapper.""" + + @staticmethod + def _operation_name_in_ms(*args, **kwargs): + return "P.Squeeze" + + @staticmethod + def _convert_params(**kwargs): + params = kwargs["params"] + if len(params['axes']) == 1: + return { + "axis": params['axes'][0] + } + + return { + "axis": tuple(params['axes']) + } + + @staticmethod + def _convert_trained_weights(**kwargs): + return dict() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json b/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json index 3bb9b314..bcec726b 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json +++ b/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json @@ -30,5 +30,8 @@ "onnx::Erf": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.erf_mapper.ErfMapper", "onnx::Pow": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.pow_mapper.PowMapper", "onnx::Einsum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.einsum_mapper.EinSumMapper", - "onnx::Tanh": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.tanh_mapper.TanhMapper" + "onnx::Tanh": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.tanh_mapper.TanhMapper", + "onnx::LSTM": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.lstm_mapper.LSTMMapper", + "onnx::Squeeze": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.squeeze_mapper.SqueezeMapper", + "onnx::Floor": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.floor_mapper.FloorMapper" } \ No newline at end of file diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py index 584fb783..45ccedda 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py @@ -23,7 +23,8 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.input_nod from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_parser import PyTorchGraphParser from mindinsight.mindconverter.graph_based_converter.third_party_graph.tf_graph_parser import TFGraphParser -from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, NodeWeight +from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, \ + NodeWeight, NodeOutputShape from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher import generate_scope_name NONE_SCOPE_OP = { @@ -139,6 +140,9 @@ class OnnxGraph(Graph): super(OnnxGraph, self).build() self._collect_input_shape_of_each_node() + for node_name in self._shape_dict: + if len(self._shape_dict[node_name]) == 1: + self._shape_dict[node_name] = self._shape_dict[node_name][0].node_output_shape def _collect_input_shape_of_each_node(self): """ @@ -156,11 +160,17 @@ class OnnxGraph(Graph): input_nodes[ipt].set_scope_name(node.scope_name) node.precursor_nodes.insert(0, ipt) input_nodes[ipt].set_successor_nodes(node_name) - self._shape_dict[ipt] = input_nodes[ipt].output_shape + output_shape_single = NodeOutputShape(ipt, None, input_nodes[ipt].output_shape) + if ipt not in self._shape_dict: + self._shape_dict.setdefault(ipt, []).append(output_shape_single) ipt_shape = [] - for p_nd in node.precursor_nodes: - shp = self._shape_dict.get(p_nd) - ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp) + for ipt_nd_name in node.ir_node_inputs: + for p_nd in node.precursor_nodes: + shp_list = self._shape_dict.get(p_nd) + for shp in shp_list: + if ipt_nd_name == shp.node_opt_name: + shp_single = shp.node_output_shape + ipt_shape.append(tuple(shp_single) if isinstance(shp_single, list) else shp_single) self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index 9982870e..4824a594 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -438,6 +438,15 @@ class OnnxDataLoader: for tensor in tensors: t = OnnxTensor(tensor) self.tensors_dict[t.name] = t + + idx = 0 + while idx < len(self.model.graph.output): + cur_opt = self.model.graph.output[idx] + if cur_opt.name not in self.output_nodes: + self.model.graph.output.remove(cur_opt) + continue + idx += 1 + self._global_context.onnx_tensors_collection = self.tensors_dict def _parse_node_output_shape(self): @@ -465,7 +474,8 @@ class OnnxDataLoader: node_name = self.output_name_to_node_name[node_opt_name] if not node_name: raise GraphInitError(msg=f"Cannot find where edge {node_opt_name} comes from.") - self.node_output_shape_dict[node_name] = shape + node_output_shape = NodeOutputShape(node_opt_name, node_name, shape) + self.node_output_shape_dict.setdefault(node_name, []).append(node_output_shape) def get_node(self, node_name): """Get the OnnxNode instance by node name.""" @@ -576,12 +586,16 @@ class OnnxDataLoader: """Create a PASS to optimize shape and reshape operations in ONNX ir graph.""" to_be_eliminated_op = {"Cast", "Concat", "Squeeze", "Unsqueeze", "Slice", "Gather", "Shape"} - + def _is_origin_inputs(node): + for ipt in node.input_name_list: + if ipt not in self.input_nodes: + return False + return True def _traceback_precursor_nodes_until_shape_op(node_ref): nonlocal self e_nodes = [] node = self._nodes_dict[self.output_name_to_node_name[node_ref]] - if node.op_type not in to_be_eliminated_op: + if node.op_type not in to_be_eliminated_op or _is_origin_inputs(node): return e_nodes e_nodes.append(node.name) for ipt in node.input_name_list: @@ -646,3 +660,23 @@ class NodeWeight: @property def location(self): return self._weight_location + + +class NodeOutputShape: + """Node output shape and its name.""" + def __init__(self, node_opt_name, node_name, node_output_shape): + self._node_opt_name = node_opt_name + self._node_name = node_name + self._node_output_shape = node_output_shape + + @property + def node_opt_name(self): + return self._node_opt_name + + @property + def node_name(self): + return self._node_name + + @property + def node_output_shape(self): + return self._node_output_shape