# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Find out forward functions of script file""" import ast import pasta class ForwardCall(ast.NodeVisitor): """ AST visitor that processes forward calls. Find the sub functions called by the forward function in the script file. """ def __init__(self, ast_tree): self._tree = ast_tree self._name_stack = [] self._forward_stack = [] self.calls = {} # key is function name, value is forward function ast node. self._function_list = {} # key is function name, value is function ast node. self.process() def process(self): """visit ast tree to find the forward functions.""" self.visit(self._tree) # first visit to find out all functions, so restores all variables except _function_list self._name_stack.clear() self._forward_stack.clear() self.calls.clear() self.visit(self._tree) def get_current_namespace(self): """Get the namespace when visit the AST node""" namespace = '.'.join(self._name_stack) return namespace @classmethod def get_call_name(cls, node): """Get functional call name.""" if not isinstance(node, ast.Call): return None return pasta.dump(node.func) def visit_ClassDef(self, node): """Callback function when visit AST tree""" self._name_stack.append(node.name) self.generic_visit(node) self._name_stack.pop() def visit_FunctionDef(self, node): """Callback function when visit AST tree""" namespace = self.get_current_namespace() if namespace: func_name = f'{namespace}.{node.name}' else: func_name = node.name func_name = f'{self.get_current_namespace()}.{node.name}' is_in_chain = func_name in self.calls or node.name == 'forward' if is_in_chain: self._forward_stack.append(func_name) if node.name == 'forward': self.calls.update({func_name: node}) self._function_list.update({func_name: node}) self.generic_visit(node) if is_in_chain: self._forward_stack.pop() def visit_Call(self, node): """Callback function when visit AST tree""" for arg in node.args: self.visit(arg) for keyword in node.keywords: self.visit(keyword.value) func_name = self.get_call_name(node) if isinstance(node.func, ast.Name): if func_name not in ['super', 'str', 'repr']: if self._forward_stack: self.calls.update({func_name: self._function_list.get(func_name)}) self.visit(node.func) else: if self._forward_stack: if func_name.startswith('self.'): whole_name = f'{self.get_current_namespace()}.{func_name.split(".")[-1]}' self.calls.update({whole_name: self._function_list.get(whole_name)}) else: self.calls.update({func_name: self._function_list.get(func_name)}) self.visit(node.func)