diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index f4a7fec2..70999bea 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -20,6 +20,7 @@ from importlib.util import find_spec import mindinsight from mindinsight.mindconverter.common.log import logger as log from .mapper import ONNXToMindSporeMapper +from ..common.exceptions import NodeTypeNotSupport permissions = os.R_OK | os.W_OK | os.X_OK os.umask(permissions << 3 | permissions) @@ -92,7 +93,13 @@ def graph_based_converter(graph_path: str, sample_shape: tuple, graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, checkpoint=checkpoint_path) - hierarchical_tree = HierarchicalTreeFactory.create(graph_obj) + try: + hierarchical_tree = HierarchicalTreeFactory.create(graph_obj) + except Exception as e: + log.exception(e) + log.error("Error occur when create hierarchical tree.") + raise NodeTypeNotSupport("This model is not supported now.") + hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, report_folder=report_folder) diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py index 69c9abb0..71db3fe6 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Hierarchical tree module.""" +from mindinsight.mindconverter.common.log import logger as log from .hierarchical_tree import HierarchicalTree __all__ = [ @@ -35,23 +36,14 @@ class HierarchicalTreeFactory: HierarchicalTree, tree. """ tree = HierarchicalTree() - node_input = None for _, node_name in enumerate(graph.nodes_in_topological_order): node_inst = graph.get_node(node_name) + node_input = graph.get_input_shape(node_name) node_output = graph.get_output_shape(node_name) - if node_inst.in_degree == 0: - # If in-degree equals to zero, then it's a input node. - continue - - # If the node is on the top, then fetch its input - # from input table. - if not node_input: - node_input = graph.get_input_shape(node_name) - if not node_input: - raise ValueError(f"This model is not supported now. " - f"Cannot find {node_name}'s input shape.") + err_msg = f"This model is not supported now. " \ + f"Cannot find {node_name}'s input shape." + log.error(err_msg) tree.insert(node_inst, node_name, node_input, node_output) - node_input = node_output return tree diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py index 1497ee15..12772434 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/hierarchical_tree.py @@ -31,6 +31,7 @@ from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVE from ..constant import NEW_LINE, SECOND_LEVEL_INDENT from ..constant import NodeType from ..report_generator import ReportGenerator +from ...common.exceptions import NodeTypeNotSupport GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() @@ -56,7 +57,7 @@ class HierarchicalTree(Tree): # Manage module name to used. self._module_mgr = ModuleNameMgr() # Manage variable name in a module. - self._args_mgr_in_module = dict() + self._vars_mgr_in_module = dict() self._module_vars = dict() @property @@ -86,7 +87,7 @@ class HierarchicalTree(Tree): parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1]) try_parent = f"{parent}{SEPARATOR_IN_SCOPE}{scope}" \ - if not parent else scope + if parent else scope if self.contains(try_parent): # Whether current node existed. parent = try_parent @@ -132,6 +133,8 @@ class HierarchicalTree(Tree): """ Shrink sub-tree into one node. + Use child node to replace its ancestor. + Args: node (Node): List of nodes to be merged. @@ -140,6 +143,8 @@ class HierarchicalTree(Tree): parent_node = self[node.predecessor(self.tree_identifier)] # Keep successors of parent. brothers = deepcopy(parent_node.successors(self.tree_identifier)) + # Because shrink occurs when node has only one child, + # so we take index-0. child = node.successors(self.tree_identifier)[0] self.move_node(source=child, destination=node.predecessor(self.tree_identifier)) @@ -158,9 +163,13 @@ class HierarchicalTree(Tree): out_folder (str): Output folder. """ - self._adjust_structure() - - code_fragments = self._generate_codes(mapper) + try: + self._adjust_structure() + code_fragments = self._generate_codes(mapper) + except Exception as e: + log.exception(e) + log.error("Error occur when create hierarchical tree.") + raise NodeTypeNotSupport("This model is not supported now.") out_folder = os.path.abspath(out_folder) if not report_folder: @@ -176,9 +185,8 @@ class HierarchicalTree(Tree): for file_name in code_fragments: code, report = code_fragments[file_name] try: - with os.fdopen( - os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"), - self.flags, self.modes), 'w') as file: + with os.fdopen(os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"), + self.flags, self.modes), "w") as file: file.write(code) except IOError as error: log.error(str(error)) @@ -186,9 +194,8 @@ class HierarchicalTree(Tree): raise error try: - with os.fdopen( - os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"), - self.flags, stat.S_IRUSR), "w") as rpt_f: + with os.fdopen(os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"), + self.flags, stat.S_IRUSR), "w") as rpt_f: rpt_f.write(report) except IOError as error: log.error(str(error)) @@ -223,7 +230,8 @@ class HierarchicalTree(Tree): Returns: Node, node. """ - if node.data.node_type == NodeType.MODULE.value: + if node.data.node_type in {NodeType.MODULE.value, NodeType.CLASS.value, + NodeType.FUNC.value}: # If current node is class or function, then # remove unused args in __init__. cur_module_key = node.data.hash_key or self.hash_key(node) @@ -231,13 +239,15 @@ class HierarchicalTree(Tree): node = self._clear_unused_args(node, self._merged_module_args[cur_module_key]) + # `self._merged_module_args` records formal args. + # We need to replace actual args. if precursor_module_key in self._merged_module_args: # If parent node is in `_merged_module_args`, then # replace current node args with arg name declared # in _merged_module_args. for arg in node.data.args_in_code.keys(): if arg in self._merged_module_args[precursor_module_key]: - node.data.replace_with_arg(arg) + node.data.replace_with_arg(arg, arg) return node @staticmethod @@ -254,7 +264,8 @@ class HierarchicalTree(Tree): """ args_in_code = list(node.data.args_in_code.keys()) for arg in args_in_code: - if arg not in used_args: + ori_arg = arg.replace(f"_{node.data.variable_name}", "") + if ori_arg not in used_args: node.data.args_in_code.pop(arg) return node @@ -287,9 +298,11 @@ class HierarchicalTree(Tree): if node.data.node_type == NodeType.MODULE.value: self._create_module_args_and_vars(node, mapper) - # 2. Get nodes can be merged. - self._module_merging(node_collection) + # Module merging based on all nodes. + self._module_merging() + for depth in depths: + node_collection = self._hierarchical_order[depth] snippets = set() for node_name in node_collection: nd_inst = self.get_node(node_name) @@ -297,8 +310,7 @@ class HierarchicalTree(Tree): continue # Generate hash key for node. - module_key = self.hash_key(nd_inst) - + module_key = nd_inst.data.hash_key # Get code generation func. func, node_type = self._fetch_func_and_type(nd_inst) @@ -325,9 +337,8 @@ class HierarchicalTree(Tree): # 3. Pre-process node args. nd_inst = self._preprocess_node_args(nd_inst, module_key) # 4. Post-process child node args. - for scsr_nd_name in nd_inst.successors(self.tree_identifier): - self._postprocess_node_args(self.get_node(scsr_nd_name), - module_key) + for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)): + self._postprocess_node_args(self.get_node(scsr_nd_name), module_key) # 5. Generate code. snippets.add(func(nd_inst, nd_inst.data.module_name, module_key)) @@ -335,7 +346,6 @@ class HierarchicalTree(Tree): formatted_code, _ = FormatCode("".join(code_blocks), style_config=CodeFormatConfig.PEP8.value) - report_generator = ReportGenerator() report = report_generator.gen_report(formatted_code) @@ -403,9 +413,9 @@ class HierarchicalTree(Tree): "output_shape": c_nd.data.output_shape}) # Generate code statement. - expr = ", ".join([f"{k}={v}" for k, v in args.items()]) + expr = ", ".join([f"{k.replace(f'_{c_nd.data.variable_name}', '')}={v}" + for k, v in args.items()]) code_line = f"{operator}({expr})" - module_list.append(code_line) body = f",{NEW_LINE}{SECOND_LEVEL_INDENT}".join(module_list) @@ -435,6 +445,7 @@ class HierarchicalTree(Tree): if class_key.lower() in self._merged_module_args and \ self._merged_module_args[class_key.lower()]: args = f"{', '.join(self._merged_module_args[class_key.lower()])}" + class_init = f"{FIRST_LEVEL_INDENT}def __init__(self, " \ f"{args}):" \ f"{NEW_LINE}{SECOND_LEVEL_INDENT}" \ @@ -455,7 +466,8 @@ class HierarchicalTree(Tree): construct_block.append(construct) init_block.append(init) - class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):{NEW_LINE}{SECOND_LEVEL_INDENT}" + class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):" \ + f"{NEW_LINE}{SECOND_LEVEL_INDENT}" init_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(init_block) csrt_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(construct_block) csrt_rtn = f"{NEW_LINE}{SECOND_LEVEL_INDENT}return output{NEW_LINE}" @@ -514,7 +526,7 @@ class HierarchicalTree(Tree): def _find_all_previous_opt_var_(self, cur_nd, pre_nd): """ - Find all input varian names. + Find all input variable names. Args: cur_nd (Node): Current node. @@ -585,61 +597,46 @@ class HierarchicalTree(Tree): node.data.hash_key = unique_key return unique_key - def _module_merging(self, nodes): - """ - Generate sub-module and corresponding params. - - Args: - nodes (List[str]): Nodes name. - - """ - merged_module = dict() + def _module_merging(self): + """Generate sub-module and corresponding params.""" merged_module_args = dict() - for node_name in nodes: - nd_inst = self.get_node(node_name) - if nd_inst.data.node_type != NodeType.MODULE.value: - continue - - module_key = self.hash_key(nd_inst) - if module_key not in merged_module: - merged_module[module_key] = [nd_inst.data.args_in_code] - else: - merged_module[module_key].append(nd_inst.data.args_in_code) - - for module_key, module_args in merged_module.items(): + for module_key, module_args in self._merged_module.items(): if module_key not in merged_module_args: merged_module_args[module_key] = [] # Take first element's args as base. keys = module_args[0].keys() for key in keys: for i in range(1, len(module_args)): - if module_args[0][key] != module_args[i][key]: + if key in module_args[i] and module_args[0][key] != module_args[i][key]: + merged_module_args[module_key].append(key) + break + if key not in module_args[i]: merged_module_args[module_key].append(key) break - self._merged_module.update(merged_module) self._merged_module_args.update(merged_module_args) def _create_module_args_and_vars(self, node, mapper): """ - Create module args. + Create module args and variables in current node. Args: node (Node): Node on tree. mapper (Mapper): Mapper of params. """ + # All args and value pair in current node module. module_args = dict() module_key = self.hash_key(node) - created = False - if module_key not in self._args_mgr_in_module: - self._args_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR + if module_key not in self._vars_mgr_in_module: + self._vars_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR self._module_vars[module_key] = [] else: created = True + # Sub-modules in the module could have arg name conflicts. for idx, successor_name in enumerate(node.successors(self.tree_identifier)): nd_inst = self.get_node(successor_name) # Generate variable name here, then @@ -648,12 +645,11 @@ class HierarchicalTree(Tree): nd_inst.data.variable_name = self._module_vars[module_key][idx] else: variable_name = nd_inst.data.op_name or nd_inst.data.module_name - variable_name = self._args_mgr_in_module[module_key].get_name(variable_name) + variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name) nd_inst.data.variable_name = variable_name - if nd_inst.data.node_type == NodeType.OPERATION.value: - # Generation of params must behind variable assigment. - nd_inst.data.param_transform(mapper) + # Generation of params must behind variable assigment. + nd_inst.data.param_transform(mapper) module_args.update(nd_inst.data.args_in_code) @@ -662,6 +658,12 @@ class HierarchicalTree(Tree): node.data.args_in_code = module_args + # Collect module args of `module_key`. + if module_key not in self._merged_module: + self._merged_module[module_key] = [node.data.args_in_code] + else: + self._merged_module[module_key].append(node.data.args_in_code) + @staticmethod def _create_operation_args(node, mapper): """ @@ -692,21 +694,20 @@ class HierarchicalTree(Tree): self._hierarchical_order = hierarchical_order def sub_graph_merging(self) -> NoReturn: - """ - Shrink subtree. - """ + """Shrink the module has only one child.""" self.update_hierarchical_order() depths = sorted(list(self._hierarchical_order.keys()), reverse=True) for depth in depths: for node_name in self._hierarchical_order[depth]: node_inst = self[node_name] - if not node_inst.data and len(node_inst.successors(self.tree_identifier)) == 1: + # If the node type is module and has only one child, + # then merge it with its child. + if node_inst.data.node_type == NodeType.MODULE.value and \ + len(node_inst.successors(self.tree_identifier)) == 1: self.shrink(node_inst) def _adjust_structure(self) -> NoReturn: - """ - Adjust tree structure to generate source code. - """ + """Adjust tree structure to generate source code.""" self.sub_graph_merging() self.update_hierarchical_order() diff --git a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py index 83605dd6..62d8aab4 100644 --- a/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py +++ b/mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py @@ -53,6 +53,9 @@ class ModuleNameMgr(NameMgr): """Module name manager.""" +# Manage variable name of different modules. +global_var_namespace = set() +# Manage variable name of different type. global_op_namespace = dict() START_IDX = 0 @@ -81,14 +84,21 @@ class GlobalVarNameMgr: Returns: str, module name. """ - op_type = op_type.lower() - if op_type not in global_op_namespace: - global_op_namespace[op_type] = START_IDX - suffix = "" - else: - global_op_namespace[op_type] += 1 - suffix = f"{global_op_namespace[op_type] - 1}" - new_name = f"{self._get_name(op_type)}{suffix}" + def _gen(t): + t = t.lower() + if t not in global_op_namespace: + global_op_namespace[t] = START_IDX + suffix = "" + else: + global_op_namespace[t] += 1 + suffix = f"{global_op_namespace[t] - 1}" + + return f"{self._get_name(t)}{suffix}" + + new_name = _gen(op_type) + while new_name in global_var_namespace: + new_name = _gen(op_type) + global_var_namespace.add(new_name) return new_name diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/base.py b/mindinsight/mindconverter/graph_based_converter/mapper/base.py index 45b7af28..6a9fafbd 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/base.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/base.py @@ -18,6 +18,7 @@ import importlib import json import os from typing import Dict +from mindinsight.mindconverter.common.log import logger as log CONFIG_JSON = "onnx_to_ms.json" OPERATION_TABLE = os.path.join( @@ -91,7 +92,8 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): weights_converter = getattr(converter, GET_OP_WEIGHTS) except (ModuleNotFoundError,) as e: # If mapper can not be found, then skip it. - print(f"Converting {op_name} failed, see {e}") + err_msg = f"Converting {op_name} failed, see {str(e)}" + log.error(err_msg) return None, dict() try: @@ -99,8 +101,9 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): converted_params = params_converter(params, weights) converted_weights = weights_converter(weights) if weights else dict() converted_params.update(converted_weights) - except (AttributeError, KeyError, ValueError, TypeError) as _: - print(f"Converting {op_name} failed.") + except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: + err_msg = f"Converting {op_name} failed, see {str(e)}" + log.error(err_msg) return None, dict() return converter_name, converted_params diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py index 41a7c37b..0fc8357b 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py @@ -39,7 +39,7 @@ class ConvMapper(ONNXToMindSporeMapper): else: stride = params['strides'] kernel_shape = list(weight.shape) - in_channels = kernel_shape[-2] + in_channels = kernel_shape[-2] * params.get("group", 1) out_channels = kernel_shape[-1] kernel_size = kernel_shape[:-2] if len(kernel_size) == 1: diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py index ac3919a4..4d8c8cf5 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/global_pool_mapper.py @@ -31,8 +31,7 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): @staticmethod def _convert_params(params, weights): - dim = 1 if len(params['input_shape']) == 3\ - else 2 + dim = 1 if len(params['input_shape']) == 3 else 2 if dim == 1: kernel_size = params['input_shape'][-1] // params['output_shape'][-1] else: diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py index c2e6011c..3d48e403 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py @@ -100,6 +100,18 @@ class Graph(BaseGraph, abc.ABC): self._topological_order = [] self._input_shape = dict() + def get_input_shape(self, name): + """ + Get node input shape. + + Args: + name (str): Node name. + + Returns: + list, shape. + """ + return self._input_shape.get(name) + def get_output_shape(self, name): """ Get node output shape. @@ -112,7 +124,7 @@ class Graph(BaseGraph, abc.ABC): """ return self._shape_dict.get(name) - def get_input_shape(self, name): + def get_input_shape_from_input(self, name): """ Get node input shape. @@ -482,7 +494,7 @@ class GraphNode(abc.ABC): """Return op_name.""" @abc.abstractmethod - def replace_with_arg(self, arg): + def replace_with_arg(self, src_arg, tgt_arg): """Replace actual parameter with formal parameter.""" @abc.abstractmethod diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py index 110c9f89..5d9f4a47 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py @@ -53,7 +53,7 @@ class InputNode(GraphNode): def hash_key(self): pass - def replace_with_arg(self, arg): + def replace_with_arg(self, src_arg, tgt_arg): pass def _get_arg_name(self, arg): @@ -65,9 +65,30 @@ class InputNode(GraphNode): def __init__(self, input_shape): super(InputNode, self).__init__(node=None) self._op_name = 'Input' - self._op_params = {'node_shape': input_shape} + self._op_params = {'input_shape': input_shape, + "output_shape": input_shape} self._node_type = NodeType.INPUT.value + @property + def input_shape(self): + """ + Input tensor shape of current node. + + Returns: + tuple, tensor shape of input. + """ + return self._op_params["input_shape"] + + @property + def output_shape(self): + """ + Output tensor shape. + + Returns: + tuple, output tensor shape. + """ + return self._op_params["output_shape"] + def set_scope_name(self, original_input_scope_name): """ Set scope name. diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py index 36b33ca9..e86ba252 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== """Define PyTorch graph.""" -import warnings import re from typing import Dict, NoReturn @@ -27,8 +26,11 @@ from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE from ..constant import LEFT_BUCKET, RIGHT_BUCKET NONE_SCOPE_OP = { - 'onnx::Add': 'Add', - 'onnx::Flatten': 'Flatten', + "onnx::Add": "Add", + "onnx::Flatten": "Flatten", + "onnx::Concat": "Concat", + "onnx::Squeeze": "Squeeze", + "onnx::Unsqueeze": "Unsqueeze", } @@ -59,6 +61,7 @@ def normalize_scope_name(node): scopes.append(segment) if node.kind() in NONE_SCOPE_OP.keys(): scopes.append(NONE_SCOPE_OP[node.kind()]) + scopes = [s for s in scopes if s] return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{PyTorchGraph.get_node_id(node)}" @@ -90,18 +93,16 @@ class PyTorchGraph(Graph): """ if not input_shape: - error = ValueError("`input_shape` can not be None.") - log.error(str(error)) - log.exception(error) - raise error + err_msg = "`input_shape` can not be None." + log.error(err_msg) + raise ValueError(err_msg) for item in input_shape: if not isinstance(item, int): - error = ValueError(f"Only support model with one input now, " - f"and each shape value in `input_shape` should be int.") - log.error(str(error)) - log.exception(error) - raise error + err_msg = f"Only support model with one input now, " \ + f"and each shape value in `input_shape` should be int." + log.error(err_msg) + raise ValueError(err_msg) @staticmethod def _extract_shape(shape): @@ -116,18 +117,29 @@ class PyTorchGraph(Graph): """ if "," not in shape: return [] + + shape_arr = [] for s in shape.split(","): + s = s.strip() if not s: return [] - return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')] + if ":" in s: + s = s.split(":")[0] + s = s.replace("!", "") + if not s.isdigit(): + return [] + shape_arr.append(int(s)) + return shape_arr - def build(self, input_shape): + def _trace_torch_graph(self, input_shape): """ - Build graph tree. + Trace torch computational graph. Args: - input_shape (tuple): Input shape of model. + input_shape (tuple): Shape. + Returns: + object, pytorch graph. """ import torch from torch.onnx import OperatorExportTypes @@ -135,24 +147,34 @@ class PyTorchGraph(Graph): from .torch_utils import create_autograd_variable from .torch_utils import onnx_tracer - self._check_input_shape(input_shape) - - feed_forward_ipt_shape = (1, *input_shape) - batched_sample = create_autograd_variable(torch.rand(*feed_forward_ipt_shape)) - - # Assign execution mode to eval. - self.model.eval() + batched_sample = create_autograd_variable(torch.rand(*input_shape)) try: + # Assign execution mode to eval. + self.model.eval() + with OverloadTorchModuleTemporarily() as _: # In pytorch higher version, trace function has a known. graph = onnx_tracer(self.model, batched_sample, OperatorExportTypes.ONNX) + return graph except RuntimeError as error: log.error(str(error)) log.exception(error) raise error + def build(self, input_shape): + """ + Build graph tree. + + Args: + input_shape (tuple): Input shape of model. + + """ + self._check_input_shape(input_shape) + + feed_forward_ipt_shape = (1, *input_shape) + graph = self._trace_torch_graph(feed_forward_ipt_shape) nodes = list(graph.nodes()) for node in nodes: @@ -174,24 +196,43 @@ class PyTorchGraph(Graph): for node_input in list(node.inputs()): # Connect input node and src node. - if PyTorchGraph.get_node_id(node_input.node()) and node_input.node().scopeName(): + nd_id = PyTorchGraph.get_node_id(node_input.node()) + nd_scope_name = node_input.node().kind() in NONE_SCOPE_OP or \ + node_input.node().scopeName() + + if nd_id and nd_scope_name: node_input_name = normalize_scope_name( node_input.node() ) self.build_connection(node_input_name, node_name) super(PyTorchGraph, self).build(input_shape=input_shape) + self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape) - # Add Input Node + def _collect_ipt_shape_of_each_node(self, input_shape): + """ + Collect input tensor shape of each node. + + Args: + input_shape (tuple): Input shape. + + """ input_node = InputNode(input_shape) + input_node_name = "{}InputNode" for node_name, node in self._nodes_collection.items(): if node_name in self._input_nodes: + ipt_nd_name = input_node_name.format(input_node.scope_name) input_node.set_scope_name(node.scope_name) - node.precursor_nodes.append(input_node.scope_name) + node.precursor_nodes.insert(0, ipt_nd_name) input_node.set_successor_nodes(node_name) - self._nodes_collection[input_node.scope_name] = input_node - self._input_shape[node_name] = feed_forward_ipt_shape - break + self._shape_dict[ipt_nd_name] = input_node.output_shape + + ipt_shape = [] + for p_nd in node.precursor_nodes: + shp = self._shape_dict.get(p_nd) + ipt_shape.append(tuple(shp)) + + self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape def sub_graph_merging(self): """ @@ -199,12 +240,6 @@ class PyTorchGraph(Graph): """ raise NotImplementedError() - def to_ir(self, mapper): - """ - Convert graph to IR graph. - """ - raise NotImplementedError() - def build_connection(self, src, tgt) -> NoReturn: """ Build connection between source node and target node. @@ -215,13 +250,11 @@ class PyTorchGraph(Graph): """ # If src and tgt are the same node, src not in node_collection or - # tgt not in node_collection, - # then skip this edge. + # tgt not in node_collection, then skip this edge. if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection: if src.split(':')[0] not in self._nodes_collection: - warnings.warn(f"Graph construct a self-loop node {src}. Ignored.") + log.warning("Graph construct a self-loop node %s. Ignored.", src) return - if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes: self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt) if src not in self._nodes_collection[tgt].precursor_nodes: @@ -244,11 +277,10 @@ class PyTorchGraph(Graph): """ Load graph metadata. """ - error = NotImplementedError("class `PyTorchGraph` has not implemented " - "`load_metadata()`.") - log.error(str(error)) - log.exception(error) - raise error + err_msg = "class `PyTorchGraph` has not implemented " \ + "`load_metadata()`." + log.error(err_msg) + raise NotImplementedError(err_msg) @staticmethod def load_graph(graph_path: str): diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py index d98b22f5..dc5877ec 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_node.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== """Define PyTorch graph node.""" +from copy import deepcopy + from .base import GraphNode from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ @@ -140,7 +142,7 @@ class PyTorchGraphNode(GraphNode): Returns: str, op name. """ - return self._op_name # if self.is_empty() else self.tag + return self._op_name @property def real_name(self): @@ -177,8 +179,14 @@ class PyTorchGraphNode(GraphNode): args.update({"input_shape": self.input_shape, "output_shape": self.output_shape}) - expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" - for k, v in args.items()]) + if self._node_type == NodeType.OPERATION.value: + expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" + for k, v in args.items()]) + else: + # When it's type is module, class or func, + # it's not necessary to replace var. + expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" + for k, v in args.items()]) declare = f"self.{self._variable_name} = {operator}({expr})" call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_in_construct})" @@ -211,15 +219,16 @@ class PyTorchGraphNode(GraphNode): raw_params[k] = getitem_of_node(node, k) return raw_params - def replace_with_arg(self, arg): + def replace_with_arg(self, src_arg, tgt_arg): """ Replace actual parameter with formal parameter. Args: - arg (str): Arg name. + src_arg (str): Original arg name. + tgt_arg (str): Target arg name. """ - self._args_in_code[arg] = arg + self._args_in_code[src_arg] = tgt_arg @staticmethod def _extract_var_name(scope_name: str): @@ -241,6 +250,13 @@ class PyTorchGraphNode(GraphNode): mapper (Mapper): Mapper of params. """ + if self._node_type != NodeType.OPERATION.value: + args = deepcopy(self._args_in_code) + self._args_in_code = dict() + for arg, value in args.items(): + self._args_in_code[self._get_arg_name(arg)] = value + return None, None + if not self.transformed: _, _ = super(PyTorchGraphNode, self).param_transform(mapper)