|
|
@@ -22,7 +22,8 @@ from typing import Union |
|
|
from mindinsight.mindconverter.common.log import logger as log |
|
|
from mindinsight.mindconverter.common.log import logger as log |
|
|
|
|
|
|
|
|
from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ |
|
|
from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ |
|
|
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT |
|
|
|
|
|
|
|
|
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL |
|
|
|
|
|
from ...common.exceptions import GraphInitFail, ModelNotSupport |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None): |
|
|
def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None): |
|
|
@@ -271,6 +272,9 @@ class OnnxDataLoader: |
|
|
self.node_name_set = set() # [str] |
|
|
self.node_name_set = set() # [str] |
|
|
self.node_output_shape_dict = OrderedDict() # {node_name: [int]} |
|
|
self.node_output_shape_dict = OrderedDict() # {node_name: [int]} |
|
|
|
|
|
|
|
|
|
|
|
# Key is edge of ONNX ir graph, value is the corresponding precursor node. |
|
|
|
|
|
self.output_name_to_node_name = dict() |
|
|
|
|
|
|
|
|
self.initialize() |
|
|
self.initialize() |
|
|
|
|
|
|
|
|
def _check_initialization(self): |
|
|
def _check_initialization(self): |
|
|
@@ -304,8 +308,24 @@ class OnnxDataLoader: |
|
|
onnx = import_module("onnx") |
|
|
onnx = import_module("onnx") |
|
|
self.inferred_model = onnx.shape_inference.infer_shapes(self.model) |
|
|
self.inferred_model = onnx.shape_inference.infer_shapes(self.model) |
|
|
|
|
|
|
|
|
def _parse_value_info(self): # no input node & output node |
|
|
|
|
|
"""Parse onnx defined value_info class attribtues""" |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def _parse_value_info_manually(value_info): |
|
|
|
|
|
"""Parse value info from onnx ir edge manually.""" |
|
|
|
|
|
tensor_proto = getattr(import_module("onnx"), "TensorProto") |
|
|
|
|
|
node_name = value_info.name |
|
|
|
|
|
node_dim = [] |
|
|
|
|
|
node_type = tensor_proto.DataType.Name(value_info.type.tensor_type.elem_type) |
|
|
|
|
|
if not value_info.type.tensor_type.shape.dim: |
|
|
|
|
|
return node_name, node_type, "".join(node_dim) |
|
|
|
|
|
|
|
|
|
|
|
for dim in value_info.type.tensor_type.shape.dim: |
|
|
|
|
|
v = dim.dim_value if dim.dim_value != 0 else UNKNOWN_DIM_VAL |
|
|
|
|
|
node_dim.append(f"{v}") |
|
|
|
|
|
|
|
|
|
|
|
return node_name, node_type, "x".join(node_dim) |
|
|
|
|
|
|
|
|
|
|
|
def _parse_value_info(self): |
|
|
|
|
|
"""Parse onnx defined value_info class attributes.""" |
|
|
onnx = import_module("onnx") |
|
|
onnx = import_module("onnx") |
|
|
|
|
|
|
|
|
def _parse_value_info_re(i): |
|
|
def _parse_value_info_re(i): |
|
|
@@ -332,9 +352,12 @@ class OnnxDataLoader: |
|
|
value_info = self.inferred_model.graph.value_info |
|
|
value_info = self.inferred_model.graph.value_info |
|
|
|
|
|
|
|
|
for v in value_info: |
|
|
for v in value_info: |
|
|
readable_info = onnx.helper.printable_value_info(v) |
|
|
|
|
|
(node_name, node_type, node_dim) = _parse_value_info_re( |
|
|
|
|
|
readable_info) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
readable_info = onnx.helper.printable_value_info(v) |
|
|
|
|
|
(node_name, node_type, node_dim) = _parse_value_info_re(readable_info) |
|
|
|
|
|
except (AssertionError, ValueError, AttributeError) as _: |
|
|
|
|
|
node_name, node_type, node_dim = self._parse_value_info_manually(v) |
|
|
|
|
|
# `node_dim` could be "" or "scalar". |
|
|
self.value_info_dict[node_name] = (node_type, node_dim) |
|
|
self.value_info_dict[node_name] = (node_type, node_dim) |
|
|
|
|
|
|
|
|
def _parse_nodes(self): |
|
|
def _parse_nodes(self): |
|
|
@@ -343,6 +366,9 @@ class OnnxDataLoader: |
|
|
n = OnnxNode(node) |
|
|
n = OnnxNode(node) |
|
|
self.nodes_dict[n.name] = n |
|
|
self.nodes_dict[n.name] = n |
|
|
self.node_name_set.add(n.name) |
|
|
self.node_name_set.add(n.name) |
|
|
|
|
|
if len(node.output) > 1: |
|
|
|
|
|
raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.") |
|
|
|
|
|
self.output_name_to_node_name[node.output[0]] = node.name |
|
|
|
|
|
|
|
|
def _parse_tensors(self): |
|
|
def _parse_tensors(self): |
|
|
"""Parse each onnx tensors in the model.""" |
|
|
"""Parse each onnx tensors in the model.""" |
|
|
@@ -359,22 +385,25 @@ class OnnxDataLoader: |
|
|
Note: |
|
|
Note: |
|
|
This function has a prerequisite of the shape inference. |
|
|
This function has a prerequisite of the shape inference. |
|
|
""" |
|
|
""" |
|
|
for (node_name, (_, shape_str)) in self.value_info_dict.items(): |
|
|
|
|
|
lst = [] |
|
|
|
|
|
|
|
|
for (node_opt_name, (_, shape_str)) in self.value_info_dict.items(): |
|
|
# split shape by 'x' |
|
|
# split shape by 'x' |
|
|
shape_list = shape_str.split('x') |
|
|
|
|
|
|
|
|
shape = shape_str.split('x') |
|
|
# replace unknown shape by '-1' |
|
|
# replace unknown shape by '-1' |
|
|
for s in shape_list: |
|
|
|
|
|
|
|
|
for i, s in enumerate(shape): |
|
|
if 'unk' in s: |
|
|
if 'unk' in s: |
|
|
if self.graph_input_shape is not None: |
|
|
|
|
|
s = self.graph_input_shape[0] |
|
|
|
|
|
else: |
|
|
|
|
|
s = '1' |
|
|
|
|
|
|
|
|
|
|
|
# convert str to int |
|
|
|
|
|
s = int(s) |
|
|
|
|
|
lst.append(s) |
|
|
|
|
|
self.node_output_shape_dict[node_name] = lst |
|
|
|
|
|
|
|
|
shape[i] = int(self.graph_input_shape[0]) if self.graph_input_shape is not None else 1 |
|
|
|
|
|
continue |
|
|
|
|
|
if s == "scalar": |
|
|
|
|
|
shape = SCALAR_WITHOUT_SHAPE |
|
|
|
|
|
continue |
|
|
|
|
|
if s == "": |
|
|
|
|
|
shape = DYNAMIC_SHAPE |
|
|
|
|
|
continue |
|
|
|
|
|
shape[i] = int(shape[i]) |
|
|
|
|
|
node_name = self.output_name_to_node_name[node_opt_name] |
|
|
|
|
|
if not node_name: |
|
|
|
|
|
raise GraphInitFail(user_msg=f"Cannot find where edge {node_opt_name} comes from.") |
|
|
|
|
|
self.node_output_shape_dict[node_name] = shape |
|
|
|
|
|
|
|
|
def get_node(self, node_name): |
|
|
def get_node(self, node_name): |
|
|
"""Get the OnnxNode instance by node name.""" |
|
|
"""Get the OnnxNode instance by node name.""" |
|
|
@@ -405,7 +434,7 @@ class OnnxDataLoader: |
|
|
for node_name, node in self.nodes_dict.items(): |
|
|
for node_name, node in self.nodes_dict.items(): |
|
|
# for each input of a node |
|
|
# for each input of a node |
|
|
for input_name in node.input_name_list: |
|
|
for input_name in node.input_name_list: |
|
|
# remove :0 in the name to ensure consistency in hierarical tree. |
|
|
|
|
|
|
|
|
# remove :0 in the name to ensure consistency in hierarchical tree. |
|
|
input_name = input_name.split(':')[0] |
|
|
input_name = input_name.split(':')[0] |
|
|
if input_name in self.node_name_set: |
|
|
if input_name in self.node_name_set: |
|
|
# input is a node |
|
|
# input is a node |
|
|
@@ -413,10 +442,9 @@ class OnnxDataLoader: |
|
|
node.precursor_onnx_node_dict[input_name] = self.get_node( |
|
|
node.precursor_onnx_node_dict[input_name] = self.get_node( |
|
|
input_name) |
|
|
input_name) |
|
|
|
|
|
|
|
|
# backtracing successor nodes |
|
|
|
|
|
|
|
|
# Back tracing successor nodes |
|
|
back_tracked_node = self.get_node(input_name) |
|
|
back_tracked_node = self.get_node(input_name) |
|
|
back_tracked_node.successor_onnx_node_dict[node_name] = self.get_node( |
|
|
|
|
|
node_name) |
|
|
|
|
|
|
|
|
back_tracked_node.successor_onnx_node_dict[node_name] = self.get_node(node_name) |
|
|
continue |
|
|
continue |
|
|
|
|
|
|
|
|
# check if nodes connected by a tensor |
|
|
# check if nodes connected by a tensor |
|
|
@@ -433,7 +461,7 @@ class OnnxDataLoader: |
|
|
node.precursor_onnx_node_dict[n_name] = self.get_node( |
|
|
node.precursor_onnx_node_dict[n_name] = self.get_node( |
|
|
n_name) |
|
|
n_name) |
|
|
|
|
|
|
|
|
# backtracing successor nodes |
|
|
|
|
|
|
|
|
# Back tracing successor nodes |
|
|
back_tracked_node = self.get_node(n_name) |
|
|
back_tracked_node = self.get_node(n_name) |
|
|
back_tracked_node.successor_onnx_node_dict[n_name] = self.get_node( |
|
|
back_tracked_node.successor_onnx_node_dict[n_name] = self.get_node( |
|
|
n_name) |
|
|
n_name) |
|
|
@@ -446,52 +474,9 @@ class OnnxDataLoader: |
|
|
if out_name == input_name: |
|
|
if out_name == input_name: |
|
|
node.precursor_onnx_node_dict[nm] = n |
|
|
node.precursor_onnx_node_dict[nm] = n |
|
|
|
|
|
|
|
|
# backtracing |
|
|
|
|
|
|
|
|
# Back tracing |
|
|
n.successor_onnx_node_dict[node_name] = node |
|
|
n.successor_onnx_node_dict[node_name] = node |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def normalize_dict_key(d): |
|
|
|
|
|
""" |
|
|
|
|
|
Normalize dictionary key. |
|
|
|
|
|
|
|
|
|
|
|
Note: |
|
|
|
|
|
The normalization is removing :0 in each node or output name. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
d (dict): Dictionary where keys are node/output names. |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
dict, normalized dictionary. |
|
|
|
|
|
""" |
|
|
|
|
|
if not isinstance(d, (dict, OrderedDict)): |
|
|
|
|
|
error_msg = "Error occurs in normalizing dictionary key.\ |
|
|
|
|
|
Object passed in is not a dictionary." |
|
|
|
|
|
error = TypeError(error_msg) |
|
|
|
|
|
log.error(error_msg) |
|
|
|
|
|
log.exception(error) |
|
|
|
|
|
raise error |
|
|
|
|
|
|
|
|
|
|
|
new_d = None |
|
|
|
|
|
if isinstance(d, dict): |
|
|
|
|
|
new_d = {} |
|
|
|
|
|
for key_old in d.keys(): |
|
|
|
|
|
key_new = key_old.split(':')[0] |
|
|
|
|
|
new_d[key_new] = d.get(key_old) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(d, OrderedDict): |
|
|
|
|
|
new_d = OrderedDict() |
|
|
|
|
|
for key_old in d.keys(): |
|
|
|
|
|
key_new = key_old.split(':')[0] |
|
|
|
|
|
new_d[key_new] = d.get(key_old) |
|
|
|
|
|
|
|
|
|
|
|
if not new_d: |
|
|
|
|
|
error_msg = "Error occurs in normalizing dictionary key." |
|
|
|
|
|
error = ValueError(error_msg) |
|
|
|
|
|
log.error(error_msg) |
|
|
|
|
|
log.exception(error) |
|
|
|
|
|
raise error |
|
|
|
|
|
return new_d |
|
|
|
|
|
|
|
|
|
|
|
def initialize(self): |
|
|
def initialize(self): |
|
|
"""Initialize the OnnxDataLoader.""" |
|
|
"""Initialize the OnnxDataLoader.""" |
|
|
|
|
|
|
|
|
@@ -516,12 +501,6 @@ class OnnxDataLoader: |
|
|
if self.inferred_model: |
|
|
if self.inferred_model: |
|
|
try: |
|
|
try: |
|
|
self._parse_value_info() |
|
|
self._parse_value_info() |
|
|
except Exception as e: |
|
|
|
|
|
log.error(str(e)) |
|
|
|
|
|
log.exception(e) |
|
|
|
|
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
self._parse_node_output_shape() |
|
|
self._parse_node_output_shape() |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
log.error(str(e)) |
|
|
log.error(str(e)) |
|
|
|