diff --git a/mindinsight/mindconverter/ast_edits.py b/mindinsight/mindconverter/ast_edits.py index af4b0df1..131fc647 100644 --- a/mindinsight/mindconverter/ast_edits.py +++ b/mindinsight/mindconverter/ast_edits.py @@ -24,7 +24,7 @@ from pasta.base import formatting as fmt from mindinsight.mindconverter.code_analysis import CodeAnalyzer from mindinsight.mindconverter.code_analysis import APIAnalysisSpec -from mindinsight.mindconverter.config import ALL_MAPPING +from mindinsight.mindconverter.config import ALL_MAPPING, F_LIST from mindinsight.mindconverter.config import NN_LIST from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_2P_LIST @@ -516,6 +516,7 @@ class AstEditVisitor(ast.NodeVisitor): if self._check_tensor_object(call_func_node): api_name = '.' + name_attributes[-1] match_case = ApiMatchingEnum.API_INFER + return api_name, match_case def _check_tensor_object(self, node): @@ -531,13 +532,28 @@ class AstEditVisitor(ast.NodeVisitor): if func_name not in TENSOR_DOT_LIST: return False + extracted_api = [] + for api in name_attributes[1:len(name_attributes) - 1]: + if "(" or ")" in api: + start = api.find("(") + start = start if start != -1 else len(api) + end = api.find(")") + end = end if end != -1 else len(api) + if start < end: + api = f"{api[:start]}{api[end + 1:]}" + extracted_api.append(api) + is_tensor_object = True if self._code_analyzer: # Check whether the object is external reference. + real_ref = None for ref_name in self._code_analyzer.external_references: if node_ref_name == ref_name: - is_tensor_object = False + real_ref = self._code_analyzer.external_references[ref_name]["external_ref_info"] break + if real_ref and f"{real_ref.name}.{'.'.join(extracted_api)}" not in F_LIST: + is_tensor_object = False + return is_tensor_object @staticmethod