Merge pull request !640 from 刘崇鸣/generalize_mindconvertertags/v1.0.0
| @@ -20,6 +20,7 @@ from importlib.util import find_spec | |||||
| import mindinsight | import mindinsight | ||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from .mapper import ONNXToMindSporeMapper | from .mapper import ONNXToMindSporeMapper | ||||
| from ..common.exceptions import NodeTypeNotSupport | |||||
| permissions = os.R_OK | os.W_OK | os.X_OK | permissions = os.R_OK | os.W_OK | os.X_OK | ||||
| os.umask(permissions << 3 | permissions) | os.umask(permissions << 3 | permissions) | ||||
| @@ -92,7 +93,13 @@ def graph_based_converter(graph_path: str, sample_shape: tuple, | |||||
| graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, | ||||
| checkpoint=checkpoint_path) | checkpoint=checkpoint_path) | ||||
| hierarchical_tree = HierarchicalTreeFactory.create(graph_obj) | |||||
| try: | |||||
| hierarchical_tree = HierarchicalTreeFactory.create(graph_obj) | |||||
| except Exception as e: | |||||
| log.exception(e) | |||||
| log.error("Error occur when create hierarchical tree.") | |||||
| raise NodeTypeNotSupport("This model is not supported now.") | |||||
| hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | hierarchical_tree.save_source_files(output_folder, mapper=ONNXToMindSporeMapper, | ||||
| report_folder=report_folder) | report_folder=report_folder) | ||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Hierarchical tree module.""" | """Hierarchical tree module.""" | ||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| from .hierarchical_tree import HierarchicalTree | from .hierarchical_tree import HierarchicalTree | ||||
| __all__ = [ | __all__ = [ | ||||
| @@ -35,23 +36,14 @@ class HierarchicalTreeFactory: | |||||
| HierarchicalTree, tree. | HierarchicalTree, tree. | ||||
| """ | """ | ||||
| tree = HierarchicalTree() | tree = HierarchicalTree() | ||||
| node_input = None | |||||
| for _, node_name in enumerate(graph.nodes_in_topological_order): | for _, node_name in enumerate(graph.nodes_in_topological_order): | ||||
| node_inst = graph.get_node(node_name) | node_inst = graph.get_node(node_name) | ||||
| node_input = graph.get_input_shape(node_name) | |||||
| node_output = graph.get_output_shape(node_name) | node_output = graph.get_output_shape(node_name) | ||||
| if node_inst.in_degree == 0: | |||||
| # If in-degree equals to zero, then it's a input node. | |||||
| continue | |||||
| # If the node is on the top, then fetch its input | |||||
| # from input table. | |||||
| if not node_input: | |||||
| node_input = graph.get_input_shape(node_name) | |||||
| if not node_input: | if not node_input: | ||||
| raise ValueError(f"This model is not supported now. " | |||||
| f"Cannot find {node_name}'s input shape.") | |||||
| err_msg = f"This model is not supported now. " \ | |||||
| f"Cannot find {node_name}'s input shape." | |||||
| log.error(err_msg) | |||||
| tree.insert(node_inst, node_name, node_input, node_output) | tree.insert(node_inst, node_name, node_input, node_output) | ||||
| node_input = node_output | |||||
| return tree | return tree | ||||
| @@ -31,6 +31,7 @@ from ..constant import SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, FIRST_LEVE | |||||
| from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | from ..constant import NEW_LINE, SECOND_LEVEL_INDENT | ||||
| from ..constant import NodeType | from ..constant import NodeType | ||||
| from ..report_generator import ReportGenerator | from ..report_generator import ReportGenerator | ||||
| from ...common.exceptions import NodeTypeNotSupport | |||||
| GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() | GLOBAL_VAR_NAME_MGR = GlobalVarNameMgr() | ||||
| @@ -56,7 +57,7 @@ class HierarchicalTree(Tree): | |||||
| # Manage module name to used. | # Manage module name to used. | ||||
| self._module_mgr = ModuleNameMgr() | self._module_mgr = ModuleNameMgr() | ||||
| # Manage variable name in a module. | # Manage variable name in a module. | ||||
| self._args_mgr_in_module = dict() | |||||
| self._vars_mgr_in_module = dict() | |||||
| self._module_vars = dict() | self._module_vars = dict() | ||||
| @property | @property | ||||
| @@ -86,7 +87,7 @@ class HierarchicalTree(Tree): | |||||
| parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) | parent = SEPARATOR_IN_SCOPE.join(scopes[:idx]) | ||||
| identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1]) | identifier = SEPARATOR_IN_SCOPE.join(scopes[:idx + 1]) | ||||
| try_parent = f"{parent}{SEPARATOR_IN_SCOPE}{scope}" \ | try_parent = f"{parent}{SEPARATOR_IN_SCOPE}{scope}" \ | ||||
| if not parent else scope | |||||
| if parent else scope | |||||
| if self.contains(try_parent): | if self.contains(try_parent): | ||||
| # Whether current node existed. | # Whether current node existed. | ||||
| parent = try_parent | parent = try_parent | ||||
| @@ -132,6 +133,8 @@ class HierarchicalTree(Tree): | |||||
| """ | """ | ||||
| Shrink sub-tree into one node. | Shrink sub-tree into one node. | ||||
| Use child node to replace its ancestor. | |||||
| Args: | Args: | ||||
| node (Node): List of nodes to be merged. | node (Node): List of nodes to be merged. | ||||
| @@ -140,6 +143,8 @@ class HierarchicalTree(Tree): | |||||
| parent_node = self[node.predecessor(self.tree_identifier)] | parent_node = self[node.predecessor(self.tree_identifier)] | ||||
| # Keep successors of parent. | # Keep successors of parent. | ||||
| brothers = deepcopy(parent_node.successors(self.tree_identifier)) | brothers = deepcopy(parent_node.successors(self.tree_identifier)) | ||||
| # Because shrink occurs when node has only one child, | |||||
| # so we take index-0. | |||||
| child = node.successors(self.tree_identifier)[0] | child = node.successors(self.tree_identifier)[0] | ||||
| self.move_node(source=child, | self.move_node(source=child, | ||||
| destination=node.predecessor(self.tree_identifier)) | destination=node.predecessor(self.tree_identifier)) | ||||
| @@ -158,9 +163,13 @@ class HierarchicalTree(Tree): | |||||
| out_folder (str): Output folder. | out_folder (str): Output folder. | ||||
| """ | """ | ||||
| self._adjust_structure() | |||||
| code_fragments = self._generate_codes(mapper) | |||||
| try: | |||||
| self._adjust_structure() | |||||
| code_fragments = self._generate_codes(mapper) | |||||
| except Exception as e: | |||||
| log.exception(e) | |||||
| log.error("Error occur when create hierarchical tree.") | |||||
| raise NodeTypeNotSupport("This model is not supported now.") | |||||
| out_folder = os.path.abspath(out_folder) | out_folder = os.path.abspath(out_folder) | ||||
| if not report_folder: | if not report_folder: | ||||
| @@ -176,9 +185,8 @@ class HierarchicalTree(Tree): | |||||
| for file_name in code_fragments: | for file_name in code_fragments: | ||||
| code, report = code_fragments[file_name] | code, report = code_fragments[file_name] | ||||
| try: | try: | ||||
| with os.fdopen( | |||||
| os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"), | |||||
| self.flags, self.modes), 'w') as file: | |||||
| with os.fdopen(os.open(os.path.join(os.path.abspath(out_folder), f"{file_name}.py"), | |||||
| self.flags, self.modes), "w") as file: | |||||
| file.write(code) | file.write(code) | ||||
| except IOError as error: | except IOError as error: | ||||
| log.error(str(error)) | log.error(str(error)) | ||||
| @@ -186,9 +194,8 @@ class HierarchicalTree(Tree): | |||||
| raise error | raise error | ||||
| try: | try: | ||||
| with os.fdopen( | |||||
| os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"), | |||||
| self.flags, stat.S_IRUSR), "w") as rpt_f: | |||||
| with os.fdopen(os.open(os.path.join(report_folder, f"report_of_{file_name}.txt"), | |||||
| self.flags, stat.S_IRUSR), "w") as rpt_f: | |||||
| rpt_f.write(report) | rpt_f.write(report) | ||||
| except IOError as error: | except IOError as error: | ||||
| log.error(str(error)) | log.error(str(error)) | ||||
| @@ -223,7 +230,8 @@ class HierarchicalTree(Tree): | |||||
| Returns: | Returns: | ||||
| Node, node. | Node, node. | ||||
| """ | """ | ||||
| if node.data.node_type == NodeType.MODULE.value: | |||||
| if node.data.node_type in {NodeType.MODULE.value, NodeType.CLASS.value, | |||||
| NodeType.FUNC.value}: | |||||
| # If current node is class or function, then | # If current node is class or function, then | ||||
| # remove unused args in __init__. | # remove unused args in __init__. | ||||
| cur_module_key = node.data.hash_key or self.hash_key(node) | cur_module_key = node.data.hash_key or self.hash_key(node) | ||||
| @@ -231,13 +239,15 @@ class HierarchicalTree(Tree): | |||||
| node = self._clear_unused_args(node, | node = self._clear_unused_args(node, | ||||
| self._merged_module_args[cur_module_key]) | self._merged_module_args[cur_module_key]) | ||||
| # `self._merged_module_args` records formal args. | |||||
| # We need to replace actual args. | |||||
| if precursor_module_key in self._merged_module_args: | if precursor_module_key in self._merged_module_args: | ||||
| # If parent node is in `_merged_module_args`, then | # If parent node is in `_merged_module_args`, then | ||||
| # replace current node args with arg name declared | # replace current node args with arg name declared | ||||
| # in _merged_module_args. | # in _merged_module_args. | ||||
| for arg in node.data.args_in_code.keys(): | for arg in node.data.args_in_code.keys(): | ||||
| if arg in self._merged_module_args[precursor_module_key]: | if arg in self._merged_module_args[precursor_module_key]: | ||||
| node.data.replace_with_arg(arg) | |||||
| node.data.replace_with_arg(arg, arg) | |||||
| return node | return node | ||||
| @staticmethod | @staticmethod | ||||
| @@ -254,7 +264,8 @@ class HierarchicalTree(Tree): | |||||
| """ | """ | ||||
| args_in_code = list(node.data.args_in_code.keys()) | args_in_code = list(node.data.args_in_code.keys()) | ||||
| for arg in args_in_code: | for arg in args_in_code: | ||||
| if arg not in used_args: | |||||
| ori_arg = arg.replace(f"_{node.data.variable_name}", "") | |||||
| if ori_arg not in used_args: | |||||
| node.data.args_in_code.pop(arg) | node.data.args_in_code.pop(arg) | ||||
| return node | return node | ||||
| @@ -287,9 +298,11 @@ class HierarchicalTree(Tree): | |||||
| if node.data.node_type == NodeType.MODULE.value: | if node.data.node_type == NodeType.MODULE.value: | ||||
| self._create_module_args_and_vars(node, mapper) | self._create_module_args_and_vars(node, mapper) | ||||
| # 2. Get nodes can be merged. | |||||
| self._module_merging(node_collection) | |||||
| # Module merging based on all nodes. | |||||
| self._module_merging() | |||||
| for depth in depths: | |||||
| node_collection = self._hierarchical_order[depth] | |||||
| snippets = set() | snippets = set() | ||||
| for node_name in node_collection: | for node_name in node_collection: | ||||
| nd_inst = self.get_node(node_name) | nd_inst = self.get_node(node_name) | ||||
| @@ -297,8 +310,7 @@ class HierarchicalTree(Tree): | |||||
| continue | continue | ||||
| # Generate hash key for node. | # Generate hash key for node. | ||||
| module_key = self.hash_key(nd_inst) | |||||
| module_key = nd_inst.data.hash_key | |||||
| # Get code generation func. | # Get code generation func. | ||||
| func, node_type = self._fetch_func_and_type(nd_inst) | func, node_type = self._fetch_func_and_type(nd_inst) | ||||
| @@ -325,9 +337,8 @@ class HierarchicalTree(Tree): | |||||
| # 3. Pre-process node args. | # 3. Pre-process node args. | ||||
| nd_inst = self._preprocess_node_args(nd_inst, module_key) | nd_inst = self._preprocess_node_args(nd_inst, module_key) | ||||
| # 4. Post-process child node args. | # 4. Post-process child node args. | ||||
| for scsr_nd_name in nd_inst.successors(self.tree_identifier): | |||||
| self._postprocess_node_args(self.get_node(scsr_nd_name), | |||||
| module_key) | |||||
| for _, scsr_nd_name in enumerate(nd_inst.successors(self.tree_identifier)): | |||||
| self._postprocess_node_args(self.get_node(scsr_nd_name), module_key) | |||||
| # 5. Generate code. | # 5. Generate code. | ||||
| snippets.add(func(nd_inst, nd_inst.data.module_name, module_key)) | snippets.add(func(nd_inst, nd_inst.data.module_name, module_key)) | ||||
| @@ -335,7 +346,6 @@ class HierarchicalTree(Tree): | |||||
| formatted_code, _ = FormatCode("".join(code_blocks), | formatted_code, _ = FormatCode("".join(code_blocks), | ||||
| style_config=CodeFormatConfig.PEP8.value) | style_config=CodeFormatConfig.PEP8.value) | ||||
| report_generator = ReportGenerator() | report_generator = ReportGenerator() | ||||
| report = report_generator.gen_report(formatted_code) | report = report_generator.gen_report(formatted_code) | ||||
| @@ -403,9 +413,9 @@ class HierarchicalTree(Tree): | |||||
| "output_shape": c_nd.data.output_shape}) | "output_shape": c_nd.data.output_shape}) | ||||
| # Generate code statement. | # Generate code statement. | ||||
| expr = ", ".join([f"{k}={v}" for k, v in args.items()]) | |||||
| expr = ", ".join([f"{k.replace(f'_{c_nd.data.variable_name}', '')}={v}" | |||||
| for k, v in args.items()]) | |||||
| code_line = f"{operator}({expr})" | code_line = f"{operator}({expr})" | ||||
| module_list.append(code_line) | module_list.append(code_line) | ||||
| body = f",{NEW_LINE}{SECOND_LEVEL_INDENT}".join(module_list) | body = f",{NEW_LINE}{SECOND_LEVEL_INDENT}".join(module_list) | ||||
| @@ -435,6 +445,7 @@ class HierarchicalTree(Tree): | |||||
| if class_key.lower() in self._merged_module_args and \ | if class_key.lower() in self._merged_module_args and \ | ||||
| self._merged_module_args[class_key.lower()]: | self._merged_module_args[class_key.lower()]: | ||||
| args = f"{', '.join(self._merged_module_args[class_key.lower()])}" | args = f"{', '.join(self._merged_module_args[class_key.lower()])}" | ||||
| class_init = f"{FIRST_LEVEL_INDENT}def __init__(self, " \ | class_init = f"{FIRST_LEVEL_INDENT}def __init__(self, " \ | ||||
| f"{args}):" \ | f"{args}):" \ | ||||
| f"{NEW_LINE}{SECOND_LEVEL_INDENT}" \ | f"{NEW_LINE}{SECOND_LEVEL_INDENT}" \ | ||||
| @@ -455,7 +466,8 @@ class HierarchicalTree(Tree): | |||||
| construct_block.append(construct) | construct_block.append(construct) | ||||
| init_block.append(init) | init_block.append(init) | ||||
| class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||||
| class_construct = f"{NEW_LINE}{FIRST_LEVEL_INDENT}def construct(self, x):" \ | |||||
| f"{NEW_LINE}{SECOND_LEVEL_INDENT}" | |||||
| init_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(init_block) | init_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(init_block) | ||||
| csrt_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(construct_block) | csrt_body = f"{NEW_LINE}{SECOND_LEVEL_INDENT}".join(construct_block) | ||||
| csrt_rtn = f"{NEW_LINE}{SECOND_LEVEL_INDENT}return output{NEW_LINE}" | csrt_rtn = f"{NEW_LINE}{SECOND_LEVEL_INDENT}return output{NEW_LINE}" | ||||
| @@ -514,7 +526,7 @@ class HierarchicalTree(Tree): | |||||
| def _find_all_previous_opt_var_(self, cur_nd, pre_nd): | def _find_all_previous_opt_var_(self, cur_nd, pre_nd): | ||||
| """ | """ | ||||
| Find all input varian names. | |||||
| Find all input variable names. | |||||
| Args: | Args: | ||||
| cur_nd (Node): Current node. | cur_nd (Node): Current node. | ||||
| @@ -585,61 +597,46 @@ class HierarchicalTree(Tree): | |||||
| node.data.hash_key = unique_key | node.data.hash_key = unique_key | ||||
| return unique_key | return unique_key | ||||
| def _module_merging(self, nodes): | |||||
| """ | |||||
| Generate sub-module and corresponding params. | |||||
| Args: | |||||
| nodes (List[str]): Nodes name. | |||||
| """ | |||||
| merged_module = dict() | |||||
| def _module_merging(self): | |||||
| """Generate sub-module and corresponding params.""" | |||||
| merged_module_args = dict() | merged_module_args = dict() | ||||
| for node_name in nodes: | |||||
| nd_inst = self.get_node(node_name) | |||||
| if nd_inst.data.node_type != NodeType.MODULE.value: | |||||
| continue | |||||
| module_key = self.hash_key(nd_inst) | |||||
| if module_key not in merged_module: | |||||
| merged_module[module_key] = [nd_inst.data.args_in_code] | |||||
| else: | |||||
| merged_module[module_key].append(nd_inst.data.args_in_code) | |||||
| for module_key, module_args in merged_module.items(): | |||||
| for module_key, module_args in self._merged_module.items(): | |||||
| if module_key not in merged_module_args: | if module_key not in merged_module_args: | ||||
| merged_module_args[module_key] = [] | merged_module_args[module_key] = [] | ||||
| # Take first element's args as base. | # Take first element's args as base. | ||||
| keys = module_args[0].keys() | keys = module_args[0].keys() | ||||
| for key in keys: | for key in keys: | ||||
| for i in range(1, len(module_args)): | for i in range(1, len(module_args)): | ||||
| if module_args[0][key] != module_args[i][key]: | |||||
| if key in module_args[i] and module_args[0][key] != module_args[i][key]: | |||||
| merged_module_args[module_key].append(key) | |||||
| break | |||||
| if key not in module_args[i]: | |||||
| merged_module_args[module_key].append(key) | merged_module_args[module_key].append(key) | ||||
| break | break | ||||
| self._merged_module.update(merged_module) | |||||
| self._merged_module_args.update(merged_module_args) | self._merged_module_args.update(merged_module_args) | ||||
| def _create_module_args_and_vars(self, node, mapper): | def _create_module_args_and_vars(self, node, mapper): | ||||
| """ | """ | ||||
| Create module args. | |||||
| Create module args and variables in current node. | |||||
| Args: | Args: | ||||
| node (Node): Node on tree. | node (Node): Node on tree. | ||||
| mapper (Mapper): Mapper of params. | mapper (Mapper): Mapper of params. | ||||
| """ | """ | ||||
| # All args and value pair in current node module. | |||||
| module_args = dict() | module_args = dict() | ||||
| module_key = self.hash_key(node) | module_key = self.hash_key(node) | ||||
| created = False | created = False | ||||
| if module_key not in self._args_mgr_in_module: | |||||
| self._args_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR | |||||
| if module_key not in self._vars_mgr_in_module: | |||||
| self._vars_mgr_in_module[module_key] = GLOBAL_VAR_NAME_MGR | |||||
| self._module_vars[module_key] = [] | self._module_vars[module_key] = [] | ||||
| else: | else: | ||||
| created = True | created = True | ||||
| # Sub-modules in the module could have arg name conflicts. | |||||
| for idx, successor_name in enumerate(node.successors(self.tree_identifier)): | for idx, successor_name in enumerate(node.successors(self.tree_identifier)): | ||||
| nd_inst = self.get_node(successor_name) | nd_inst = self.get_node(successor_name) | ||||
| # Generate variable name here, then | # Generate variable name here, then | ||||
| @@ -648,12 +645,11 @@ class HierarchicalTree(Tree): | |||||
| nd_inst.data.variable_name = self._module_vars[module_key][idx] | nd_inst.data.variable_name = self._module_vars[module_key][idx] | ||||
| else: | else: | ||||
| variable_name = nd_inst.data.op_name or nd_inst.data.module_name | variable_name = nd_inst.data.op_name or nd_inst.data.module_name | ||||
| variable_name = self._args_mgr_in_module[module_key].get_name(variable_name) | |||||
| variable_name = self._vars_mgr_in_module[module_key].get_name(variable_name) | |||||
| nd_inst.data.variable_name = variable_name | nd_inst.data.variable_name = variable_name | ||||
| if nd_inst.data.node_type == NodeType.OPERATION.value: | |||||
| # Generation of params must behind variable assigment. | |||||
| nd_inst.data.param_transform(mapper) | |||||
| # Generation of params must behind variable assigment. | |||||
| nd_inst.data.param_transform(mapper) | |||||
| module_args.update(nd_inst.data.args_in_code) | module_args.update(nd_inst.data.args_in_code) | ||||
| @@ -662,6 +658,12 @@ class HierarchicalTree(Tree): | |||||
| node.data.args_in_code = module_args | node.data.args_in_code = module_args | ||||
| # Collect module args of `module_key`. | |||||
| if module_key not in self._merged_module: | |||||
| self._merged_module[module_key] = [node.data.args_in_code] | |||||
| else: | |||||
| self._merged_module[module_key].append(node.data.args_in_code) | |||||
| @staticmethod | @staticmethod | ||||
| def _create_operation_args(node, mapper): | def _create_operation_args(node, mapper): | ||||
| """ | """ | ||||
| @@ -692,21 +694,20 @@ class HierarchicalTree(Tree): | |||||
| self._hierarchical_order = hierarchical_order | self._hierarchical_order = hierarchical_order | ||||
| def sub_graph_merging(self) -> NoReturn: | def sub_graph_merging(self) -> NoReturn: | ||||
| """ | |||||
| Shrink subtree. | |||||
| """ | |||||
| """Shrink the module has only one child.""" | |||||
| self.update_hierarchical_order() | self.update_hierarchical_order() | ||||
| depths = sorted(list(self._hierarchical_order.keys()), reverse=True) | depths = sorted(list(self._hierarchical_order.keys()), reverse=True) | ||||
| for depth in depths: | for depth in depths: | ||||
| for node_name in self._hierarchical_order[depth]: | for node_name in self._hierarchical_order[depth]: | ||||
| node_inst = self[node_name] | node_inst = self[node_name] | ||||
| if not node_inst.data and len(node_inst.successors(self.tree_identifier)) == 1: | |||||
| # If the node type is module and has only one child, | |||||
| # then merge it with its child. | |||||
| if node_inst.data.node_type == NodeType.MODULE.value and \ | |||||
| len(node_inst.successors(self.tree_identifier)) == 1: | |||||
| self.shrink(node_inst) | self.shrink(node_inst) | ||||
| def _adjust_structure(self) -> NoReturn: | def _adjust_structure(self) -> NoReturn: | ||||
| """ | |||||
| Adjust tree structure to generate source code. | |||||
| """ | |||||
| """Adjust tree structure to generate source code.""" | |||||
| self.sub_graph_merging() | self.sub_graph_merging() | ||||
| self.update_hierarchical_order() | self.update_hierarchical_order() | ||||
| @@ -53,6 +53,9 @@ class ModuleNameMgr(NameMgr): | |||||
| """Module name manager.""" | """Module name manager.""" | ||||
| # Manage variable name of different modules. | |||||
| global_var_namespace = set() | |||||
| # Manage variable name of different type. | |||||
| global_op_namespace = dict() | global_op_namespace = dict() | ||||
| START_IDX = 0 | START_IDX = 0 | ||||
| @@ -81,14 +84,21 @@ class GlobalVarNameMgr: | |||||
| Returns: | Returns: | ||||
| str, module name. | str, module name. | ||||
| """ | """ | ||||
| op_type = op_type.lower() | |||||
| if op_type not in global_op_namespace: | |||||
| global_op_namespace[op_type] = START_IDX | |||||
| suffix = "" | |||||
| else: | |||||
| global_op_namespace[op_type] += 1 | |||||
| suffix = f"{global_op_namespace[op_type] - 1}" | |||||
| new_name = f"{self._get_name(op_type)}{suffix}" | |||||
| def _gen(t): | |||||
| t = t.lower() | |||||
| if t not in global_op_namespace: | |||||
| global_op_namespace[t] = START_IDX | |||||
| suffix = "" | |||||
| else: | |||||
| global_op_namespace[t] += 1 | |||||
| suffix = f"{global_op_namespace[t] - 1}" | |||||
| return f"{self._get_name(t)}{suffix}" | |||||
| new_name = _gen(op_type) | |||||
| while new_name in global_var_namespace: | |||||
| new_name = _gen(op_type) | |||||
| global_var_namespace.add(new_name) | |||||
| return new_name | return new_name | ||||
| @@ -18,6 +18,7 @@ import importlib | |||||
| import json | import json | ||||
| import os | import os | ||||
| from typing import Dict | from typing import Dict | ||||
| from mindinsight.mindconverter.common.log import logger as log | |||||
| CONFIG_JSON = "onnx_to_ms.json" | CONFIG_JSON = "onnx_to_ms.json" | ||||
| OPERATION_TABLE = os.path.join( | OPERATION_TABLE = os.path.join( | ||||
| @@ -91,7 +92,8 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| weights_converter = getattr(converter, GET_OP_WEIGHTS) | weights_converter = getattr(converter, GET_OP_WEIGHTS) | ||||
| except (ModuleNotFoundError,) as e: | except (ModuleNotFoundError,) as e: | ||||
| # If mapper can not be found, then skip it. | # If mapper can not be found, then skip it. | ||||
| print(f"Converting {op_name} failed, see {e}") | |||||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | |||||
| log.error(err_msg) | |||||
| return None, dict() | return None, dict() | ||||
| try: | try: | ||||
| @@ -99,8 +101,9 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||||
| converted_params = params_converter(params, weights) | converted_params = params_converter(params, weights) | ||||
| converted_weights = weights_converter(weights) if weights else dict() | converted_weights = weights_converter(weights) if weights else dict() | ||||
| converted_params.update(converted_weights) | converted_params.update(converted_weights) | ||||
| except (AttributeError, KeyError, ValueError, TypeError) as _: | |||||
| print(f"Converting {op_name} failed.") | |||||
| except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: | |||||
| err_msg = f"Converting {op_name} failed, see {str(e)}" | |||||
| log.error(err_msg) | |||||
| return None, dict() | return None, dict() | ||||
| return converter_name, converted_params | return converter_name, converted_params | ||||
| @@ -39,7 +39,7 @@ class ConvMapper(ONNXToMindSporeMapper): | |||||
| else: | else: | ||||
| stride = params['strides'] | stride = params['strides'] | ||||
| kernel_shape = list(weight.shape) | kernel_shape = list(weight.shape) | ||||
| in_channels = kernel_shape[-2] | |||||
| in_channels = kernel_shape[-2] * params.get("group", 1) | |||||
| out_channels = kernel_shape[-1] | out_channels = kernel_shape[-1] | ||||
| kernel_size = kernel_shape[:-2] | kernel_size = kernel_shape[:-2] | ||||
| if len(kernel_size) == 1: | if len(kernel_size) == 1: | ||||
| @@ -31,8 +31,7 @@ class GlobalPoolMapper(ONNXToMindSporeMapper): | |||||
| @staticmethod | @staticmethod | ||||
| def _convert_params(params, weights): | def _convert_params(params, weights): | ||||
| dim = 1 if len(params['input_shape']) == 3\ | |||||
| else 2 | |||||
| dim = 1 if len(params['input_shape']) == 3 else 2 | |||||
| if dim == 1: | if dim == 1: | ||||
| kernel_size = params['input_shape'][-1] // params['output_shape'][-1] | kernel_size = params['input_shape'][-1] // params['output_shape'][-1] | ||||
| else: | else: | ||||
| @@ -100,6 +100,18 @@ class Graph(BaseGraph, abc.ABC): | |||||
| self._topological_order = [] | self._topological_order = [] | ||||
| self._input_shape = dict() | self._input_shape = dict() | ||||
| def get_input_shape(self, name): | |||||
| """ | |||||
| Get node input shape. | |||||
| Args: | |||||
| name (str): Node name. | |||||
| Returns: | |||||
| list, shape. | |||||
| """ | |||||
| return self._input_shape.get(name) | |||||
| def get_output_shape(self, name): | def get_output_shape(self, name): | ||||
| """ | """ | ||||
| Get node output shape. | Get node output shape. | ||||
| @@ -112,7 +124,7 @@ class Graph(BaseGraph, abc.ABC): | |||||
| """ | """ | ||||
| return self._shape_dict.get(name) | return self._shape_dict.get(name) | ||||
| def get_input_shape(self, name): | |||||
| def get_input_shape_from_input(self, name): | |||||
| """ | """ | ||||
| Get node input shape. | Get node input shape. | ||||
| @@ -482,7 +494,7 @@ class GraphNode(abc.ABC): | |||||
| """Return op_name.""" | """Return op_name.""" | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| def replace_with_arg(self, arg): | |||||
| def replace_with_arg(self, src_arg, tgt_arg): | |||||
| """Replace actual parameter with formal parameter.""" | """Replace actual parameter with formal parameter.""" | ||||
| @abc.abstractmethod | @abc.abstractmethod | ||||
| @@ -53,7 +53,7 @@ class InputNode(GraphNode): | |||||
| def hash_key(self): | def hash_key(self): | ||||
| pass | pass | ||||
| def replace_with_arg(self, arg): | |||||
| def replace_with_arg(self, src_arg, tgt_arg): | |||||
| pass | pass | ||||
| def _get_arg_name(self, arg): | def _get_arg_name(self, arg): | ||||
| @@ -65,9 +65,30 @@ class InputNode(GraphNode): | |||||
| def __init__(self, input_shape): | def __init__(self, input_shape): | ||||
| super(InputNode, self).__init__(node=None) | super(InputNode, self).__init__(node=None) | ||||
| self._op_name = 'Input' | self._op_name = 'Input' | ||||
| self._op_params = {'node_shape': input_shape} | |||||
| self._op_params = {'input_shape': input_shape, | |||||
| "output_shape": input_shape} | |||||
| self._node_type = NodeType.INPUT.value | self._node_type = NodeType.INPUT.value | ||||
| @property | |||||
| def input_shape(self): | |||||
| """ | |||||
| Input tensor shape of current node. | |||||
| Returns: | |||||
| tuple, tensor shape of input. | |||||
| """ | |||||
| return self._op_params["input_shape"] | |||||
| @property | |||||
| def output_shape(self): | |||||
| """ | |||||
| Output tensor shape. | |||||
| Returns: | |||||
| tuple, output tensor shape. | |||||
| """ | |||||
| return self._op_params["output_shape"] | |||||
| def set_scope_name(self, original_input_scope_name): | def set_scope_name(self, original_input_scope_name): | ||||
| """ | """ | ||||
| Set scope name. | Set scope name. | ||||
| @@ -13,7 +13,6 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Define PyTorch graph.""" | """Define PyTorch graph.""" | ||||
| import warnings | |||||
| import re | import re | ||||
| from typing import Dict, NoReturn | from typing import Dict, NoReturn | ||||
| @@ -27,8 +26,11 @@ from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE | |||||
| from ..constant import LEFT_BUCKET, RIGHT_BUCKET | from ..constant import LEFT_BUCKET, RIGHT_BUCKET | ||||
| NONE_SCOPE_OP = { | NONE_SCOPE_OP = { | ||||
| 'onnx::Add': 'Add', | |||||
| 'onnx::Flatten': 'Flatten', | |||||
| "onnx::Add": "Add", | |||||
| "onnx::Flatten": "Flatten", | |||||
| "onnx::Concat": "Concat", | |||||
| "onnx::Squeeze": "Squeeze", | |||||
| "onnx::Unsqueeze": "Unsqueeze", | |||||
| } | } | ||||
| @@ -59,6 +61,7 @@ def normalize_scope_name(node): | |||||
| scopes.append(segment) | scopes.append(segment) | ||||
| if node.kind() in NONE_SCOPE_OP.keys(): | if node.kind() in NONE_SCOPE_OP.keys(): | ||||
| scopes.append(NONE_SCOPE_OP[node.kind()]) | scopes.append(NONE_SCOPE_OP[node.kind()]) | ||||
| scopes = [s for s in scopes if s] | |||||
| return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{PyTorchGraph.get_node_id(node)}" | return f"{SEPARATOR_IN_SCOPE.join(scopes)}_{PyTorchGraph.get_node_id(node)}" | ||||
| @@ -90,18 +93,16 @@ class PyTorchGraph(Graph): | |||||
| """ | """ | ||||
| if not input_shape: | if not input_shape: | ||||
| error = ValueError("`input_shape` can not be None.") | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| err_msg = "`input_shape` can not be None." | |||||
| log.error(err_msg) | |||||
| raise ValueError(err_msg) | |||||
| for item in input_shape: | for item in input_shape: | ||||
| if not isinstance(item, int): | if not isinstance(item, int): | ||||
| error = ValueError(f"Only support model with one input now, " | |||||
| f"and each shape value in `input_shape` should be int.") | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| err_msg = f"Only support model with one input now, " \ | |||||
| f"and each shape value in `input_shape` should be int." | |||||
| log.error(err_msg) | |||||
| raise ValueError(err_msg) | |||||
| @staticmethod | @staticmethod | ||||
| def _extract_shape(shape): | def _extract_shape(shape): | ||||
| @@ -116,18 +117,29 @@ class PyTorchGraph(Graph): | |||||
| """ | """ | ||||
| if "," not in shape: | if "," not in shape: | ||||
| return [] | return [] | ||||
| shape_arr = [] | |||||
| for s in shape.split(","): | for s in shape.split(","): | ||||
| s = s.strip() | |||||
| if not s: | if not s: | ||||
| return [] | return [] | ||||
| return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')] | |||||
| if ":" in s: | |||||
| s = s.split(":")[0] | |||||
| s = s.replace("!", "") | |||||
| if not s.isdigit(): | |||||
| return [] | |||||
| shape_arr.append(int(s)) | |||||
| return shape_arr | |||||
| def build(self, input_shape): | |||||
| def _trace_torch_graph(self, input_shape): | |||||
| """ | """ | ||||
| Build graph tree. | |||||
| Trace torch computational graph. | |||||
| Args: | Args: | ||||
| input_shape (tuple): Input shape of model. | |||||
| input_shape (tuple): Shape. | |||||
| Returns: | |||||
| object, pytorch graph. | |||||
| """ | """ | ||||
| import torch | import torch | ||||
| from torch.onnx import OperatorExportTypes | from torch.onnx import OperatorExportTypes | ||||
| @@ -135,24 +147,34 @@ class PyTorchGraph(Graph): | |||||
| from .torch_utils import create_autograd_variable | from .torch_utils import create_autograd_variable | ||||
| from .torch_utils import onnx_tracer | from .torch_utils import onnx_tracer | ||||
| self._check_input_shape(input_shape) | |||||
| feed_forward_ipt_shape = (1, *input_shape) | |||||
| batched_sample = create_autograd_variable(torch.rand(*feed_forward_ipt_shape)) | |||||
| # Assign execution mode to eval. | |||||
| self.model.eval() | |||||
| batched_sample = create_autograd_variable(torch.rand(*input_shape)) | |||||
| try: | try: | ||||
| # Assign execution mode to eval. | |||||
| self.model.eval() | |||||
| with OverloadTorchModuleTemporarily() as _: | with OverloadTorchModuleTemporarily() as _: | ||||
| # In pytorch higher version, trace function has a known. | # In pytorch higher version, trace function has a known. | ||||
| graph = onnx_tracer(self.model, batched_sample, | graph = onnx_tracer(self.model, batched_sample, | ||||
| OperatorExportTypes.ONNX) | OperatorExportTypes.ONNX) | ||||
| return graph | |||||
| except RuntimeError as error: | except RuntimeError as error: | ||||
| log.error(str(error)) | log.error(str(error)) | ||||
| log.exception(error) | log.exception(error) | ||||
| raise error | raise error | ||||
| def build(self, input_shape): | |||||
| """ | |||||
| Build graph tree. | |||||
| Args: | |||||
| input_shape (tuple): Input shape of model. | |||||
| """ | |||||
| self._check_input_shape(input_shape) | |||||
| feed_forward_ipt_shape = (1, *input_shape) | |||||
| graph = self._trace_torch_graph(feed_forward_ipt_shape) | |||||
| nodes = list(graph.nodes()) | nodes = list(graph.nodes()) | ||||
| for node in nodes: | for node in nodes: | ||||
| @@ -174,24 +196,43 @@ class PyTorchGraph(Graph): | |||||
| for node_input in list(node.inputs()): | for node_input in list(node.inputs()): | ||||
| # Connect input node and src node. | # Connect input node and src node. | ||||
| if PyTorchGraph.get_node_id(node_input.node()) and node_input.node().scopeName(): | |||||
| nd_id = PyTorchGraph.get_node_id(node_input.node()) | |||||
| nd_scope_name = node_input.node().kind() in NONE_SCOPE_OP or \ | |||||
| node_input.node().scopeName() | |||||
| if nd_id and nd_scope_name: | |||||
| node_input_name = normalize_scope_name( | node_input_name = normalize_scope_name( | ||||
| node_input.node() | node_input.node() | ||||
| ) | ) | ||||
| self.build_connection(node_input_name, node_name) | self.build_connection(node_input_name, node_name) | ||||
| super(PyTorchGraph, self).build(input_shape=input_shape) | super(PyTorchGraph, self).build(input_shape=input_shape) | ||||
| self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape) | |||||
| # Add Input Node | |||||
| def _collect_ipt_shape_of_each_node(self, input_shape): | |||||
| """ | |||||
| Collect input tensor shape of each node. | |||||
| Args: | |||||
| input_shape (tuple): Input shape. | |||||
| """ | |||||
| input_node = InputNode(input_shape) | input_node = InputNode(input_shape) | ||||
| input_node_name = "{}InputNode" | |||||
| for node_name, node in self._nodes_collection.items(): | for node_name, node in self._nodes_collection.items(): | ||||
| if node_name in self._input_nodes: | if node_name in self._input_nodes: | ||||
| ipt_nd_name = input_node_name.format(input_node.scope_name) | |||||
| input_node.set_scope_name(node.scope_name) | input_node.set_scope_name(node.scope_name) | ||||
| node.precursor_nodes.append(input_node.scope_name) | |||||
| node.precursor_nodes.insert(0, ipt_nd_name) | |||||
| input_node.set_successor_nodes(node_name) | input_node.set_successor_nodes(node_name) | ||||
| self._nodes_collection[input_node.scope_name] = input_node | |||||
| self._input_shape[node_name] = feed_forward_ipt_shape | |||||
| break | |||||
| self._shape_dict[ipt_nd_name] = input_node.output_shape | |||||
| ipt_shape = [] | |||||
| for p_nd in node.precursor_nodes: | |||||
| shp = self._shape_dict.get(p_nd) | |||||
| ipt_shape.append(tuple(shp)) | |||||
| self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape | |||||
| def sub_graph_merging(self): | def sub_graph_merging(self): | ||||
| """ | """ | ||||
| @@ -199,12 +240,6 @@ class PyTorchGraph(Graph): | |||||
| """ | """ | ||||
| raise NotImplementedError() | raise NotImplementedError() | ||||
| def to_ir(self, mapper): | |||||
| """ | |||||
| Convert graph to IR graph. | |||||
| """ | |||||
| raise NotImplementedError() | |||||
| def build_connection(self, src, tgt) -> NoReturn: | def build_connection(self, src, tgt) -> NoReturn: | ||||
| """ | """ | ||||
| Build connection between source node and target node. | Build connection between source node and target node. | ||||
| @@ -215,13 +250,11 @@ class PyTorchGraph(Graph): | |||||
| """ | """ | ||||
| # If src and tgt are the same node, src not in node_collection or | # If src and tgt are the same node, src not in node_collection or | ||||
| # tgt not in node_collection, | |||||
| # then skip this edge. | |||||
| # tgt not in node_collection, then skip this edge. | |||||
| if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection: | if src == tgt or src not in self._nodes_collection or tgt not in self._nodes_collection: | ||||
| if src.split(':')[0] not in self._nodes_collection: | if src.split(':')[0] not in self._nodes_collection: | ||||
| warnings.warn(f"Graph construct a self-loop node {src}. Ignored.") | |||||
| log.warning("Graph construct a self-loop node %s. Ignored.", src) | |||||
| return | return | ||||
| if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes: | if tgt not in self._nodes_collection[src.split(':')[0]].successor_nodes: | ||||
| self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt) | self._nodes_collection[src.split(':')[0]].successor_nodes.append(tgt) | ||||
| if src not in self._nodes_collection[tgt].precursor_nodes: | if src not in self._nodes_collection[tgt].precursor_nodes: | ||||
| @@ -244,11 +277,10 @@ class PyTorchGraph(Graph): | |||||
| """ | """ | ||||
| Load graph metadata. | Load graph metadata. | ||||
| """ | """ | ||||
| error = NotImplementedError("class `PyTorchGraph` has not implemented " | |||||
| "`load_metadata()`.") | |||||
| log.error(str(error)) | |||||
| log.exception(error) | |||||
| raise error | |||||
| err_msg = "class `PyTorchGraph` has not implemented " \ | |||||
| "`load_metadata()`." | |||||
| log.error(err_msg) | |||||
| raise NotImplementedError(err_msg) | |||||
| @staticmethod | @staticmethod | ||||
| def load_graph(graph_path: str): | def load_graph(graph_path: str): | ||||
| @@ -13,6 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Define PyTorch graph node.""" | """Define PyTorch graph node.""" | ||||
| from copy import deepcopy | |||||
| from .base import GraphNode | from .base import GraphNode | ||||
| from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \ | ||||
| @@ -140,7 +142,7 @@ class PyTorchGraphNode(GraphNode): | |||||
| Returns: | Returns: | ||||
| str, op name. | str, op name. | ||||
| """ | """ | ||||
| return self._op_name # if self.is_empty() else self.tag | |||||
| return self._op_name | |||||
| @property | @property | ||||
| def real_name(self): | def real_name(self): | ||||
| @@ -177,8 +179,14 @@ class PyTorchGraphNode(GraphNode): | |||||
| args.update({"input_shape": self.input_shape, | args.update({"input_shape": self.input_shape, | ||||
| "output_shape": self.output_shape}) | "output_shape": self.output_shape}) | ||||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||||
| for k, v in args.items()]) | |||||
| if self._node_type == NodeType.OPERATION.value: | |||||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||||
| for k, v in args.items()]) | |||||
| else: | |||||
| # When it's type is module, class or func, | |||||
| # it's not necessary to replace var. | |||||
| expr = ", ".join([f"{k.replace(f'_{self._variable_name}', '')}={v}" | |||||
| for k, v in args.items()]) | |||||
| declare = f"self.{self._variable_name} = {operator}({expr})" | declare = f"self.{self._variable_name} = {operator}({expr})" | ||||
| call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_in_construct})" | call = f"{self._opt_var_name} = self.{self._variable_name}({ipt_args_in_construct})" | ||||
| @@ -211,15 +219,16 @@ class PyTorchGraphNode(GraphNode): | |||||
| raw_params[k] = getitem_of_node(node, k) | raw_params[k] = getitem_of_node(node, k) | ||||
| return raw_params | return raw_params | ||||
| def replace_with_arg(self, arg): | |||||
| def replace_with_arg(self, src_arg, tgt_arg): | |||||
| """ | """ | ||||
| Replace actual parameter with formal parameter. | Replace actual parameter with formal parameter. | ||||
| Args: | Args: | ||||
| arg (str): Arg name. | |||||
| src_arg (str): Original arg name. | |||||
| tgt_arg (str): Target arg name. | |||||
| """ | """ | ||||
| self._args_in_code[arg] = arg | |||||
| self._args_in_code[src_arg] = tgt_arg | |||||
| @staticmethod | @staticmethod | ||||
| def _extract_var_name(scope_name: str): | def _extract_var_name(scope_name: str): | ||||
| @@ -241,6 +250,13 @@ class PyTorchGraphNode(GraphNode): | |||||
| mapper (Mapper): Mapper of params. | mapper (Mapper): Mapper of params. | ||||
| """ | """ | ||||
| if self._node_type != NodeType.OPERATION.value: | |||||
| args = deepcopy(self._args_in_code) | |||||
| self._args_in_code = dict() | |||||
| for arg, value in args.items(): | |||||
| self._args_in_code[self._get_arg_name(arg)] = value | |||||
| return None, None | |||||
| if not self.transformed: | if not self.transformed: | ||||
| _, _ = super(PyTorchGraphNode, self).param_transform(mapper) | _, _ = super(PyTorchGraphNode, self).param_transform(mapper) | ||||