Browse Source

!328 The annotate maybe lost when delete the import statement.

Merge pull request !328 from ggpolar/br_wzk_dev
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
97ca18d255
1 changed files with 65 additions and 31 deletions
  1. +65
    -31
      mindinsight/mindconverter/ast_edits.py

+ 65
- 31
mindinsight/mindconverter/ast_edits.py View File

@@ -20,7 +20,6 @@ import re
from enum import Enum from enum import Enum


import pasta import pasta
from pasta.augment import import_utils


from mindinsight.mindconverter.code_analysis import CodeAnalyzer from mindinsight.mindconverter.code_analysis import CodeAnalyzer
from mindinsight.mindconverter.code_analysis import APIAnalysisSpec from mindinsight.mindconverter.code_analysis import APIAnalysisSpec
@@ -290,58 +289,93 @@ class AstEditVisitor(ast.NodeVisitor):
for convert_fun, args in tasks: for convert_fun, args in tasks:
convert_fun(*args) convert_fun(*args)


def _convert_external_reference(self):
"""Convert import statements."""
name_replace = APIAnalysisSpec.import_name_mapping
replace_imports = list(name_replace.values())
@staticmethod
def _dump_without_prefix(node):
"""Get the python source for an AST."""
pos = 0
source_prefix = pasta.base.formatting.get(node, 'prefix')
if source_prefix:
pos = len(source_prefix)
source_code = pasta.dump(node)
return source_code[pos]

def _replace_external_reference(self):
"""
Replace external reference statements.


Returns:
dict, key is external name, value is the new replaced node.
"""
all_name_mappings = APIAnalysisSpec.import_name_mapping
names_replaced_with = dict()
for ref_info in self._code_analyzer.external_references.values(): for ref_info in self._code_analyzer.external_references.values():
external_ref_info = ref_info['external_ref_info'] external_ref_info = ref_info['external_ref_info']
parent_node = ref_info['parent_node']
if parent_node is None:
import_node = ref_info['parent_node']
if import_node is None:
continue continue
code = pasta.dump(parent_node)
code = self._dump_without_prefix(import_node)
import_parent_node = self._code_analyzer.root_scope.parent(import_node)
# replace import with new name
if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names(): if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names():
external_ref_info = ref_info['external_ref_info'] external_ref_info = ref_info['external_ref_info']
if external_ref_info.name in name_replace.keys():
import_utils.remove_import_alias_node(self._code_analyzer.root_scope, external_ref_info.node)
replace_info = name_replace[external_ref_info.name]
new_ref_name = replace_info[1]
new_external_name = replace_info[0]
if new_ref_name:
new_code = f'import {new_external_name} as {new_ref_name}'
else:
new_code = f'import {new_external_name}'

self._process_log.info(parent_node.lineno, parent_node.col_offset, LOG_FMT_CONVERT %
if external_ref_info.name in all_name_mappings.keys():
replace_info = all_name_mappings[external_ref_info.name]
new_node = self._make_import(name_to_import=replace_info[0], as_name=replace_info[1])
new_code = pasta.dump(new_node)
pasta.ast_utils.replace_child(import_parent_node, import_node, new_node)
names_replaced_with.update({external_ref_info.name: new_node})
self._process_log.info(import_node.lineno, import_node.col_offset, LOG_FMT_CONVERT %
(code.strip(), new_code.strip())) (code.strip(), new_code.strip()))
elif external_ref_info.name.startswith('torch.'): elif external_ref_info.name.startswith('torch.'):
self._process_log.warning(parent_node.lineno, parent_node.col_offset, LOG_FMT_NOT_CONVERT %
self._process_log.warning(import_node.lineno, import_node.col_offset, LOG_FMT_NOT_CONVERT %
(code.strip(), LOG_SUGGESTION_MANUAL_CONVERT)) (code.strip(), LOG_SUGGESTION_MANUAL_CONVERT))
else: else:
pass pass
return names_replaced_with


# Insert import in reverse order, display in forward order.
for idx in range(len(replace_imports) - 1, -1, -1):
replace_import = replace_imports[idx]
if replace_import[1]:
self._add_import(name_to_import=replace_import[0], as_name=replace_import[1])
def _convert_external_reference(self):
"""Convert import statements."""
all_name_mappings = APIAnalysisSpec.import_name_mapping

# Step1. Replace external reference first.
names_replaced_with = self._replace_external_reference()
new_import_node = dict()
insert_pos = 0
# Step2. Find out remaining mapping name which not found in script.
for src_name, new_import_name in all_name_mappings.items():
if src_name not in names_replaced_with:
new_node = self._make_import(name_to_import=new_import_name[0], as_name=new_import_name[1])
new_import_node.update({insert_pos: new_node})
insert_pos += 1
else: else:
self._add_import(name_to_import=replace_import[0])
try:
replaced_with_node = names_replaced_with[src_name]
insert_pos = self._tree.body.index(replaced_with_node) + 1
except ValueError:
pass

# Step3. Insert import reference in order.
insert_cnt = 0
for insert_pos, new_node in new_import_node.items():
# Insert the node into the module
self._tree.body.insert(insert_pos + insert_cnt, new_node)
insert_cnt += 1


def _add_import(self, name_to_import, as_name=None):
@staticmethod
def _make_import(name_to_import, as_name=None):
""" """
Adds an import to the ast tree.
Create an import to the ast tree.


Args: Args:
name_to_import: (string) The absolute name to import. name_to_import: (string) The absolute name to import.
as_name: (string) The alias for the import ("import name_to_import as asname") as_name: (string) The alias for the import ("import name_to_import as asname")

Returns:
ast.Import, a new ast.Import node.
""" """
new_alias = ast.alias(name=name_to_import, asname=as_name) new_alias = ast.alias(name=name_to_import, asname=as_name)
import_node = ast.Import(names=[new_alias]) import_node = ast.Import(names=[new_alias])

# Insert the node at the top of the module
self._tree.body.insert(1 if pasta.base.ast_utils.has_docstring(self._tree) else 0, import_node)
return import_node


def _convert_function(self, func_scope, is_forward): def _convert_function(self, func_scope, is_forward):
""" """


Loading…
Cancel
Save