Browse Source

Fix AST bugs.

tags/v1.0.0
liuchongming 5 years ago
parent
commit
4a0f5e247a
1 changed files with 18 additions and 2 deletions
  1. +18
    -2
      mindinsight/mindconverter/ast_edits.py

+ 18
- 2
mindinsight/mindconverter/ast_edits.py View File

@@ -24,7 +24,7 @@ from pasta.base import formatting as fmt

from mindinsight.mindconverter.code_analysis import CodeAnalyzer
from mindinsight.mindconverter.code_analysis import APIAnalysisSpec
from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import ALL_MAPPING, F_LIST
from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST
@@ -516,6 +516,7 @@ class AstEditVisitor(ast.NodeVisitor):
if self._check_tensor_object(call_func_node):
api_name = '.' + name_attributes[-1]
match_case = ApiMatchingEnum.API_INFER

return api_name, match_case

def _check_tensor_object(self, node):
@@ -531,13 +532,28 @@ class AstEditVisitor(ast.NodeVisitor):
if func_name not in TENSOR_DOT_LIST:
return False

extracted_api = []
for api in name_attributes[1:len(name_attributes) - 1]:
if "(" or ")" in api:
start = api.find("(")
start = start if start != -1 else len(api)
end = api.find(")")
end = end if end != -1 else len(api)
if start < end:
api = f"{api[:start]}{api[end + 1:]}"
extracted_api.append(api)

is_tensor_object = True
if self._code_analyzer:
# Check whether the object is external reference.
real_ref = None
for ref_name in self._code_analyzer.external_references:
if node_ref_name == ref_name:
is_tensor_object = False
real_ref = self._code_analyzer.external_references[ref_name]["external_ref_info"]
break
if real_ref and f"{real_ref.name}.{'.'.join(extracted_api)}" not in F_LIST:
is_tensor_object = False

return is_tensor_object

@staticmethod


Loading…
Cancel
Save