|
|
|
@@ -152,14 +152,14 @@ class _LineColEditVisitor(ast.NodeVisitor): |
|
|
|
if self._check_arg2update(arg): |
|
|
|
args.append(arg) |
|
|
|
for arg in args: |
|
|
|
arg.lineno = dst_call.lineno |
|
|
|
# line number starts from 1, column number starts from 0. |
|
|
|
arg.lineno += dst_call.lineno - 1 |
|
|
|
arg.col_offset += dst_call.col_offset |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _check_arg2update(arg): |
|
|
|
# Only the col_offset of the first line code is re-counted, needs to be corrected. |
|
|
|
# When the arg is a function call, its col_offset is handled separately. |
|
|
|
if not isinstance(arg, ast.Call) and arg.lineno == 1: |
|
|
|
if not isinstance(arg, ast.Call): |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
@@ -269,14 +269,24 @@ class AstEditVisitor(ast.NodeVisitor): |
|
|
|
def _convert_api(self): |
|
|
|
"""Convert PyTorch api call to MindSpore api call in a function.""" |
|
|
|
tasks = [] |
|
|
|
|
|
|
|
found_func_nodes = [] |
|
|
|
convert_elements = self._code_analyzer.network_definitions() |
|
|
|
for func_node_scope in convert_elements.get("functions", []): |
|
|
|
found_func_nodes.append(func_node_scope.node) |
|
|
|
is_forward = self._judge_forward(func_node_scope) |
|
|
|
tasks.append((self._convert_function, (func_node_scope, is_forward))) |
|
|
|
for class_scope in convert_elements.get("cell", []).keys(): |
|
|
|
for class_scope, func_scopes in convert_elements.get("cell", []).items(): |
|
|
|
for func_node_scope in func_scopes: |
|
|
|
found_func_nodes.append(func_node_scope.node) |
|
|
|
tasks.append((self._convert_cell, (class_scope,))) |
|
|
|
|
|
|
|
# Some functions in the forward call chain are not found by self._code_analyzer. |
|
|
|
for func_node in self._forward_list.values(): |
|
|
|
is_forward = True |
|
|
|
if func_node and func_node not in found_func_nodes: |
|
|
|
func_node_scope = self._code_analyzer.lookup_scope(func_node) |
|
|
|
tasks.append((self._convert_function, (func_node_scope, is_forward))) |
|
|
|
|
|
|
|
for convert_fun, args in tasks: |
|
|
|
convert_fun(*args) |
|
|
|
|
|
|
|
|