Browse Source

!965 Fix ONNX ir parse error

From: @liuchongming74
Reviewed-by: @ouwenchang,@lilongfei15
Signed-off-by: @lilongfei15
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
34c70d3733
6 changed files with 67 additions and 84 deletions
  1. +4
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  2. +1
    -1
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py
  3. +1
    -1
      mindinsight/mindconverter/graph_based_converter/report_generator.py
  4. +1
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py
  5. +7
    -8
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  6. +53
    -74
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py

+ 4
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -33,6 +33,10 @@ ONNX_TYPE_FLOAT = 1
ONNX_TYPE_FLOATS = 6
ONNX_TYPE_STRING = 3

DYNAMIC_SHAPE = -1
SCALAR_WITHOUT_SHAPE = 0
UNKNOWN_DIM_VAL = "unk__001"

BINARY_HEADER_PYTORCH_FILE = \
b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00'



+ 1
- 1
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py View File

@@ -70,7 +70,7 @@ class HierarchicalTreeFactory:
node_inst = graph.get_node(node_name)
node_input = graph.get_input_shape(node_name)
node_output = graph.get_output_shape(node_name)
if not node_input:
if node_input != 0 and not node_input:
err_msg = f"This model is not supported now. " \
f"Cannot find {node_name}'s input shape."
error = NodeInputMissing(err_msg)


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/report_generator.py View File

@@ -134,7 +134,7 @@ class ReportGenerator(metaclass=abc.ABCMeta):
if 'onnx.' in code_line:
num_unconverted_operator += 1
unconverted_operator = SEPARATOR_IN_ONNX_OP.join(
('onnx', re.findall(r".*onnx.(.*)[(]", code_line)[0]))
('onnx', re.findall(r".*onnx.([a-zA-Z]+).*", code_line)[0]))
info_unconverted_line = self._gen_unconverted_operator_content(
[f"{num_line + 1}", f"{code_line.index('onnx.') + 1}"],
unconverted_operator


+ 1
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py View File

@@ -44,6 +44,7 @@ class InputNode(GraphNode):
def op_name(self):
return self._op_name

@property
def hash_key(self):
pass



+ 7
- 8
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -60,7 +60,8 @@ class OnnxGraph(Graph):

self.build(sample_shape)

def _extract_shape(self, shape):
@staticmethod
def _extract_shape(shape):
"""
Extract shape from string-type shape.

@@ -121,8 +122,7 @@ class OnnxGraph(Graph):
from ..sub_graph_searcher import generate_scope_name
scope_name_list = generate_scope_name(model_data)

self._shape_dict = model_data.normalize_dict_key(
model_data.node_output_shape_dict)
self._shape_dict = model_data.node_output_shape_dict
for ind, (node_name, node) in enumerate(model_data.nodes_dict.items()):
node_weight = {}
node.scope_name = scope_name_list[ind]
@@ -138,12 +138,11 @@ class OnnxGraph(Graph):
node, node_weight)
self._nodes_record[node_name] = node_name

for node_input in node.input_name_list:
self._build_connection(node_input, node_name)
for nd_ipt_name in node.precursor_onnx_node_dict:
self._build_connection(nd_ipt_name, node_name)

super(OnnxGraph, self).build(input_shape=input_shape)
self._collect_input_shape_of_each_node(
input_shape) # diff than pyTorch
self._collect_input_shape_of_each_node(input_shape) # diff than pyTorch

def _collect_input_shape_of_each_node(self, input_shape):
"""
@@ -165,7 +164,7 @@ class OnnxGraph(Graph):
ipt_shape = []
for p_nd in node.precursor_nodes:
shp = self._shape_dict.get(p_nd)
ipt_shape.append(tuple(shp))
ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp)

self._input_shape[node_name] = ipt_shape[0] if len(
ipt_shape) == 1 else ipt_shape


+ 53
- 74
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -22,7 +22,8 @@ from typing import Union
from mindinsight.mindconverter.common.log import logger as log

from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL
from ...common.exceptions import GraphInitFail, ModelNotSupport


def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None):
@@ -271,6 +272,9 @@ class OnnxDataLoader:
self.node_name_set = set() # [str]
self.node_output_shape_dict = OrderedDict() # {node_name: [int]}

# Key is edge of ONNX ir graph, value is the corresponding precursor node.
self.output_name_to_node_name = dict()

self.initialize()

def _check_initialization(self):
@@ -304,8 +308,24 @@ class OnnxDataLoader:
onnx = import_module("onnx")
self.inferred_model = onnx.shape_inference.infer_shapes(self.model)

def _parse_value_info(self): # no input node & output node
"""Parse onnx defined value_info class attribtues"""
@staticmethod
def _parse_value_info_manually(value_info):
"""Parse value info from onnx ir edge manually."""
tensor_proto = getattr(import_module("onnx"), "TensorProto")
node_name = value_info.name
node_dim = []
node_type = tensor_proto.DataType.Name(value_info.type.tensor_type.elem_type)
if not value_info.type.tensor_type.shape.dim:
return node_name, node_type, "".join(node_dim)

for dim in value_info.type.tensor_type.shape.dim:
v = dim.dim_value if dim.dim_value != 0 else UNKNOWN_DIM_VAL
node_dim.append(f"{v}")

return node_name, node_type, "x".join(node_dim)

def _parse_value_info(self):
"""Parse onnx defined value_info class attributes."""
onnx = import_module("onnx")

def _parse_value_info_re(i):
@@ -332,9 +352,12 @@ class OnnxDataLoader:
value_info = self.inferred_model.graph.value_info

for v in value_info:
readable_info = onnx.helper.printable_value_info(v)
(node_name, node_type, node_dim) = _parse_value_info_re(
readable_info)
try:
readable_info = onnx.helper.printable_value_info(v)
(node_name, node_type, node_dim) = _parse_value_info_re(readable_info)
except (AssertionError, ValueError, AttributeError) as _:
node_name, node_type, node_dim = self._parse_value_info_manually(v)
# `node_dim` could be "" or "scalar".
self.value_info_dict[node_name] = (node_type, node_dim)

def _parse_nodes(self):
@@ -343,6 +366,9 @@ class OnnxDataLoader:
n = OnnxNode(node)
self.nodes_dict[n.name] = n
self.node_name_set.add(n.name)
if len(node.output) > 1:
raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.")
self.output_name_to_node_name[node.output[0]] = node.name

def _parse_tensors(self):
"""Parse each onnx tensors in the model."""
@@ -359,22 +385,25 @@ class OnnxDataLoader:
Note:
This function has a prerequisite of the shape inference.
"""
for (node_name, (_, shape_str)) in self.value_info_dict.items():
lst = []
for (node_opt_name, (_, shape_str)) in self.value_info_dict.items():
# split shape by 'x'
shape_list = shape_str.split('x')
shape = shape_str.split('x')
# replace unknown shape by '-1'
for s in shape_list:
for i, s in enumerate(shape):
if 'unk' in s:
if self.graph_input_shape is not None:
s = self.graph_input_shape[0]
else:
s = '1'

# convert str to int
s = int(s)
lst.append(s)
self.node_output_shape_dict[node_name] = lst
shape[i] = int(self.graph_input_shape[0]) if self.graph_input_shape is not None else 1
continue
if s == "scalar":
shape = SCALAR_WITHOUT_SHAPE
continue
if s == "":
shape = DYNAMIC_SHAPE
continue
shape[i] = int(shape[i])
node_name = self.output_name_to_node_name[node_opt_name]
if not node_name:
raise GraphInitFail(user_msg=f"Cannot find where edge {node_opt_name} comes from.")
self.node_output_shape_dict[node_name] = shape

def get_node(self, node_name):
"""Get the OnnxNode instance by node name."""
@@ -405,7 +434,7 @@ class OnnxDataLoader:
for node_name, node in self.nodes_dict.items():
# for each input of a node
for input_name in node.input_name_list:
# remove :0 in the name to ensure consistency in hierarical tree.
# remove :0 in the name to ensure consistency in hierarchical tree.
input_name = input_name.split(':')[0]
if input_name in self.node_name_set:
# input is a node
@@ -413,10 +442,9 @@ class OnnxDataLoader:
node.precursor_onnx_node_dict[input_name] = self.get_node(
input_name)

# backtracing successor nodes
# Back tracing successor nodes
back_tracked_node = self.get_node(input_name)
back_tracked_node.successor_onnx_node_dict[node_name] = self.get_node(
node_name)
back_tracked_node.successor_onnx_node_dict[node_name] = self.get_node(node_name)
continue

# check if nodes connected by a tensor
@@ -433,7 +461,7 @@ class OnnxDataLoader:
node.precursor_onnx_node_dict[n_name] = self.get_node(
n_name)

# backtracing successor nodes
# Back tracing successor nodes
back_tracked_node = self.get_node(n_name)
back_tracked_node.successor_onnx_node_dict[n_name] = self.get_node(
n_name)
@@ -446,52 +474,9 @@ class OnnxDataLoader:
if out_name == input_name:
node.precursor_onnx_node_dict[nm] = n

# backtracing
# Back tracing
n.successor_onnx_node_dict[node_name] = node

@staticmethod
def normalize_dict_key(d):
"""
Normalize dictionary key.

Note:
The normalization is removing :0 in each node or output name.

Args:
d (dict): Dictionary where keys are node/output names.

Returns:
dict, normalized dictionary.
"""
if not isinstance(d, (dict, OrderedDict)):
error_msg = "Error occurs in normalizing dictionary key.\
Object passed in is not a dictionary."
error = TypeError(error_msg)
log.error(error_msg)
log.exception(error)
raise error

new_d = None
if isinstance(d, dict):
new_d = {}
for key_old in d.keys():
key_new = key_old.split(':')[0]
new_d[key_new] = d.get(key_old)

if isinstance(d, OrderedDict):
new_d = OrderedDict()
for key_old in d.keys():
key_new = key_old.split(':')[0]
new_d[key_new] = d.get(key_old)

if not new_d:
error_msg = "Error occurs in normalizing dictionary key."
error = ValueError(error_msg)
log.error(error_msg)
log.exception(error)
raise error
return new_d

def initialize(self):
"""Initialize the OnnxDataLoader."""

@@ -516,12 +501,6 @@ class OnnxDataLoader:
if self.inferred_model:
try:
self._parse_value_info()
except Exception as e:
log.error(str(e))
log.exception(e)
raise e

try:
self._parse_node_output_shape()
except Exception as e:
log.error(str(e))


Loading…
Cancel
Save