|
|
@@ -27,11 +27,13 @@ from mindinsight.mindconverter.config import ALL_MAPPING |
|
|
from mindinsight.mindconverter.config import NN_LIST |
|
|
from mindinsight.mindconverter.config import NN_LIST |
|
|
from mindinsight.mindconverter.config import ALL_TORCH_APIS |
|
|
from mindinsight.mindconverter.config import ALL_TORCH_APIS |
|
|
from mindinsight.mindconverter.config import ALL_2P_LIST |
|
|
from mindinsight.mindconverter.config import ALL_2P_LIST |
|
|
|
|
|
from mindinsight.mindconverter.config import get_corresponding_ms_name |
|
|
from mindinsight.mindconverter.config import get_prompt_info |
|
|
from mindinsight.mindconverter.config import get_prompt_info |
|
|
from mindinsight.mindconverter.common.log import logger |
|
|
from mindinsight.mindconverter.common.log import logger |
|
|
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport |
|
|
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport |
|
|
from mindinsight.mindconverter.forward_call import ForwardCall |
|
|
from mindinsight.mindconverter.forward_call import ForwardCall |
|
|
|
|
|
|
|
|
|
|
|
LOG_FMT_INSERT = "[Insert] '%s' is inserted to the converted file." |
|
|
LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'." |
|
|
LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'." |
|
|
LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s" |
|
|
LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s" |
|
|
LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s" |
|
|
LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s" |
|
|
@@ -54,16 +56,22 @@ class _ConvertReport: |
|
|
def __init__(self, is_stub=False): |
|
|
def __init__(self, is_stub=False): |
|
|
self._is_stub = is_stub |
|
|
self._is_stub = is_stub |
|
|
self._max_line = 0 |
|
|
self._max_line = 0 |
|
|
self._log = [] # report log, type is (severity, line, col, msg) |
|
|
|
|
|
|
|
|
self._log_head = [] |
|
|
|
|
|
self._log_body = [] # report log, type is (severity, line, col, msg) |
|
|
|
|
|
|
|
|
def _add_log(self, severity, line, col, msg): |
|
|
def _add_log(self, severity, line, col, msg): |
|
|
"""Add log.""" |
|
|
"""Add log.""" |
|
|
if self._is_stub: |
|
|
if self._is_stub: |
|
|
return |
|
|
return |
|
|
|
|
|
if line is None and col is None: |
|
|
|
|
|
self._log_head.append(msg) |
|
|
|
|
|
return |
|
|
if isinstance(line, int) and isinstance(col, int): |
|
|
if isinstance(line, int) and isinstance(col, int): |
|
|
self._log.append((severity, line, col, msg)) |
|
|
|
|
|
|
|
|
self._log_body.append((severity, line, col, msg)) |
|
|
if self._max_line < line: |
|
|
if self._max_line < line: |
|
|
self._max_line = line |
|
|
self._max_line = line |
|
|
|
|
|
else: |
|
|
|
|
|
raise TypeError('The parameter type is incorrect.') |
|
|
|
|
|
|
|
|
def info(self, line, col, msg): |
|
|
def info(self, line, col, msg): |
|
|
"""Interface to add infer log""" |
|
|
"""Interface to add infer log""" |
|
|
@@ -73,14 +81,24 @@ class _ConvertReport: |
|
|
"""Interface to add warning log""" |
|
|
"""Interface to add warning log""" |
|
|
self._add_log(logging.WARNING, line, col, msg) |
|
|
self._add_log(logging.WARNING, line, col, msg) |
|
|
|
|
|
|
|
|
|
|
|
def header_msg(self, msg): |
|
|
|
|
|
"""Interface to add header message log""" |
|
|
|
|
|
self._add_log(logging.INFO, None, None, msg) |
|
|
|
|
|
|
|
|
def get_logs(self): |
|
|
def get_logs(self): |
|
|
"""Get convert logs""" |
|
|
"""Get convert logs""" |
|
|
logs = [] |
|
|
logs = [] |
|
|
|
|
|
logs.extend(self._log_head) |
|
|
# sort rule: line * self._max_line + col |
|
|
# sort rule: line * self._max_line + col |
|
|
self._log.sort(key=lambda log: log[1] * self._max_line + log[2]) |
|
|
|
|
|
for log_info in self._log: |
|
|
|
|
|
|
|
|
self._log_body.sort(key=lambda log: log[1] * self._max_line + log[2]) |
|
|
|
|
|
for log_info in self._log_body: |
|
|
log_info = "line %d:%d: %s" % (log_info[1], log_info[2], log_info[3]) |
|
|
log_info = "line %d:%d: %s" % (log_info[1], log_info[2], log_info[3]) |
|
|
logs.append(log_info) |
|
|
|
|
|
|
|
|
if logs: |
|
|
|
|
|
# Deduplication for logs |
|
|
|
|
|
if logs[-1] != log_info: |
|
|
|
|
|
logs.append(log_info) |
|
|
|
|
|
else: |
|
|
|
|
|
logs.append(log_info) |
|
|
return logs |
|
|
return logs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -262,7 +280,8 @@ class AstEditVisitor(ast.NodeVisitor): |
|
|
new_func_name = 'construct' |
|
|
new_func_name = 'construct' |
|
|
if func_ast_node.name == old_func_name: |
|
|
if func_ast_node.name == old_func_name: |
|
|
func_ast_node.name = new_func_name |
|
|
func_ast_node.name = new_func_name |
|
|
self._process_log.info(func_ast_node.lineno, func_ast_node.col_offset, |
|
|
|
|
|
|
|
|
real_line_number = self._get_real_line_number(func_ast_node) |
|
|
|
|
|
self._process_log.info(real_line_number, func_ast_node.col_offset, |
|
|
LOG_FMT_CONVERT % (old_func_name, new_func_name)) |
|
|
LOG_FMT_CONVERT % (old_func_name, new_func_name)) |
|
|
|
|
|
|
|
|
def _convert_api(self): |
|
|
def _convert_api(self): |
|
|
@@ -299,6 +318,15 @@ class AstEditVisitor(ast.NodeVisitor): |
|
|
source_code = pasta.dump(node) |
|
|
source_code = pasta.dump(node) |
|
|
return source_code[pos:] |
|
|
return source_code[pos:] |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def _get_real_line_number(node): |
|
|
|
|
|
"""Get the real line number of the node.""" |
|
|
|
|
|
try: |
|
|
|
|
|
line_number = node.lineno + len(node.decorator_list) |
|
|
|
|
|
except AttributeError: |
|
|
|
|
|
line_number = node.lineno |
|
|
|
|
|
return line_number |
|
|
|
|
|
|
|
|
def _replace_external_reference(self): |
|
|
def _replace_external_reference(self): |
|
|
""" |
|
|
""" |
|
|
Replace external reference statements. |
|
|
Replace external reference statements. |
|
|
@@ -349,6 +377,7 @@ class AstEditVisitor(ast.NodeVisitor): |
|
|
insert_pos += 1 |
|
|
insert_pos += 1 |
|
|
else: |
|
|
else: |
|
|
try: |
|
|
try: |
|
|
|
|
|
# insert pos after the last one, if last one name is replaced. |
|
|
replaced_with_node = names_replaced_with[src_name] |
|
|
replaced_with_node = names_replaced_with[src_name] |
|
|
insert_pos = self._tree.body.index(replaced_with_node) + 1 |
|
|
insert_pos = self._tree.body.index(replaced_with_node) + 1 |
|
|
except ValueError: |
|
|
except ValueError: |
|
|
@@ -359,6 +388,8 @@ class AstEditVisitor(ast.NodeVisitor): |
|
|
for insert_pos, new_node in new_import_node.items(): |
|
|
for insert_pos, new_node in new_import_node.items(): |
|
|
# Insert the node into the module |
|
|
# Insert the node into the module |
|
|
self._tree.body.insert(insert_pos + insert_cnt, new_node) |
|
|
self._tree.body.insert(insert_pos + insert_cnt, new_node) |
|
|
|
|
|
new_code = self._dump_without_prefix(new_node) |
|
|
|
|
|
self._process_log.header_msg(LOG_FMT_INSERT % new_code.strip()) |
|
|
insert_cnt += 1 |
|
|
insert_cnt += 1 |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
|
@@ -445,8 +476,10 @@ class AstEditVisitor(ast.NodeVisitor): |
|
|
|
|
|
|
|
|
is_include_sub_call = self._is_include_sub_call(call_func_node) |
|
|
is_include_sub_call = self._is_include_sub_call(call_func_node) |
|
|
if is_include_sub_call: |
|
|
if is_include_sub_call: |
|
|
|
|
|
# x.y().z splits to ['x.y()', 'z'] |
|
|
name_attributes = call_name.rsplit('.', 1) |
|
|
name_attributes = call_name.rsplit('.', 1) |
|
|
else: |
|
|
else: |
|
|
|
|
|
# x.y.z splits to ['x', 'y', 'z'] |
|
|
name_attributes = call_name.split('.') |
|
|
name_attributes = call_name.split('.') |
|
|
|
|
|
|
|
|
# rewritten external module name |
|
|
# rewritten external module name |
|
|
@@ -665,7 +698,7 @@ class AstEditVisitor(ast.NodeVisitor): |
|
|
try: |
|
|
try: |
|
|
new_node = pasta.parse(new_code).body[0].value |
|
|
new_node = pasta.parse(new_code).body[0].value |
|
|
# find the first call name |
|
|
# find the first call name |
|
|
new_api_name = new_code[:new_code.find('(')] |
|
|
|
|
|
|
|
|
new_api_name = get_corresponding_ms_name(matched_api_name) |
|
|
except AttributeError: |
|
|
except AttributeError: |
|
|
new_node = pasta.parse(new_code).body[0] |
|
|
new_node = pasta.parse(new_code).body[0] |
|
|
new_api_name = new_code |
|
|
new_api_name = new_code |
|
|
|