diff --git a/mindinsight/mindconverter/ast_edits.py b/mindinsight/mindconverter/ast_edits.py index 74d08466..c7dc7351 100644 --- a/mindinsight/mindconverter/ast_edits.py +++ b/mindinsight/mindconverter/ast_edits.py @@ -715,3 +715,4 @@ class AstEditVisitor(ast.NodeVisitor): def visit_Attribute(self, node): """Callback function when visit AST tree""" self._check_isinstance_parameter(node) + self.generic_visit(node) diff --git a/mindinsight/mindconverter/code_analysis.py b/mindinsight/mindconverter/code_analysis.py index 8f1eca84..eacfd87a 100644 --- a/mindinsight/mindconverter/code_analysis.py +++ b/mindinsight/mindconverter/code_analysis.py @@ -188,20 +188,26 @@ class CodeAnalyzer(ast.NodeVisitor): @staticmethod def _analyze_import_references(root_scope): - """Find out all references from the import statements.""" - external_name_ref = {} + """ + Find out all references from the import statements. + + Case1: (from)import alias, node_ref.name_ref.id is node_ref.name_ref.definition.asname. + Case2: import without alias, node_ref.name_ref.definition.asname is None. + e.g., import a.b.c, the reference definition id maybe is a, a.b or a.b.c. + The reference id a.b.c is really wanted. + """ + external_name_ref = dict() + all_node_references = [] for node_references in root_scope.external_references.values(): - for node_ref in node_references: - if node_ref.name_ref: - # case1: (from)import alias, node_ref.name_ref.id is node_ref.name_ref.definition.asname. - # case2: import without alias, node_ref.name_ref.definition.asname is None. - # e.g., import a.b.c, the reference definition id maybe is a, a.b or a.b.c. - # The reference id a.b.c is really wanted. - if node_ref.name_ref.id in [node_ref.name_ref.definition.asname, - node_ref.name_ref.definition.name]: - external_name_ref[node_ref.name_ref.id] = node_ref - else: - pass + all_node_references.extend(node_references) + + for node_ref in all_node_references: + name_ref = node_ref.name_ref + if not name_ref: + continue + definition = name_ref.definition + if node_ref.name_ref.id in [definition.asname, definition.name]: + external_name_ref[name_ref.id] = node_ref return external_name_ref @@ -230,7 +236,7 @@ class CodeAnalyzer(ast.NodeVisitor): full_name = self._get_full_name(node) if not full_name: return None - + whole_name = full_name # node is in stack top pos if node is self._stack[-1]: parent_index = -1 @@ -238,8 +244,6 @@ class CodeAnalyzer(ast.NodeVisitor): parent_index -= 1 whole_name = self._get_full_name(self._stack[parent_index]) - else: - whole_name = full_name return whole_name def _is_ref_convertible_imports(self, node): @@ -260,26 +264,24 @@ class CodeAnalyzer(ast.NodeVisitor): return check_result @staticmethod - def _get_external_node(external_references): + def _get_external_node(external_references, only_convertible=False): """Get all external reference nodes.""" external_nodes = {} for ref_name, ref_info in external_references.items(): - external_nodes.update({ref_info['external_ref_info'].node: ref_name}) + is_add = False + if only_convertible: + if ref_info['external_ref_info'].name in APIAnalysisSpec.get_convertible_external_names(): + is_add = True + else: + is_add = True + if is_add: + external_nodes.update({ref_info['external_ref_info'].node: ref_name}) return external_nodes - @staticmethod - def _get_convertible_external_node(external_name_ref): - """Get all convertible external reference nodes.""" - convertible_external_nodes = {} - for ref_name, ref_info in external_name_ref.items(): - if ref_info['external_ref_info'].name in APIAnalysisSpec.get_convertible_external_names(): - convertible_external_nodes.update({ref_info['external_ref_info'].node: ref_name}) - return convertible_external_nodes - def _update_external_ref_parent(self, node): """Set external reference parent node info.""" - external_nodes = self._get_external_node(self._external_references) - convertible_external_nodes = self._get_convertible_external_node(self._external_references) + external_nodes = self._get_external_node(self._external_references, only_convertible=False) + convertible_external_nodes = self._get_external_node(self._external_references, only_convertible=True) for name_node in node.names: if name_node in convertible_external_nodes.keys(): if len(node.names) > 1: @@ -326,15 +328,18 @@ class CodeAnalyzer(ast.NodeVisitor): self.generic_visit(node) - def visit_Import(self, node): - """Callback function when visit AST tree""" + def _update_external_when_visit(self, node): + """Update external reference when visiting import and import from statements.""" self._update_external_ref_parent(node) self.generic_visit(node) + def visit_Import(self, node): + """Callback function when visit AST tree""" + self._update_external_when_visit(node) + def visit_ImportFrom(self, node): """Callback function when visit AST tree""" - self._update_external_ref_parent(node) - self.generic_visit(node) + self._update_external_when_visit(node) def visit_Call(self, node): """Callback function when visit AST tree"""