Browse Source

!1220 Add support for transferring LSTM

From: @ghty0625
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
1dde3dc207
12 changed files with 339 additions and 31 deletions
  1. +24
    -2
      mindinsight/mindconverter/graph_based_converter/common/code_fragment.py
  2. +1
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  3. +4
    -3
      mindinsight/mindconverter/graph_based_converter/framework.py
  4. +108
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/lstm_mapper.py
  5. +49
    -5
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py
  6. +32
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/floor_mapper.py
  7. +9
    -6
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py
  8. +16
    -6
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py
  9. +40
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/squeeze_mapper.py
  10. +4
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json
  11. +15
    -5
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  12. +37
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py

+ 24
- 2
mindinsight/mindconverter/graph_based_converter/common/code_fragment.py View File

@@ -304,8 +304,30 @@ class NewFragment:
"""
rewrite_data = {var: data[ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value]}
if ExchangeMessageKeywords.VariableScope.value.INPUTS.value in data:
rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join(
data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value])
group_inputs = ExchangeMessageKeywords.VariableScope.value.GROUP_INPUTS.value
if group_inputs in data:
input_tuple_list = []
tuple_index = 0
tuple_id = 0
while tuple_index < len(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]):
if tuple_id < len(data[group_inputs]) and tuple_index in data[group_inputs][tuple_id]:
tuple_added = ", ".join(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]
[data[group_inputs][tuple_id][0]:
data[group_inputs][tuple_id][-1]+1])
tuple_added = f"({tuple_added})"
input_tuple_list.append(tuple_added)
tuple_index = data[group_inputs][tuple_id][-1]+1
tuple_id += 1
continue
input_tuple_list.append(data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]
[tuple_index])
tuple_index += 1

rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = \
", ".join(input_tuple_list)
else:
rewrite_data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value] = ", ".join(
data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value])
if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data:
rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value])
if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data:


+ 1
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -108,6 +108,7 @@ class ExchangeMessageKeywords(Enum):
WEIGHTS = "weights"
TRAINABLE_PARAMS = "trainable_params"
PARAMETERS_DECLARED = "parameters"
GROUP_INPUTS = "group_inputs"


BINARY_HEADER_PYTORCH_FILE = \


+ 4
- 3
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -21,7 +21,7 @@ from typing import List
from importlib import import_module
from importlib.util import find_spec
from functools import partial
from google.protobuf.internal import api_implementation
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \
save_code_file_and_report, get_framework_type, check_dependency_integrity, get_third_part_lib_validation_error_info
@@ -35,7 +35,7 @@ from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCrea
BadParamError
from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory

from google.protobuf.internal import api_implementation

check_common_dependency_integrity = partial(check_dependency_integrity,
"onnx", "onnxruntime", "onnxoptimizer")
@@ -267,7 +267,8 @@ def main_graph_base_converter(file_config):

if api_implementation.Type() != 'cpp' or os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION') != 'cpp':
log_console.warning("Protobuf is currently implemented in \"Python\". "
"The conversion process may take a long time. Please use the \"C++\" backend version.")
"The conversion process may take a long time. "
"Please use `export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp` to enable cpp backend.")

graph_path = file_config['model_file']
frame_type = get_framework_type(graph_path)


+ 108
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/lstm_mapper.py View File

@@ -0,0 +1,108 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords



class LSTMMapper(ONNXToMindSporeMapper):
"""LSTM mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "nn.LSTM"

@staticmethod
def _convert_params(**kwargs):
"""convert params"""
weights = kwargs["weights"]
input_weights = LSTMMapper._find_val_by_index(0, weights)
embed_dim = input_weights.shape[2]
params = kwargs['params']
output_shape_list = kwargs.get("params").get("output_shape")
output_shape = output_shape_list[0].node_output_shape
# Here the first element determine if the lstm is bidirectional
# `1` means unidirectional. `2` means bidirectional
if output_shape[1] == 2:
return {
"input_size": embed_dim,
"hidden_size": params["hidden_size"],
"bidirectional": True
}

return {
"input_size": embed_dim,
"hidden_size": params["hidden_size"]
}

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _generate_snippet_template(**kwargs):
"""generate snippet template"""
op = kwargs.get("operation")
args = kwargs.get("converted_params", dict())
weights = kwargs.get("weights")
output_shape_list = kwargs.get("raw_params").get("output_shape")
output_shape = output_shape_list[0].node_output_shape
output_reshape = (output_shape[0], output_shape[2], output_shape[1], output_shape[3])
trainable_params = kwargs.get("trainable_params", dict())
if not op:
raise ValueError("Can not get MindSpore operation name.")
variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"
init_reshape = f"self.{{{variable_slot}}}_reshape = P.Reshape()"
init_transpose = f"self.{{{variable_slot}}}_transpose = P.Transpose()"
init_cast = f"self.{{{variable_slot}}}_cast = P.Cast()"
construct_template = f"opt_{{{variable_slot}}}, (opt_{{{variable_slot}}}_h, " \
f"opt_{{{variable_slot}}}_c) = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})"
construct_template_cast = f"opt_{{{variable_slot}}} = " \
f"self.{{{variable_slot}}}_cast(" \
f"opt_{{{variable_slot}}}, mindspore.float32)"
construct_template_reshape = f"opt_{{{variable_slot}}} = " \
f"self.{{{variable_slot}}}_reshape(" \
f"opt_{{{variable_slot}}}, {output_reshape})"
construct_template_transpose = f"opt_{{{variable_slot}}} = " \
f"self.{{{variable_slot}}}_transpose(" \
f"opt_{{{variable_slot}}}, (0, 2, 1, 3))"
template = {
variable_slot: {
TemplateKeywords.INIT.value: [init_template, init_cast, init_reshape, init_transpose],
TemplateKeywords.CONSTRUCT.value: [construct_template,
construct_template_cast,
construct_template_reshape,
construct_template_transpose]
}
}
exchange_msg = {
variable_slot: {
ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,
ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:
ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.GROUP_INPUTS.value: [(1, 2)],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params
}
}
outputs_list = [f"opt_{{{variable_slot}}}", f"opt_{{{variable_slot}}}_h", f"opt_{{{variable_slot}}}_c"]
outputs_mapping = ((0, 0), (1, 1), (2, 2),)
return template, exchange_msg, outputs_list, outputs_mapping

+ 49
- 5
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/concat_mapper.py View File

@@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper

@@ -36,11 +35,56 @@ class ConcatMapper(ONNXToMindSporeMapper):

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
weights = kwargs.get("weights")
op = kwargs.get("operation")
args = kwargs.get("converted_params", dict())
trainable_params = kwargs.get("trainable_params", dict())
if not op:
raise ValueError("Can not get MindSpore operation name.")

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}},))"
template = reset_init_or_construct(template, variable_slot, [construct_template],
TemplateKeywords.CONSTRUCT.value)
if weights:
tensor = ConcatMapper._find_val_by_index(0, weights)
weight_shape = tensor.shape
weight_type = tensor.dtype
args["weight_shape"] = weight_shape
args["weight_type"] = weight_type
init_tensor = f"self.{{{variable_slot}}}_w = " \
f"Parameter(Tensor(np.zeros({{weight_shape}}).astype(np.{{weight_type}})), " \
f"name=None)"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"(({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \
f"self.{{{variable_slot}}}_w))"
template = {
variable_slot: {
TemplateKeywords.INIT.value: [init_template, init_tensor],
TemplateKeywords.CONSTRUCT.value: [construct_template]
}
}

else:
template = {
variable_slot: {
TemplateKeywords.INIT.value: [init_template],
TemplateKeywords.CONSTRUCT.value: [construct_template]
}
}

exchange_msg = {
variable_slot: {
ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,
ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:
ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params
}
}
outputs_list = [f"opt_{{{variable_slot}}}"]
outputs_mapping = ((0, 0),)
return template, exchange_msg, outputs_list, outputs_mapping

+ 32
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/floor_mapper.py View File

@@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper


class FloorMapper(ONNXToMindSporeMapper):
"""Floor mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "P.Floor"

@staticmethod
def _convert_params(**kwargs):
return dict()

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

+ 9
- 6
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py View File

@@ -37,6 +37,7 @@ class SliceMapper(ONNXToMindSporeMapper):

@staticmethod
def _generate_snippet_template(**kwargs):
"""Generate snippet template."""
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
op = kwargs.get("operation")
@@ -47,13 +48,14 @@ class SliceMapper(ONNXToMindSporeMapper):

starts = SliceMapper._find_val_by_index(0, weights)
ends = SliceMapper._find_val_by_index(1, weights)
axes = SliceMapper._find_val_by_index(2, weights, np.array([i for i in range(len(ipt_shape))]))
axes = SliceMapper._find_val_by_index(2, weights, np.array(list(range(len(ipt_shape)))))
steps = SliceMapper._find_val_by_index(3, weights, np.array([1 for _ in range(len(ipt_shape))]))

if not op:
raise ValueError("Can not get MindSpore operation name.")

if not weights:
raise ValueError("Cannot get required params from slice.")
raise ValueError("Can not get required params from slice.")

if axes.shape != (1,):
ordered_begin = sorted(zip(starts.tolist(), axes.tolist()), key=lambda x: x[1], reverse=False)
@@ -71,9 +73,9 @@ class SliceMapper(ONNXToMindSporeMapper):
end[axis] = min(ends.tolist()[0], end[axis])
strides[axis] = steps.tolist()[0]

args["begin"] = tuple(begin)
args["end"] = tuple(end)
args["strides"] = tuple(strides)
args['begin'] = tuple(begin)
args['end'] = tuple(end)
args['strides'] = tuple(strides)

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}()"
@@ -84,7 +86,8 @@ class SliceMapper(ONNXToMindSporeMapper):
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \
f"self.{{{variable_slot}}}_begin, self.{{{variable_slot}}}_end, " \
f"self.{{{variable_slot}}}_strides)"
template = reset_init_or_construct(template, variable_slot, [init_template, init_begin, init_end, init_strides],
template = reset_init_or_construct(template, variable_slot,
[init_template, init_begin, init_end, init_strides],
TemplateKeywords.INIT.value)
template = reset_init_or_construct(template, variable_slot, [construct_template],
TemplateKeywords.CONSTRUCT.value)


+ 16
- 6
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py View File

@@ -27,8 +27,8 @@ class SplitMapper(ONNXToMindSporeMapper):
@staticmethod
def _convert_params(**kwargs):
axis = kwargs["params"]["axis"]
split = kwargs["params"]["split"]
output_num = len(split)
output_shape_list = kwargs["params"]["output_shape"]
output_num = len(output_shape_list)
return {"axis": axis,
"output_num": output_num}

@@ -43,9 +43,16 @@ class SplitMapper(ONNXToMindSporeMapper):
weights = kwargs.get("weights")
if not op:
raise ValueError("Can not get MindSpore operation name.")
converted_params = kwargs["converted_params"]
output_num = converted_params["output_num"]

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
slot_list = [f"opt_{{{variable_slot}}}"]
for i in range(1, output_num): # Here `1` means the second output of the operator
slot_list.append(f"opt_{{{variable_slot}}}_{i}")
slot_final = ", ".join(slot_list)
construct_template = f"{slot_final} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})"
template = {
variable_slot: {
@@ -58,13 +65,16 @@ class SplitMapper(ONNXToMindSporeMapper):
ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,
ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:
ExchangeMessageKeywords.VariableScope.value.ARR_TYPE.value,
ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
}
}
outputs_list = [f"opt_{{{variable_slot}}}"]
outputs_mapping = ((0, 0),)
outputs_list = slot_list
outputs_mapping = []
for i in range(output_num):
outputs_mapping.append((i, i))
outputs_mapping = tuple(outputs_mapping)
return template, exchange_msg, outputs_list, outputs_mapping

+ 40
- 0
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/squeeze_mapper.py View File

@@ -0,0 +1,40 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper


class SqueezeMapper(ONNXToMindSporeMapper):
"""Squeeze mapper."""

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
return "P.Squeeze"

@staticmethod
def _convert_params(**kwargs):
params = kwargs["params"]
if len(params['axes']) == 1:
return {
"axis": params['axes'][0]
}

return {
"axis": tuple(params['axes'])
}

@staticmethod
def _convert_trained_weights(**kwargs):
return dict()

+ 4
- 1
mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json View File

@@ -30,5 +30,8 @@
"onnx::Erf": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.erf_mapper.ErfMapper",
"onnx::Pow": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.pow_mapper.PowMapper",
"onnx::Einsum": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.einsum_mapper.EinSumMapper",
"onnx::Tanh": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.tanh_mapper.TanhMapper"
"onnx::Tanh": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.tanh_mapper.TanhMapper",
"onnx::LSTM": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.lstm_mapper.LSTMMapper",
"onnx::Squeeze": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.squeeze_mapper.SqueezeMapper",
"onnx::Floor": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.floor_mapper.FloorMapper"
}

+ 15
- 5
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -23,7 +23,8 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.input_nod
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode
from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_parser import PyTorchGraphParser
from mindinsight.mindconverter.graph_based_converter.third_party_graph.tf_graph_parser import TFGraphParser
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, NodeWeight
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, \
NodeWeight, NodeOutputShape
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher import generate_scope_name

NONE_SCOPE_OP = {
@@ -139,6 +140,9 @@ class OnnxGraph(Graph):

super(OnnxGraph, self).build()
self._collect_input_shape_of_each_node()
for node_name in self._shape_dict:
if len(self._shape_dict[node_name]) == 1:
self._shape_dict[node_name] = self._shape_dict[node_name][0].node_output_shape

def _collect_input_shape_of_each_node(self):
"""
@@ -156,11 +160,17 @@ class OnnxGraph(Graph):
input_nodes[ipt].set_scope_name(node.scope_name)
node.precursor_nodes.insert(0, ipt)
input_nodes[ipt].set_successor_nodes(node_name)
self._shape_dict[ipt] = input_nodes[ipt].output_shape
output_shape_single = NodeOutputShape(ipt, None, input_nodes[ipt].output_shape)
if ipt not in self._shape_dict:
self._shape_dict.setdefault(ipt, []).append(output_shape_single)
ipt_shape = []
for p_nd in node.precursor_nodes:
shp = self._shape_dict.get(p_nd)
ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp)
for ipt_nd_name in node.ir_node_inputs:
for p_nd in node.precursor_nodes:
shp_list = self._shape_dict.get(p_nd)
for shp in shp_list:
if ipt_nd_name == shp.node_opt_name:
shp_single = shp.node_output_shape
ipt_shape.append(tuple(shp_single) if isinstance(shp_single, list) else shp_single)

self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape



+ 37
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -438,6 +438,15 @@ class OnnxDataLoader:
for tensor in tensors:
t = OnnxTensor(tensor)
self.tensors_dict[t.name] = t

idx = 0
while idx < len(self.model.graph.output):
cur_opt = self.model.graph.output[idx]
if cur_opt.name not in self.output_nodes:
self.model.graph.output.remove(cur_opt)
continue
idx += 1

self._global_context.onnx_tensors_collection = self.tensors_dict

def _parse_node_output_shape(self):
@@ -465,7 +474,8 @@ class OnnxDataLoader:
node_name = self.output_name_to_node_name[node_opt_name]
if not node_name:
raise GraphInitError(msg=f"Cannot find where edge {node_opt_name} comes from.")
self.node_output_shape_dict[node_name] = shape
node_output_shape = NodeOutputShape(node_opt_name, node_name, shape)
self.node_output_shape_dict.setdefault(node_name, []).append(node_output_shape)

def get_node(self, node_name):
"""Get the OnnxNode instance by node name."""
@@ -576,12 +586,16 @@ class OnnxDataLoader:
"""Create a PASS to optimize shape and reshape operations in ONNX ir graph."""
to_be_eliminated_op = {"Cast", "Concat", "Squeeze", "Unsqueeze", "Slice",
"Gather", "Shape"}

def _is_origin_inputs(node):
for ipt in node.input_name_list:
if ipt not in self.input_nodes:
return False
return True
def _traceback_precursor_nodes_until_shape_op(node_ref):
nonlocal self
e_nodes = []
node = self._nodes_dict[self.output_name_to_node_name[node_ref]]
if node.op_type not in to_be_eliminated_op:
if node.op_type not in to_be_eliminated_op or _is_origin_inputs(node):
return e_nodes
e_nodes.append(node.name)
for ipt in node.input_name_list:
@@ -646,3 +660,23 @@ class NodeWeight:
@property
def location(self):
return self._weight_location


class NodeOutputShape:
"""Node output shape and its name."""
def __init__(self, node_opt_name, node_name, node_output_shape):
self._node_opt_name = node_opt_name
self._node_name = node_name
self._node_output_shape = node_output_shape

@property
def node_opt_name(self):
return self._node_opt_name

@property
def node_name(self):
return self._node_name

@property
def node_output_shape(self):
return self._node_output_shape

Loading…
Cancel
Save