|
|
@@ -20,7 +20,6 @@ import re |
|
|
from enum import Enum |
|
|
from enum import Enum |
|
|
|
|
|
|
|
|
import pasta |
|
|
import pasta |
|
|
from pasta.augment import import_utils |
|
|
|
|
|
|
|
|
|
|
|
from mindinsight.mindconverter.code_analysis import CodeAnalyzer |
|
|
from mindinsight.mindconverter.code_analysis import CodeAnalyzer |
|
|
from mindinsight.mindconverter.code_analysis import APIAnalysisSpec |
|
|
from mindinsight.mindconverter.code_analysis import APIAnalysisSpec |
|
|
@@ -290,58 +289,93 @@ class AstEditVisitor(ast.NodeVisitor): |
|
|
for convert_fun, args in tasks: |
|
|
for convert_fun, args in tasks: |
|
|
convert_fun(*args) |
|
|
convert_fun(*args) |
|
|
|
|
|
|
|
|
def _convert_external_reference(self): |
|
|
|
|
|
"""Convert import statements.""" |
|
|
|
|
|
name_replace = APIAnalysisSpec.import_name_mapping |
|
|
|
|
|
replace_imports = list(name_replace.values()) |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def _dump_without_prefix(node): |
|
|
|
|
|
"""Get the python source for an AST.""" |
|
|
|
|
|
pos = 0 |
|
|
|
|
|
source_prefix = pasta.base.formatting.get(node, 'prefix') |
|
|
|
|
|
if source_prefix: |
|
|
|
|
|
pos = len(source_prefix) |
|
|
|
|
|
source_code = pasta.dump(node) |
|
|
|
|
|
return source_code[pos] |
|
|
|
|
|
|
|
|
|
|
|
def _replace_external_reference(self): |
|
|
|
|
|
""" |
|
|
|
|
|
Replace external reference statements. |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
dict, key is external name, value is the new replaced node. |
|
|
|
|
|
""" |
|
|
|
|
|
all_name_mappings = APIAnalysisSpec.import_name_mapping |
|
|
|
|
|
names_replaced_with = dict() |
|
|
for ref_info in self._code_analyzer.external_references.values(): |
|
|
for ref_info in self._code_analyzer.external_references.values(): |
|
|
external_ref_info = ref_info['external_ref_info'] |
|
|
external_ref_info = ref_info['external_ref_info'] |
|
|
parent_node = ref_info['parent_node'] |
|
|
|
|
|
if parent_node is None: |
|
|
|
|
|
|
|
|
import_node = ref_info['parent_node'] |
|
|
|
|
|
if import_node is None: |
|
|
continue |
|
|
continue |
|
|
code = pasta.dump(parent_node) |
|
|
|
|
|
|
|
|
code = self._dump_without_prefix(import_node) |
|
|
|
|
|
import_parent_node = self._code_analyzer.root_scope.parent(import_node) |
|
|
|
|
|
# replace import with new name |
|
|
if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names(): |
|
|
if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names(): |
|
|
external_ref_info = ref_info['external_ref_info'] |
|
|
external_ref_info = ref_info['external_ref_info'] |
|
|
if external_ref_info.name in name_replace.keys(): |
|
|
|
|
|
import_utils.remove_import_alias_node(self._code_analyzer.root_scope, external_ref_info.node) |
|
|
|
|
|
replace_info = name_replace[external_ref_info.name] |
|
|
|
|
|
new_ref_name = replace_info[1] |
|
|
|
|
|
new_external_name = replace_info[0] |
|
|
|
|
|
if new_ref_name: |
|
|
|
|
|
new_code = f'import {new_external_name} as {new_ref_name}' |
|
|
|
|
|
else: |
|
|
|
|
|
new_code = f'import {new_external_name}' |
|
|
|
|
|
|
|
|
|
|
|
self._process_log.info(parent_node.lineno, parent_node.col_offset, LOG_FMT_CONVERT % |
|
|
|
|
|
|
|
|
if external_ref_info.name in all_name_mappings.keys(): |
|
|
|
|
|
replace_info = all_name_mappings[external_ref_info.name] |
|
|
|
|
|
new_node = self._make_import(name_to_import=replace_info[0], as_name=replace_info[1]) |
|
|
|
|
|
new_code = pasta.dump(new_node) |
|
|
|
|
|
pasta.ast_utils.replace_child(import_parent_node, import_node, new_node) |
|
|
|
|
|
names_replaced_with.update({external_ref_info.name: new_node}) |
|
|
|
|
|
self._process_log.info(import_node.lineno, import_node.col_offset, LOG_FMT_CONVERT % |
|
|
(code.strip(), new_code.strip())) |
|
|
(code.strip(), new_code.strip())) |
|
|
elif external_ref_info.name.startswith('torch.'): |
|
|
elif external_ref_info.name.startswith('torch.'): |
|
|
self._process_log.warning(parent_node.lineno, parent_node.col_offset, LOG_FMT_NOT_CONVERT % |
|
|
|
|
|
|
|
|
self._process_log.warning(import_node.lineno, import_node.col_offset, LOG_FMT_NOT_CONVERT % |
|
|
(code.strip(), LOG_SUGGESTION_MANUAL_CONVERT)) |
|
|
(code.strip(), LOG_SUGGESTION_MANUAL_CONVERT)) |
|
|
else: |
|
|
else: |
|
|
pass |
|
|
pass |
|
|
|
|
|
return names_replaced_with |
|
|
|
|
|
|
|
|
# Insert import in reverse order, display in forward order. |
|
|
|
|
|
for idx in range(len(replace_imports) - 1, -1, -1): |
|
|
|
|
|
replace_import = replace_imports[idx] |
|
|
|
|
|
if replace_import[1]: |
|
|
|
|
|
self._add_import(name_to_import=replace_import[0], as_name=replace_import[1]) |
|
|
|
|
|
|
|
|
def _convert_external_reference(self): |
|
|
|
|
|
"""Convert import statements.""" |
|
|
|
|
|
all_name_mappings = APIAnalysisSpec.import_name_mapping |
|
|
|
|
|
|
|
|
|
|
|
# Step1. Replace external reference first. |
|
|
|
|
|
names_replaced_with = self._replace_external_reference() |
|
|
|
|
|
new_import_node = dict() |
|
|
|
|
|
insert_pos = 0 |
|
|
|
|
|
# Step2. Find out remaining mapping name which not found in script. |
|
|
|
|
|
for src_name, new_import_name in all_name_mappings.items(): |
|
|
|
|
|
if src_name not in names_replaced_with: |
|
|
|
|
|
new_node = self._make_import(name_to_import=new_import_name[0], as_name=new_import_name[1]) |
|
|
|
|
|
new_import_node.update({insert_pos: new_node}) |
|
|
|
|
|
insert_pos += 1 |
|
|
else: |
|
|
else: |
|
|
self._add_import(name_to_import=replace_import[0]) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
replaced_with_node = names_replaced_with[src_name] |
|
|
|
|
|
insert_pos = self._tree.body.index(replaced_with_node) + 1 |
|
|
|
|
|
except ValueError: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
# Step3. Insert import reference in order. |
|
|
|
|
|
insert_cnt = 0 |
|
|
|
|
|
for insert_pos, new_node in new_import_node.items(): |
|
|
|
|
|
# Insert the node into the module |
|
|
|
|
|
self._tree.body.insert(insert_pos + insert_cnt, new_node) |
|
|
|
|
|
insert_cnt += 1 |
|
|
|
|
|
|
|
|
def _add_import(self, name_to_import, as_name=None): |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def _make_import(name_to_import, as_name=None): |
|
|
""" |
|
|
""" |
|
|
Adds an import to the ast tree. |
|
|
|
|
|
|
|
|
Create an import to the ast tree. |
|
|
|
|
|
|
|
|
Args: |
|
|
Args: |
|
|
name_to_import: (string) The absolute name to import. |
|
|
name_to_import: (string) The absolute name to import. |
|
|
as_name: (string) The alias for the import ("import name_to_import as asname") |
|
|
as_name: (string) The alias for the import ("import name_to_import as asname") |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
ast.Import, a new ast.Import node. |
|
|
""" |
|
|
""" |
|
|
new_alias = ast.alias(name=name_to_import, asname=as_name) |
|
|
new_alias = ast.alias(name=name_to_import, asname=as_name) |
|
|
import_node = ast.Import(names=[new_alias]) |
|
|
import_node = ast.Import(names=[new_alias]) |
|
|
|
|
|
|
|
|
# Insert the node at the top of the module |
|
|
|
|
|
self._tree.body.insert(1 if pasta.base.ast_utils.has_docstring(self._tree) else 0, import_node) |
|
|
|
|
|
|
|
|
return import_node |
|
|
|
|
|
|
|
|
def _convert_function(self, func_scope, is_forward): |
|
|
def _convert_function(self, func_scope, is_forward): |
|
|
""" |
|
|
""" |
|
|
|