|
|
|
@@ -24,7 +24,7 @@ from pasta.base import formatting as fmt |
|
|
|
|
|
|
|
from mindinsight.mindconverter.code_analysis import CodeAnalyzer |
|
|
|
from mindinsight.mindconverter.code_analysis import APIAnalysisSpec |
|
|
|
from mindinsight.mindconverter.config import ALL_MAPPING |
|
|
|
from mindinsight.mindconverter.config import ALL_MAPPING, F_LIST |
|
|
|
from mindinsight.mindconverter.config import NN_LIST |
|
|
|
from mindinsight.mindconverter.config import ALL_TORCH_APIS |
|
|
|
from mindinsight.mindconverter.config import ALL_2P_LIST |
|
|
|
@@ -516,6 +516,7 @@ class AstEditVisitor(ast.NodeVisitor): |
|
|
|
if self._check_tensor_object(call_func_node): |
|
|
|
api_name = '.' + name_attributes[-1] |
|
|
|
match_case = ApiMatchingEnum.API_INFER |
|
|
|
|
|
|
|
return api_name, match_case |
|
|
|
|
|
|
|
def _check_tensor_object(self, node): |
|
|
|
@@ -531,13 +532,28 @@ class AstEditVisitor(ast.NodeVisitor): |
|
|
|
if func_name not in TENSOR_DOT_LIST: |
|
|
|
return False |
|
|
|
|
|
|
|
extracted_api = [] |
|
|
|
for api in name_attributes[1:len(name_attributes) - 1]: |
|
|
|
if "(" or ")" in api: |
|
|
|
start = api.find("(") |
|
|
|
start = start if start != -1 else len(api) |
|
|
|
end = api.find(")") |
|
|
|
end = end if end != -1 else len(api) |
|
|
|
if start < end: |
|
|
|
api = f"{api[:start]}{api[end + 1:]}" |
|
|
|
extracted_api.append(api) |
|
|
|
|
|
|
|
is_tensor_object = True |
|
|
|
if self._code_analyzer: |
|
|
|
# Check whether the object is external reference. |
|
|
|
real_ref = None |
|
|
|
for ref_name in self._code_analyzer.external_references: |
|
|
|
if node_ref_name == ref_name: |
|
|
|
is_tensor_object = False |
|
|
|
real_ref = self._code_analyzer.external_references[ref_name]["external_ref_info"] |
|
|
|
break |
|
|
|
if real_ref and f"{real_ref.name}.{'.'.join(extracted_api)}" not in F_LIST: |
|
|
|
is_tensor_object = False |
|
|
|
|
|
|
|
return is_tensor_object |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|