| @@ -20,6 +20,7 @@ from importlib.util import find_spec | |||
| import mindinsight | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .mapper import ONNXToMindSporeMapper | |||
| from ..common.exceptions import NodeTypeNotSupport | |||
| permissions = os.R_OK | os.W_OK | os.X_OK | |||
| os.umask(permissions << 3 | permissions) | |||
| @@ -92,7 +93,13 @@ def graph_based_converter(graph_path: str, sample_shape: tuple, | |||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | |||
| checkpoint=checkpoint_path) | |||
| hierarchical_tree = HierarchicalTreeFactory.create(graph_obj) | |||
| try: | |||
| hierarchical_tree = HierarchicalTreeFactory.create(graph_obj) | |||
| except Exception as e: | |||
| log.exception(e) | |||
| log.error("Error occur when create hierarchical tree.") | |||
| raise NodeTypeNotSupport("This model is not supported now.") | |||
| hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | |||
| report_folder=report_folder) | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Hierarchical tree module.""" | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .hierarchical_tree import HierarchicalTree | |||
| __all__ = [ | |||
| @@ -35,23 +36,14 @@ class HierarchicalTreeFactory: | |||
| HierarchicalTree, tree. | |||
| """ | |||
| tree = HierarchicalTree() | |||
| node_input = None | |||
| for _, node_name in enumerate(graph.nodes_in_topological_order): | |||
| node_inst = graph.get_node(node_name) | |||
| node_input = graph.get_input_shape(node_name) | |||
| node_output = graph.get_output_shape(node_name) | |||
| if node_inst.in_degree == 0: | |||
| # If in-degree equals to zero, then it's a input node. | |||
| continue | |||
| # If the node is on the top, then fetch its input | |||
| # from input table. | |||
| if not node_input: | |||
| node_input = graph.get_input_shape(node_name) | |||
| if not node_input: | |||
| raise ValueError(f"This model is not supported now. " | |||
| f"Cannot find {node_name}'s input shape.") | |||
| err_msg = f"This model is not supported now. " \ | |||
| f"Cannot find {node_name}'s input shape." | |||
| log.error(err_msg) | |||
| tree.insert(node_inst, node_name, node_input, node_output) | |||
| node_input = node_output | |||
| return tree | |||
| @@ -31,6 +31,7 @@ from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVE | |||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | |||
| from ..constant import NodeType | |||
| from ..report_generator import ReportGenerator | |||
| from ...common.exceptions import NodeTypeNotSupport | |||
| GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() | |||
| @@ -56,7 +57,7 @@ class HierarchicalTree(Tree): | |||
| # Manage module name to used. | |||
| self._module_mgr = ModuleNameMgr() | |||
| # Manage variable name in a module. | |||
| self._args_mgr_in_module = dict() | |||
| self._vars_mgr_in_module = dict() | |||
| self._module_vars = dict() | |||
| @property | |||
| @@ -86,7 +87,7 @@ class HierarchicalTree(Tree): | |||
| parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) | |||
| identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1]) | |||
| try_parent = f"{parent}{SEPARATOR_IN_SCOPE}{scope}" \ | |||
| if not parent else scope | |||
| if parent else scope | |||
| if self.contains(try_parent): | |||
| # Whether current node existed. | |||
| parent = try_parent | |||
| @@ -132,6 +133,8 @@ class HierarchicalTree(Tree): | |||
| """ | |||
| Shrink sub-tree into one node. | |||
| Use child node to replace its ancestor. | |||
| Args: | |||
| node (Node): List of nodes to be merged. | |||
| @@ -140,6 +143,8 @@ class HierarchicalTree(Tree): | |||
| parent_node = self[node.predecessor(self.tree_identifier)] | |||
| # Keep successors of parent. | |||
| brothers = deepcopy(parent_node.successors(self.tree_identifier)) | |||
| # Because shrink occurs when node has only one child, | |||
| # so we take index-0. | |||
| child = node.successors(self.tree_identifier)[0] | |||
| self.move_node(source=child, | |||
| destination=node.predecessor(self.tree_identifier)) | |||
| @@ -158,9 +163,13 @@ class HierarchicalTree(Tree): | |||
| out_folder (str): Output folder. | |||
| """ | |||
| self._adjust_structure() | |||
| code_fragments = self._generate_codes(mapper) | |||
| try: | |||
| self._adjust_structure() | |||
| code_fragments = self._generate_codes(mapper) | |||
| except Exception as e: | |||
| log.exception(e) | |||
| log.error("Error occur when create hierarchical tree.") | |||
| raise NodeTypeNotSupport("This model is not supported now.") | |||
| out_folder = os.path.abspath(out_folder) | |||
| if not report_folder: | |||
| @@ -176,9 +185,8 @@ class HierarchicalTree(Tree): | |||
| for file_name in code_fragments: | |||
| code, report = code_fragments[file_name] | |||
| try: | |||
| with os.fdopen( | |||
| os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"), | |||
| self.flags, self.modes), 'w') as file: | |||
| with os.fdopen(os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"), | |||
| self.flags, self.modes), "w") as file: | |||
| file.write(code) | |||
| except IOError as error: | |||
| log.error(str(error)) | |||
| @@ -186,9 +194,8 @@ class HierarchicalTree(Tree): | |||
| raise error | |||
| try: | |||
| with os.fdopen( | |||
| os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"), | |||
| self.flags, stat.S_IRUSR), "w") as rpt_f: | |||
| with os.fdopen(os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"), | |||
| self.flags, stat.S_IRUSR), "w") as rpt_f: | |||
| rpt_f.write(report) | |||
| except IOError as error: | |||
| log.error(str(error)) | |||
| @@ -223,7 +230,8 @@ class HierarchicalTree(Tree): | |||
| Returns: | |||
| Node, node. | |||
| """ | |||
| if node.data.node_type == NodeType.MODULE.value: | |||
| if node.data.node_type in {NodeType.MODULE.value, NodeType.CLASS.value, | |||
| NodeType.FUNC.value}: | |||
| # If current node is class or function, then | |||
| # remove unused args in __init__. | |||
| cur_module_key = node.data.hash_key or self.hash_key(node) | |||
| @@ -231,13 +239,15 @@ class HierarchicalTree(Tree): | |||
| node = self._clear_unused_args(node, | |||
| self._merged_module_args[cur_module_key]) | |||
| # `self._merged_module_args` records formal args. | |||
| # We need to replace actual args. | |||
| if precursor_module_key in self._merged_module_args: | |||
| # If parent node is in `_merged_module_args`, then | |||
| # replace current node args with arg name declared | |||
| # in _merged_module_args. | |||
| for arg in node.data.args_in_code.keys(): | |||
| if arg in self._merged_module_args[precursor_module_key]: | |||
| node.data.replace_with_arg(arg) | |||
| node.data.replace_with_arg(arg, arg) | |||
| return node | |||
| @staticmethod | |||
| @@ -254,7 +264,8 @@ class HierarchicalTree(Tree): | |||
| """ | |||
| args_in_code = list(node.data.args_in_code.keys()) | |||
| for arg in args_in_code: | |||
| if arg not in used_args: | |||
| ori_arg = arg.replace(f"_{node.data.variable_name}", "") | |||
| if ori_arg not in used_args: | |||
| node.data.args_in_code.pop(arg) | |||
| return node | |||
| @@ -287,9 +298,11 @@ class HierarchicalTree(Tree): | |||
| if node.data.node_type == NodeType.MODULE.value: | |||
| self._create_module_args_and_vars(node, mapper) | |||
| # 2. Get nodes can be merged. | |||
| self._module_merging(node_collection) | |||
| # Module merging based on all nodes. | |||
| self._module_merging() | |||
| for depth in depths: | |||
| node_collection = self._hierarchical_order[depth] | |||
| snippets = set() | |||
| for node_name in node_collection: | |||
| nd_inst = self.get_node(node_name) | |||
| @@ -297,8 +310,7 @@ class HierarchicalTree(Tree): | |||
| continue | |||
| # Generate hash key for node. | |||
| module_key = self.hash_key(nd_inst) | |||
| module_key = nd_inst.data.hash_key | |||
| # Get code generation func. | |||
| func, node_type = self._fetch_func_and_type(nd_inst) | |||
| @@ -325,9 +337,8 @@ class HierarchicalTree(Tree): | |||
| # 3. Pre-process node args. | |||
| nd_inst = self._preprocess_node_args(nd_inst, module_key) | |||
| # 4. Post-process child node args. | |||
| for scsr_nd_name in nd_inst.successors(self.tree_identifier): | |||
| self._postprocess_node_args(self.get_node(scsr_nd_name), | |||
| module_key) | |||
| for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)): | |||
| self._postprocess_node_args(self.get_node(scsr_nd_name), module_key) | |||
| # 5. Generate code. | |||
| snippets.add(func(nd_inst, nd_inst.data.module_name, module_key)) | |||
| @@ -335,7 +346,6 @@ class HierarchicalTree(Tree): | |||
| formatted_code, _ = FormatCode("".join(code_blocks), | |||
| style_config=CodeFormatConfig.PEP8.value) | |||
| report_generator = ReportGenerator() | |||
| report = report_generator.gen_report(formatted_code) | |||
| @@ -403,9 +413,9 @@ class HierarchicalTree(Tree): | |||
| "output_shape": c_nd.data.output_shape}) | |||
| # Generate code statement. | |||
| expr = ", ".join([f"{k}={v}" for k, v in args.items()]) | |||
| expr = ", ".join([f"{k.replace(f'_{c_nd.data.variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| code_line = f"{operator}({expr})" | |||
| module_list.append(code_line) | |||
| body = f",{NEW_LINE}{SECOND_LEVEL_INDENT}".join(module_list) | |||
| @@ -435,6 +445,7 @@ class HierarchicalTree(Tree): | |||
| if class_key.lower() in self._merged_module_args and \ | |||
| self._merged_module_args[class_key.lower()]: | |||
| args = f"{', '.join(self._merged_module_args[class_key.lower()])}" | |||
| class_init = f"{FIRST_LEVEL_INDENT}def __init__(self, " \ | |||
| f"{args}):" \ | |||
| f"{NEW_LINE}{SECOND_LEVEL_INDENT}" \ | |||
| @@ -455,7 +466,8 @@ class HierarchicalTree(Tree): | |||
| construct_block.append(construct) | |||
| init_block.append(init) | |||
| class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||
| class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):" \ | |||
| f"{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||
| init_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(init_block) | |||
| csrt_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(construct_block) | |||
| csrt_rtn = f"{NEW_LINE}{SECOND_LEVEL_INDENT}return output{NEW_LINE}" | |||
| @@ -514,7 +526,7 @@ class HierarchicalTree(Tree): | |||
| def _find_all_previous_opt_var_(self, cur_nd, pre_nd): | |||
| """ | |||
| Find all input varian names. | |||
| Find all input variable names. | |||
| Args: | |||
| cur_nd (Node): Current node. | |||
| @@ -585,61 +597,46 @@ class HierarchicalTree(Tree): | |||
| node.data.hash_key = unique_key | |||
| return unique_key | |||
| def _module_merging(self, nodes): | |||
| """ | |||
| Generate sub-module and corresponding params. | |||
| Args: | |||
| nodes (List[str]): Nodes name. | |||
| """ | |||
| merged_module = dict() | |||
| def _module_merging(self): | |||
| """Generate sub-module and corresponding params.""" | |||
| merged_module_args = dict() | |||
| for node_name in nodes: | |||
| nd_inst = self.get_node(node_name) | |||
| if nd_inst.data.node_type != NodeType.MODULE.value: | |||
| continue | |||
| module_key = self.hash_key(nd_inst) | |||
| if module_key not in merged_module: | |||
| merged_module[module_key] = [nd_inst.data.args_in_code] | |||
| else: | |||
| merged_module[module_key].append(nd_inst.data.args_in_code) | |||
| for module_key, module_args in merged_module.items(): | |||
| for module_key, module_args in self._merged_module.items(): | |||
| if module_key not in merged_module_args: | |||
| merged_module_args[module_key] = [] | |||
| # Take first element's args as base. | |||
| keys = module_args[0].keys() | |||
| for key in keys: | |||
| for i in range(1, len(module_args)): | |||
| if module_args[0][key] != module_args[i][key]: | |||
| if key in module_args[i] and module_args[0][key] != module_args[i][key]: | |||
| merged_module_args[module_key].append(key) | |||
| break | |||
| if key not in module_args[i]: | |||
| merged_module_args[module_key].append(key) | |||
| break | |||
| self._merged_module.update(merged_module) | |||
| self._merged_module_args.update(merged_module_args) | |||
| def _create_module_args_and_vars(self, node, mapper): | |||
| """ | |||
| Create module args. | |||
| Create module args and variables in current node. | |||
| Args: | |||
| node (Node): Node on tree. | |||
| mapper (Mapper): Mapper of params. | |||
| """ | |||
| # All args and value pair in current node module. | |||
| module_args = dict() | |||
| module_key = self.hash_key(node) | |||
| created = False | |||
| if module_key not in self._args_mgr_in_module: | |||
| self._args_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR | |||
| if module_key not in self._vars_mgr_in_module: | |||
| self._vars_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR | |||
| self._module_vars[module_key] = [] | |||
| else: | |||
| created = True | |||
| # Sub-modules in the module could have arg name conflicts. | |||
| for idx, successor_name in enumerate(node.successors(self.tree_identifier)): | |||
| nd_inst = self.get_node(successor_name) | |||
| # Generate variable name here, then | |||
| @@ -648,12 +645,11 @@ class HierarchicalTree(Tree): | |||
| nd_inst.data.variable_name = self._module_vars[module_key][idx] | |||
| else: | |||
| variable_name = nd_inst.data.op_name or nd_inst.data.module_name | |||
| variable_name = self._args_mgr_in_module[module_key].get_name(variable_name) | |||
| variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name) | |||
| nd_inst.data.variable_name = variable_name | |||
| if nd_inst.data.node_type == NodeType.OPERATION.value: | |||
| # Generation of params must behind variable assigment. | |||
| nd_inst.data.param_transform(mapper) | |||
| # Generation of params must behind variable assigment. | |||
| nd_inst.data.param_transform(mapper) | |||
| module_args.update(nd_inst.data.args_in_code) | |||
| @@ -662,6 +658,12 @@ class HierarchicalTree(Tree): | |||
| node.data.args_in_code = module_args | |||
| # Collect module args of `module_key`. | |||
| if module_key not in self._merged_module: | |||
| self._merged_module[module_key] = [node.data.args_in_code] | |||
| else: | |||
| self._merged_module[module_key].append(node.data.args_in_code) | |||
| @staticmethod | |||
| def _create_operation_args(node, mapper): | |||
| """ | |||
| @@ -692,21 +694,20 @@ class HierarchicalTree(Tree): | |||
| self._hierarchical_order = hierarchical_order | |||
| def sub_graph_merging(self) -> NoReturn: | |||
| """ | |||
| Shrink subtree. | |||
| """ | |||
| """Shrink the module has only one child.""" | |||
| self.update_hierarchical_order() | |||
| depths = sorted(list(self._hierarchical_order.keys()), reverse=True) | |||
| for depth in depths: | |||
| for node_name in self._hierarchical_order[depth]: | |||
| node_inst = self[node_name] | |||
| if not node_inst.data and len(node_inst.successors(self.tree_identifier)) == 1: | |||
| # If the node type is module and has only one child, | |||
| # then merge it with its child. | |||
| if node_inst.data.node_type == NodeType.MODULE.value and \ | |||
| len(node_inst.successors(self.tree_identifier)) == 1: | |||
| self.shrink(node_inst) | |||
| def _adjust_structure(self) -> NoReturn: | |||
| """ | |||
| Adjust tree structure to generate source code. | |||
| """ | |||
| """Adjust tree structure to generate source code.""" | |||
| self.sub_graph_merging() | |||
| self.update_hierarchical_order() | |||
| @@ -53,6 +53,9 @@ class ModuleNameMgr(NameMgr): | |||
| """Module name manager.""" | |||
| # Manage variable name of different modules. | |||
| global_var_namespace = set() | |||
| # Manage variable name of different type. | |||
| global_op_namespace = dict() | |||
| START_IDX = 0 | |||
| @@ -81,14 +84,21 @@ class GlobalVarNameMgr: | |||
| Returns: | |||
| str, module name. | |||
| """ | |||
| op_type = op_type.lower() | |||
| if op_type not in global_op_namespace: | |||
| global_op_namespace[op_type] = START_IDX | |||
| suffix = "" | |||
| else: | |||
| global_op_namespace[op_type] += 1 | |||
| suffix = f"{global_op_namespace[op_type] - 1}" | |||
| new_name = f"{self._get_name(op_type)}{suffix}" | |||
| def _gen(t): | |||
| t = t.lower() | |||
| if t not in global_op_namespace: | |||
| global_op_namespace[t] = START_IDX | |||
| suffix = "" | |||
| else: | |||
| global_op_namespace[t] += 1 | |||
| suffix = f"{global_op_namespace[t] - 1}" | |||
| return f"{self._get_name(t)}{suffix}" | |||
| new_name = _gen(op_type) | |||
| while new_name in global_var_namespace: | |||
| new_name = _gen(op_type) | |||
| global_var_namespace.add(new_name) | |||
| return new_name | |||
| @@ -18,6 +18,7 @@ import importlib | |||
| import json | |||
| import os | |||
| from typing import Dict | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| CONFIG_JSON = "onnx_to_ms.json" | |||
| OPERATION_TABLE = os.path.join( | |||
| @@ -91,7 +92,8 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| weights_converter = getattr(converter, GET_OP_WEIGHTS) | |||
| except (ModuleNotFoundError,) as e: | |||
| # If mapper can not be found, then skip it. | |||
| print(f"Converting {op_name} failed, see {e}") | |||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | |||
| log.error(err_msg) | |||
| return None, dict() | |||
| try: | |||
| @@ -99,8 +101,9 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| converted_params = params_converter(params, weights) | |||
| converted_weights = weights_converter(weights) if weights else dict() | |||
| converted_params.update(converted_weights) | |||
| except (AttributeError, KeyError, ValueError, TypeError) as _: | |||
| print(f"Converting {op_name} failed.") | |||
| except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | |||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | |||
| log.error(err_msg) | |||
| return None, dict() | |||
| return converter_name, converted_params | |||
| @@ -39,7 +39,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||
| else: | |||
| stride = params['strides'] | |||
| kernel_shape = list(weight.shape) | |||
| in_channels = kernel_shape[-2] | |||
| in_channels = kernel_shape[-2] * params.get("group", 1) | |||
| out_channels = kernel_shape[-1] | |||
| kernel_size = kernel_shape[:-2] | |||
| if len(kernel_size) == 1: | |||
| @@ -31,8 +31,7 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): | |||
| @staticmethod | |||
| def _convert_params(params, weights): | |||
| dim = 1 if len(params['input_shape']) == 3\ | |||
| else 2 | |||
| dim = 1 if len(params['input_shape']) == 3 else 2 | |||
| if dim == 1: | |||
| kernel_size = params['input_shape'][-1] // params['output_shape'][-1] | |||
| else: | |||
| @@ -100,6 +100,18 @@ class Graph(BaseGraph, abc.ABC): | |||
| self._topological_order = [] | |||
| self._input_shape = dict() | |||
| def get_input_shape(self, name): | |||
| """ | |||
| Get node input shape. | |||
| Args: | |||
| name (str): Node name. | |||
| Returns: | |||
| list, shape. | |||
| """ | |||
| return self._input_shape.get(name) | |||
| def get_output_shape(self, name): | |||
| """ | |||
| Get node output shape. | |||
| @@ -112,7 +124,7 @@ class Graph(BaseGraph, abc.ABC): | |||
| """ | |||
| return self._shape_dict.get(name) | |||
| def get_input_shape(self, name): | |||
| def get_input_shape_from_input(self, name): | |||
| """ | |||
| Get node input shape. | |||
| @@ -482,7 +494,7 @@ class GraphNode(abc.ABC): | |||
| """Return op_name.""" | |||
| @abc.abstractmethod | |||
| def replace_with_arg(self, arg): | |||
| def replace_with_arg(self, src_arg, tgt_arg): | |||
| """Replace actual parameter with formal parameter.""" | |||
| @abc.abstractmethod | |||
| @@ -53,7 +53,7 @@ class InputNode(GraphNode): | |||
| def hash_key(self): | |||
| pass | |||
| def replace_with_arg(self, arg): | |||
| def replace_with_arg(self, src_arg, tgt_arg): | |||
| pass | |||
| def _get_arg_name(self, arg): | |||
| @@ -65,9 +65,30 @@ class InputNode(GraphNode): | |||
| def __init__(self, input_shape): | |||
| super(InputNode, self).__init__(node=None) | |||
| self._op_name = 'Input' | |||
| self._op_params = {'node_shape': input_shape} | |||
| self._op_params = {'input_shape': input_shape, | |||
| "output_shape": input_shape} | |||
| self._node_type = NodeType.INPUT.value | |||
| @property | |||
| def input_shape(self): | |||
| """ | |||
| Input tensor shape of current node. | |||
| Returns: | |||
| tuple, tensor shape of input. | |||
| """ | |||
| return self._op_params["input_shape"] | |||
| @property | |||
| def output_shape(self): | |||
| """ | |||
| Output tensor shape. | |||
| Returns: | |||
| tuple, output tensor shape. | |||
| """ | |||
| return self._op_params["output_shape"] | |||
| def set_scope_name(self, original_input_scope_name): | |||
| """ | |||
| Set scope name. | |||
| @@ -13,7 +13,6 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define PyTorch graph.""" | |||
| import warnings | |||
| import re | |||
| from typing import Dict, NoReturn | |||
| @@ -27,8 +26,11 @@ from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE | |||
| from ..constant import LEFT_BUCKET, RIGHT_BUCKET | |||
| NONE_SCOPE_OP = { | |||
| 'onnx::Add': 'Add', | |||
| 'onnx::Flatten': 'Flatten', | |||
| "onnx::Add": "Add", | |||
| "onnx::Flatten": "Flatten", | |||
| "onnx::Concat": "Concat", | |||
| "onnx::Squeeze": "Squeeze", | |||
| "onnx::Unsqueeze": "Unsqueeze", | |||
| } | |||
| @@ -59,6 +61,7 @@ def normalize_scope_name(node): | |||
| scopes.append(segment) | |||
| if node.kind() in NONE_SCOPE_OP.keys(): | |||
| scopes.append(NONE_SCOPE_OP[node.kind()]) | |||
| scopes = [s for s in scopes if s] | |||
| return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{PyTorchGraph.get_node_id(node)}" | |||
| @@ -90,18 +93,16 @@ class PyTorchGraph(Graph): | |||
| """ | |||
| if not input_shape: | |||
| error = ValueError("`input_shape` can not be None.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| err_msg = "`input_shape` can not be None." | |||
| log.error(err_msg) | |||
| raise ValueError(err_msg) | |||
| for item in input_shape: | |||
| if not isinstance(item, int): | |||
| error = ValueError(f"Only support model with one input now, " | |||
| f"and each shape value in `input_shape` should be int.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| err_msg = f"Only support model with one input now, " \ | |||
| f"and each shape value in `input_shape` should be int." | |||
| log.error(err_msg) | |||
| raise ValueError(err_msg) | |||
| @staticmethod | |||
| def _extract_shape(shape): | |||
| @@ -116,18 +117,29 @@ class PyTorchGraph(Graph): | |||
| """ | |||
| if "," not in shape: | |||
| return [] | |||
| shape_arr = [] | |||
| for s in shape.split(","): | |||
| s = s.strip() | |||
| if not s: | |||
| return [] | |||
| return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')] | |||
| if ":" in s: | |||
| s = s.split(":")[0] | |||
| s = s.replace("!", "") | |||
| if not s.isdigit(): | |||
| return [] | |||
| shape_arr.append(int(s)) | |||
| return shape_arr | |||
| def build(self, input_shape): | |||
| def _trace_torch_graph(self, input_shape): | |||
| """ | |||
| Build graph tree. | |||
| Trace torch computational graph. | |||
| Args: | |||
| input_shape (tuple): Input shape of model. | |||
| input_shape (tuple): Shape. | |||
| Returns: | |||
| object, pytorch graph. | |||
| """ | |||
| import torch | |||
| from torch.onnx import OperatorExportTypes | |||
| @@ -135,24 +147,34 @@ class PyTorchGraph(Graph): | |||
| from .torch_utils import create_autograd_variable | |||
| from .torch_utils import onnx_tracer | |||
| self._check_input_shape(input_shape) | |||
| feed_forward_ipt_shape = (1, *input_shape) | |||
| batched_sample = create_autograd_variable(torch.rand(*feed_forward_ipt_shape)) | |||
| # Assign execution mode to eval. | |||
| self.model.eval() | |||
| batched_sample = create_autograd_variable(torch.rand(*input_shape)) | |||
| try: | |||
| # Assign execution mode to eval. | |||
| self.model.eval() | |||
| with OverloadTorchModuleTemporarily() as _: | |||
| # In pytorch higher version, trace function has a known. | |||
| graph = onnx_tracer(self.model, batched_sample, | |||
| OperatorExportTypes.ONNX) | |||
| return graph | |||
| except RuntimeError as error: | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| def build(self, input_shape): | |||
| """ | |||
| Build graph tree. | |||
| Args: | |||
| input_shape (tuple): Input shape of model. | |||
| """ | |||
| self._check_input_shape(input_shape) | |||
| feed_forward_ipt_shape = (1, *input_shape) | |||
| graph = self._trace_torch_graph(feed_forward_ipt_shape) | |||
| nodes = list(graph.nodes()) | |||
| for node in nodes: | |||
| @@ -174,24 +196,43 @@ class PyTorchGraph(Graph): | |||
| for node_input in list(node.inputs()): | |||
| # Connect input node and src node. | |||
| if PyTorchGraph.get_node_id(node_input.node()) and node_input.node().scopeName(): | |||
| nd_id = PyTorchGraph.get_node_id(node_input.node()) | |||
| nd_scope_name = node_input.node().kind() in NONE_SCOPE_OP or \ | |||
| node_input.node().scopeName() | |||
| if nd_id and nd_scope_name: | |||
| node_input_name = normalize_scope_name( | |||
| node_input.node() | |||
| ) | |||
| self.build_connection(node_input_name, node_name) | |||
| super(PyTorchGraph, self).build(input_shape=input_shape) | |||
| self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape) | |||
| # Add Input Node | |||
| def _collect_ipt_shape_of_each_node(self, input_shape): | |||
| """ | |||
| Collect input tensor shape of each node. | |||
| Args: | |||
| input_shape (tuple): Input shape. | |||
| """ | |||
| input_node = InputNode(input_shape) | |||
| input_node_name = "{}InputNode" | |||
| for node_name, node in self._nodes_collection.items(): | |||
| if node_name in self._input_nodes: | |||
| ipt_nd_name = input_node_name.format(input_node.scope_name) | |||
| input_node.set_scope_name(node.scope_name) | |||
| node.precursor_nodes.append(input_node.scope_name) | |||
| node.precursor_nodes.insert(0, ipt_nd_name) | |||
| input_node.set_successor_nodes(node_name) | |||
| self._nodes_collection[input_node.scope_name] = input_node | |||
| self._input_shape[node_name] = feed_forward_ipt_shape | |||
| break | |||
| self._shape_dict[ipt_nd_name] = input_node.output_shape | |||
| ipt_shape = [] | |||
| for p_nd in node.precursor_nodes: | |||
| shp = self._shape_dict.get(p_nd) | |||
| ipt_shape.append(tuple(shp)) | |||
| self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape | |||
| def sub_graph_merging(self): | |||
| """ | |||
| @@ -199,12 +240,6 @@ class PyTorchGraph(Graph): | |||
| """ | |||
| raise NotImplementedError() | |||
| def to_ir(self, mapper): | |||
| """ | |||
| Convert graph to IR graph. | |||
| """ | |||
| raise NotImplementedError() | |||
| def build_connection(self, src, tgt) -> NoReturn: | |||
| """ | |||
| Build connection between source node and target node. | |||
| @@ -215,13 +250,11 @@ class PyTorchGraph(Graph): | |||
| """ | |||
| # If src and tgt are the same node, src not in node_collection or | |||
| # tgt not in node_collection, | |||
| # then skip this edge. | |||
| # tgt not in node_collection, then skip this edge. | |||
| if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection: | |||
| if src.split(':')[0] not in self._nodes_collection: | |||
| warnings.warn(f"Graph construct a self-loop node {src}. Ignored.") | |||
| log.warning("Graph construct a self-loop node %s. Ignored.", src) | |||
| return | |||
| if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes: | |||
| self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt) | |||
| if src not in self._nodes_collection[tgt].precursor_nodes: | |||
| @@ -244,11 +277,10 @@ class PyTorchGraph(Graph): | |||
| """ | |||
| Load graph metadata. | |||
| """ | |||
| error = NotImplementedError("class `PyTorchGraph` has not implemented " | |||
| "`load_metadata()`.") | |||
| log.error(str(error)) | |||
| log.exception(error) | |||
| raise error | |||
| err_msg = "class `PyTorchGraph` has not implemented " \ | |||
| "`load_metadata()`." | |||
| log.error(err_msg) | |||
| raise NotImplementedError(err_msg) | |||
| @staticmethod | |||
| def load_graph(graph_path: str): | |||
| @@ -13,6 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define PyTorch graph node.""" | |||
| from copy import deepcopy | |||
| from .base import GraphNode | |||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | |||
| @@ -140,7 +142,7 @@ class PyTorchGraphNode(GraphNode): | |||
| Returns: | |||
| str, op name. | |||
| """ | |||
| return self._op_name # if self.is_empty() else self.tag | |||
| return self._op_name | |||
| @property | |||
| def real_name(self): | |||
| @@ -177,8 +179,14 @@ class PyTorchGraphNode(GraphNode): | |||
| args.update({"input_shape": self.input_shape, | |||
| "output_shape": self.output_shape}) | |||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| if self._node_type == NodeType.OPERATION.value: | |||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| else: | |||
| # When it's type is module, class or func, | |||
| # it's not necessary to replace var. | |||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||
| for k, v in args.items()]) | |||
| declare = f"self.{self._variable_name} = {operator}({expr})" | |||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_in_construct})" | |||
| @@ -211,15 +219,16 @@ class PyTorchGraphNode(GraphNode): | |||
| raw_params[k] = getitem_of_node(node, k) | |||
| return raw_params | |||
| def replace_with_arg(self, arg): | |||
| def replace_with_arg(self, src_arg, tgt_arg): | |||
| """ | |||
| Replace actual parameter with formal parameter. | |||
| Args: | |||
| arg (str): Arg name. | |||
| src_arg (str): Original arg name. | |||
| tgt_arg (str): Target arg name. | |||
| """ | |||
| self._args_in_code[arg] = arg | |||
| self._args_in_code[src_arg] = tgt_arg | |||
| @staticmethod | |||
| def _extract_var_name(scope_name: str): | |||
| @@ -241,6 +250,13 @@ class PyTorchGraphNode(GraphNode): | |||
| mapper (Mapper): Mapper of params. | |||
| """ | |||
| if self._node_type != NodeType.OPERATION.value: | |||
| args = deepcopy(self._args_in_code) | |||
| self._args_in_code = dict() | |||
| for arg, value in args.items(): | |||
| self._args_in_code[self._get_arg_name(arg)] = value | |||
| return None, None | |||
| if not self.transformed: | |||
| _, _ = super(PyTorchGraphNode, self).param_transform(mapper) | |||