|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Find out forward functions of script file"""
- import ast
- import os
-
-
- class ForwardCall(ast.NodeVisitor):
- """
- AST visitor that processes forward calls.
-
- Find the sub functions called by the forward function in the script file.
- """
-
- def __init__(self, filename):
- self.filename = filename
- self.module_name = os.path.basename(filename).replace('.py', '')
- self.name_stack = []
- self.forward_stack = []
- self.calls = set()
- self.process()
-
- def process(self):
- """Parse the python source file to find the forward functions."""
- with open(self.filename, 'rt', encoding='utf-8') as file:
- content = file.read()
- self.visit(ast.parse(content, self.filename))
-
- def get_current_namespace(self):
- """Get the namespace when visit the AST node"""
- namespace = '.'.join(self.name_stack)
- return namespace
-
- @classmethod
- def get_ast_node_name(cls, node):
- """Get AST node name."""
- if isinstance(node, ast.Attribute):
- return f'{cls.get_ast_node_name(node.value)}.{node.attr}'
-
- if isinstance(node, ast.Name):
- return node.id
-
- return node
-
- def visit_ClassDef(self, node):
- """Callback function when visit AST tree"""
- self.name_stack.append(node.name)
- self.generic_visit(node)
- self.name_stack.pop()
-
- def visit_FunctionDef(self, node):
- """Callback function when visit AST tree"""
- func_name = f'{self.get_current_namespace()}.{node.name}'
- is_in_chain = func_name in self.calls or node.name == 'forward'
- if is_in_chain:
- self.forward_stack.append(func_name)
-
- if node.name == 'forward':
- self.calls.add(func_name)
-
- self.generic_visit(node)
-
- if is_in_chain:
- self.forward_stack.pop()
-
- def visit_Call(self, node):
- """Callback function when visit AST tree"""
- for arg in node.args:
- self.visit(arg)
- for kw in node.keywords:
- self.visit(kw.value)
- func_name = self.get_ast_node_name(node.func)
- if isinstance(node.func, ast.Name):
- if func_name not in ['super', 'str', 'repr']:
- if self.forward_stack:
- self.calls.add(func_name)
- self.visit(node.func)
- else:
- if self.forward_stack:
- if 'self' in func_name:
- self.calls.add(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}')
- else:
- self.calls.add(func_name)
- self.visit(node.func)
|