Browse Source

Fix bug that Some functions is not converted.

Some functions in the forward call chain are not found by code analyzer, supplement from the forward call chain.
tags/v0.5.0-beta
ggpolar 5 years ago
parent
commit
833cae4053
1 changed files with 15 additions and 5 deletions
  1. +15
    -5
      mindinsight/mindconverter/ast_edits.py

+ 15
- 5
mindinsight/mindconverter/ast_edits.py View File

@@ -152,14 +152,14 @@ class _LineColEditVisitor(ast.NodeVisitor):
if self._check_arg2update(arg):
args.append(arg)
for arg in args:
arg.lineno = dst_call.lineno
# line number starts from 1, column number starts from 0.
arg.lineno += dst_call.lineno - 1
arg.col_offset += dst_call.col_offset

@staticmethod
def _check_arg2update(arg):
# Only the col_offset of the first line code is re-counted, needs to be corrected.
# When the arg is a function call, its col_offset is handled separately.
if not isinstance(arg, ast.Call) and arg.lineno == 1:
if not isinstance(arg, ast.Call):
return True
return False

@@ -269,14 +269,24 @@ class AstEditVisitor(ast.NodeVisitor):
def _convert_api(self):
"""Convert PyTorch api call to MindSpore api call in a function."""
tasks = []
found_func_nodes = []
convert_elements = self._code_analyzer.network_definitions()
for func_node_scope in convert_elements.get("functions", []):
found_func_nodes.append(func_node_scope.node)
is_forward = self._judge_forward(func_node_scope)
tasks.append((self._convert_function, (func_node_scope, is_forward)))
for class_scope in convert_elements.get("cell", []).keys():
for class_scope, func_scopes in convert_elements.get("cell", []).items():
for func_node_scope in func_scopes:
found_func_nodes.append(func_node_scope.node)
tasks.append((self._convert_cell, (class_scope,)))

# Some functions in the forward call chain are not found by self._code_analyzer.
for func_node in self._forward_list.values():
is_forward = True
if func_node and func_node not in found_func_nodes:
func_node_scope = self._code_analyzer.lookup_scope(func_node)
tasks.append((self._convert_function, (func_node_scope, is_forward)))

for convert_fun, args in tasks:
convert_fun(*args)



Loading…
Cancel
Save