| @@ -24,7 +24,7 @@ from pasta.base import formatting as fmt | |||||
| from mindinsight.mindconverter.code_analysis import CodeAnalyzer | from mindinsight.mindconverter.code_analysis import CodeAnalyzer | ||||
| from mindinsight.mindconverter.code_analysis import APIAnalysisSpec | from mindinsight.mindconverter.code_analysis import APIAnalysisSpec | ||||
| from mindinsight.mindconverter.config import ALL_MAPPING | |||||
| from mindinsight.mindconverter.config import ALL_MAPPING, F_LIST | |||||
| from mindinsight.mindconverter.config import NN_LIST | from mindinsight.mindconverter.config import NN_LIST | ||||
| from mindinsight.mindconverter.config import ALL_TORCH_APIS | from mindinsight.mindconverter.config import ALL_TORCH_APIS | ||||
| from mindinsight.mindconverter.config import ALL_2P_LIST | from mindinsight.mindconverter.config import ALL_2P_LIST | ||||
| @@ -516,6 +516,7 @@ class AstEditVisitor(ast.NodeVisitor): | |||||
| if self._check_tensor_object(call_func_node): | if self._check_tensor_object(call_func_node): | ||||
| api_name = '.' + name_attributes[-1] | api_name = '.' + name_attributes[-1] | ||||
| match_case = ApiMatchingEnum.API_INFER | match_case = ApiMatchingEnum.API_INFER | ||||
| return api_name, match_case | return api_name, match_case | ||||
| def _check_tensor_object(self, node): | def _check_tensor_object(self, node): | ||||
| @@ -531,13 +532,28 @@ class AstEditVisitor(ast.NodeVisitor): | |||||
| if func_name not in TENSOR_DOT_LIST: | if func_name not in TENSOR_DOT_LIST: | ||||
| return False | return False | ||||
| extracted_api = [] | |||||
| for api in name_attributes[1:len(name_attributes) - 1]: | |||||
| if "(" or ")" in api: | |||||
| start = api.find("(") | |||||
| start = start if start != -1 else len(api) | |||||
| end = api.find(")") | |||||
| end = end if end != -1 else len(api) | |||||
| if start < end: | |||||
| api = f"{api[:start]}{api[end + 1:]}" | |||||
| extracted_api.append(api) | |||||
| is_tensor_object = True | is_tensor_object = True | ||||
| if self._code_analyzer: | if self._code_analyzer: | ||||
| # Check whether the object is external reference. | # Check whether the object is external reference. | ||||
| real_ref = None | |||||
| for ref_name in self._code_analyzer.external_references: | for ref_name in self._code_analyzer.external_references: | ||||
| if node_ref_name == ref_name: | if node_ref_name == ref_name: | ||||
| is_tensor_object = False | |||||
| real_ref = self._code_analyzer.external_references[ref_name]["external_ref_info"] | |||||
| break | break | ||||
| if real_ref and f"{real_ref.name}.{'.'.join(extracted_api)}" not in F_LIST: | |||||
| is_tensor_object = False | |||||
| return is_tensor_object | return is_tensor_object | ||||
| @staticmethod | @staticmethod | ||||