From 833cae4053c68ff0a3e1ed401d028e7ea89d0d2c Mon Sep 17 00:00:00 2001 From: ggpolar Date: Sat, 20 Jun 2020 13:14:53 +0800 Subject: [PATCH] Fix bug that Some functions is not converted. Some functions in the forward call chain are not found by code analyzer, supplement from the forward call chain. --- mindinsight/mindconverter/ast_edits.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/mindinsight/mindconverter/ast_edits.py b/mindinsight/mindconverter/ast_edits.py index c7dc7351..6dfa7618 100644 --- a/mindinsight/mindconverter/ast_edits.py +++ b/mindinsight/mindconverter/ast_edits.py @@ -152,14 +152,14 @@ class _LineColEditVisitor(ast.NodeVisitor): if self._check_arg2update(arg): args.append(arg) for arg in args: - arg.lineno = dst_call.lineno + # line number starts from 1, column number starts from 0. + arg.lineno += dst_call.lineno - 1 arg.col_offset += dst_call.col_offset @staticmethod def _check_arg2update(arg): - # Only the col_offset of the first line code is re-counted, needs to be corrected. # When the arg is a function call, its col_offset is handled separately. - if not isinstance(arg, ast.Call) and arg.lineno == 1: + if not isinstance(arg, ast.Call): return True return False @@ -269,14 +269,24 @@ class AstEditVisitor(ast.NodeVisitor): def _convert_api(self): """Convert PyTorch api call to MindSpore api call in a function.""" tasks = [] - + found_func_nodes = [] convert_elements = self._code_analyzer.network_definitions() for func_node_scope in convert_elements.get("functions", []): + found_func_nodes.append(func_node_scope.node) is_forward = self._judge_forward(func_node_scope) tasks.append((self._convert_function, (func_node_scope, is_forward))) - for class_scope in convert_elements.get("cell", []).keys(): + for class_scope, func_scopes in convert_elements.get("cell", []).items(): + for func_node_scope in func_scopes: + found_func_nodes.append(func_node_scope.node) tasks.append((self._convert_cell, (class_scope,))) + # Some functions in the forward call chain are not found by self._code_analyzer. + for func_node in self._forward_list.values(): + is_forward = True + if func_node and func_node not in found_func_nodes: + func_node_scope = self._code_analyzer.lookup_scope(func_node) + tasks.append((self._convert_function, (func_node_scope, is_forward))) + for convert_fun, args in tasks: convert_fun(*args)