diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index 407b4e7c..8a48be18 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -33,6 +33,10 @@ ONNX_TYPE_FLOAT = 1 ONNX_TYPE_FLOATS = 6 ONNX_TYPE_STRING = 3 +DYNAMIC_SHAPE = -1 +SCALAR_WITHOUT_SHAPE = 0 +UNKNOWN_DIM_VAL = "unk__001" + BINARY_HEADER_PYTORCH_FILE = \ b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00' diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py index e5b8e8e5..b3ec981e 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py @@ -70,7 +70,7 @@ class HierarchicalTreeFactory: node_inst = graph.get_node(node_name) node_input = graph.get_input_shape(node_name) node_output = graph.get_output_shape(node_name) - if not node_input: + if node_input != 0 and not node_input: err_msg = f"This model is not supported now. " \ f"Cannot find {node_name}'s input shape." error = NodeInputMissing(err_msg) diff --git a/mindinsight/mindconverter/graph_based_converter/report_generator.py b/mindinsight/mindconverter/graph_based_converter/report_generator.py index 705fbf05..1e33c458 100644 --- a/mindinsight/mindconverter/graph_based_converter/report_generator.py +++ b/mindinsight/mindconverter/graph_based_converter/report_generator.py @@ -134,7 +134,7 @@ class ReportGenerator(metaclass=abc.ABCMeta): if 'onnx.' in code_line: num_unconverted_operator += 1 unconverted_operator = SEPARATOR_IN_ONNX_OP.join( - ('onnx', re.findall(r".*onnx.(.*)[(]", code_line)[0])) + ('onnx', re.findall(r".*onnx.([a-zA-Z]+).*", code_line)[0])) info_unconverted_line = self._gen_unconverted_operator_content( [f"{num_line + 1}", f"{code_line.index('onnx.') + 1}"], unconverted_operator diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py index e92d7035..8f75451e 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py @@ -44,6 +44,7 @@ class InputNode(GraphNode): def op_name(self): return self._op_name + @property def hash_key(self): pass diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py index 1114b95c..4848b75d 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py @@ -60,7 +60,8 @@ class OnnxGraph(Graph): self.build(sample_shape) - def _extract_shape(self, shape): + @staticmethod + def _extract_shape(shape): """ Extract shape from string-type shape. @@ -121,8 +122,7 @@ class OnnxGraph(Graph): from ..sub_graph_searcher import generate_scope_name scope_name_list = generate_scope_name(model_data) - self._shape_dict = model_data.normalize_dict_key( - model_data.node_output_shape_dict) + self._shape_dict = model_data.node_output_shape_dict for ind, (node_name, node) in enumerate(model_data.nodes_dict.items()): node_weight = {} node.scope_name = scope_name_list[ind] @@ -138,12 +138,11 @@ class OnnxGraph(Graph): node, node_weight) self._nodes_record[node_name] = node_name - for node_input in node.input_name_list: - self._build_connection(node_input, node_name) + for nd_ipt_name in node.precursor_onnx_node_dict: + self._build_connection(nd_ipt_name, node_name) super(OnnxGraph, self).build(input_shape=input_shape) - self._collect_input_shape_of_each_node( - input_shape) # diff than pyTorch + self._collect_input_shape_of_each_node(input_shape) # diff than pyTorch def _collect_input_shape_of_each_node(self, input_shape): """ @@ -165,7 +164,7 @@ class OnnxGraph(Graph): ipt_shape = [] for p_nd in node.precursor_nodes: shp = self._shape_dict.get(p_nd) - ipt_shape.append(tuple(shp)) + ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp) self._input_shape[node_name] = ipt_shape[0] if len( ipt_shape) == 1 else ipt_shape diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index 348bc658..57408516 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -22,7 +22,8 @@ from typing import Union from mindinsight.mindconverter.common.log import logger as log from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ - ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT + ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL +from ...common.exceptions import GraphInitFail, ModelNotSupport def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None): @@ -271,6 +272,9 @@ class OnnxDataLoader: self.node_name_set = set() # [str] self.node_output_shape_dict = OrderedDict() # {node_name: [int]} + # Key is edge of ONNX ir graph, value is the corresponding precursor node. + self.output_name_to_node_name = dict() + self.initialize() def _check_initialization(self): @@ -304,8 +308,24 @@ class OnnxDataLoader: onnx = import_module("onnx") self.inferred_model = onnx.shape_inference.infer_shapes(self.model) - def _parse_value_info(self): # no input node & output node - """Parse onnx defined value_info class attribtues""" + @staticmethod + def _parse_value_info_manually(value_info): + """Parse value info from onnx ir edge manually.""" + tensor_proto = getattr(import_module("onnx"), "TensorProto") + node_name = value_info.name + node_dim = [] + node_type = tensor_proto.DataType.Name(value_info.type.tensor_type.elem_type) + if not value_info.type.tensor_type.shape.dim: + return node_name, node_type, "".join(node_dim) + + for dim in value_info.type.tensor_type.shape.dim: + v = dim.dim_value if dim.dim_value != 0 else UNKNOWN_DIM_VAL + node_dim.append(f"{v}") + + return node_name, node_type, "x".join(node_dim) + + def _parse_value_info(self): + """Parse onnx defined value_info class attributes.""" onnx = import_module("onnx") def _parse_value_info_re(i): @@ -332,9 +352,12 @@ class OnnxDataLoader: value_info = self.inferred_model.graph.value_info for v in value_info: - readable_info = onnx.helper.printable_value_info(v) - (node_name, node_type, node_dim) = _parse_value_info_re( - readable_info) + try: + readable_info = onnx.helper.printable_value_info(v) + (node_name, node_type, node_dim) = _parse_value_info_re(readable_info) + except (AssertionError, ValueError, AttributeError) as _: + node_name, node_type, node_dim = self._parse_value_info_manually(v) + # `node_dim` could be "" or "scalar". self.value_info_dict[node_name] = (node_type, node_dim) def _parse_nodes(self): @@ -343,6 +366,9 @@ class OnnxDataLoader: n = OnnxNode(node) self.nodes_dict[n.name] = n self.node_name_set.add(n.name) + if len(node.output) > 1: + raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.") + self.output_name_to_node_name[node.output[0]] = node.name def _parse_tensors(self): """Parse each onnx tensors in the model.""" @@ -359,22 +385,25 @@ class OnnxDataLoader: Note: This function has a prerequisite of the shape inference. """ - for (node_name, (_, shape_str)) in self.value_info_dict.items(): - lst = [] + for (node_opt_name, (_, shape_str)) in self.value_info_dict.items(): # split shape by 'x' - shape_list = shape_str.split('x') + shape = shape_str.split('x') # replace unknown shape by '-1' - for s in shape_list: + for i, s in enumerate(shape): if 'unk' in s: - if self.graph_input_shape is not None: - s = self.graph_input_shape[0] - else: - s = '1' - - # convert str to int - s = int(s) - lst.append(s) - self.node_output_shape_dict[node_name] = lst + shape[i] = int(self.graph_input_shape[0]) if self.graph_input_shape is not None else 1 + continue + if s == "scalar": + shape = SCALAR_WITHOUT_SHAPE + continue + if s == "": + shape = DYNAMIC_SHAPE + continue + shape[i] = int(shape[i]) + node_name = self.output_name_to_node_name[node_opt_name] + if not node_name: + raise GraphInitFail(user_msg=f"Cannot find where edge {node_opt_name} comes from.") + self.node_output_shape_dict[node_name] = shape def get_node(self, node_name): """Get the OnnxNode instance by node name.""" @@ -405,7 +434,7 @@ class OnnxDataLoader: for node_name, node in self.nodes_dict.items(): # for each input of a node for input_name in node.input_name_list: - # remove :0 in the name to ensure consistency in hierarical tree. + # remove :0 in the name to ensure consistency in hierarchical tree. input_name = input_name.split(':')[0] if input_name in self.node_name_set: # input is a node @@ -413,10 +442,9 @@ class OnnxDataLoader: node.precursor_onnx_node_dict[input_name] = self.get_node( input_name) - # backtracing successor nodes + # Back tracing successor nodes back_tracked_node = self.get_node(input_name) - back_tracked_node.successor_onnx_node_dict[node_name] = self.get_node( - node_name) + back_tracked_node.successor_onnx_node_dict[node_name] = self.get_node(node_name) continue # check if nodes connected by a tensor @@ -433,7 +461,7 @@ class OnnxDataLoader: node.precursor_onnx_node_dict[n_name] = self.get_node( n_name) - # backtracing successor nodes + # Back tracing successor nodes back_tracked_node = self.get_node(n_name) back_tracked_node.successor_onnx_node_dict[n_name] = self.get_node( n_name) @@ -446,52 +474,9 @@ class OnnxDataLoader: if out_name == input_name: node.precursor_onnx_node_dict[nm] = n - # backtracing + # Back tracing n.successor_onnx_node_dict[node_name] = node - @staticmethod - def normalize_dict_key(d): - """ - Normalize dictionary key. - - Note: - The normalization is removing :0 in each node or output name. - - Args: - d (dict): Dictionary where keys are node/output names. - - Returns: - dict, normalized dictionary. - """ - if not isinstance(d, (dict, OrderedDict)): - error_msg = "Error occurs in normalizing dictionary key.\ - Object passed in is not a dictionary." - error = TypeError(error_msg) - log.error(error_msg) - log.exception(error) - raise error - - new_d = None - if isinstance(d, dict): - new_d = {} - for key_old in d.keys(): - key_new = key_old.split(':')[0] - new_d[key_new] = d.get(key_old) - - if isinstance(d, OrderedDict): - new_d = OrderedDict() - for key_old in d.keys(): - key_new = key_old.split(':')[0] - new_d[key_new] = d.get(key_old) - - if not new_d: - error_msg = "Error occurs in normalizing dictionary key." - error = ValueError(error_msg) - log.error(error_msg) - log.exception(error) - raise error - return new_d - def initialize(self): """Initialize the OnnxDataLoader.""" @@ -516,12 +501,6 @@ class OnnxDataLoader: if self.inferred_model: try: self._parse_value_info() - except Exception as e: - log.error(str(e)) - log.exception(e) - raise e - - try: self._parse_node_output_shape() except Exception as e: log.error(str(e))