|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403 |
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless REQUIRED by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """code analysis module"""
- import ast
-
- import pasta
- from pasta.base import scope
-
- from mindinsight.mindconverter.common.exceptions import ScriptNotSupport
-
-
- class APIAnalysisSpec:
- """API analysis specifications"""
-
- import_name_mapping = {'torch': ['mindspore', None],
- 'torch.nn': ['mindspore.nn', 'nn'],
- 'torch.nn.functional': ['mindspore.ops.operations', 'P']}
-
- base_name_mapping = {'Module': 'Cell',
- 'Sequential': 'SequentialCell'
- }
-
- @classmethod
- def get_convertible_external_names(cls):
- """
- Obtain the convertible external names.
-
- The external name is the full dotted name being referenced.
- """
- return cls.import_name_mapping.keys()
-
- @staticmethod
- def get_network_base_class_names():
- """Obtain the base names which network class base from"""
- return ['Module',
- 'Sequential',
- 'ModuleList',
- 'ModuleDict',
- 'ParameterList',
- 'ParameterDict']
-
- @staticmethod
- def check_external_alias_ref(ref_name, external_name):
- """
- Check 'import as' is standard.
-
- Standard references are follow:
- import torch.nn as nn
- import torch.nn.functional as F
-
- Args:
- ref_name (str): The name that refers to the external_name.
- external_name (str): The full dotted name being referenced. For examples:
- 1. 'import torch.nn as nn', torch.nn is external_name, nn is ref_name.
- 2. 'from torch import nn as mm, torch.nn is external_name, mm is ref_name which is not a standard name.
-
- Returns:
- boolean, True if ref_name is standard else False.
- """
- if ref_name != 'nn' and external_name == 'torch.nn':
- is_standard = False
- elif ref_name != 'F' and external_name == 'torch.nn.functional':
- is_standard = False
- else:
- is_standard = True
-
- return is_standard
-
-
- class CodeAnalyzer(ast.NodeVisitor):
- """Code analyzer that analyzes PyTorch python script by AST Visitor.
-
- CodeAnalyzer find the codes that need to be converted to MindSpore,
- and provides the attributes related to the codes.
- """
-
- def __init__(self):
- self._stack = [] # Used to easily access the parent node
- self._external_references = {}
- self._is_standard_external_ref = True
- self._root_scope = None
- # Used to save functions that need to be converted, value type is pasta.base.scope.Scope
- self._network_functions = []
-
- # Used to easily trace the function node
- self._functions_stack = []
-
- # key type is pasta.base.scope.Scope, value type is list
- self._network_classes = {}
-
- @property
- def root_scope(self):
- """The root scope of the python script code."""
- return self._root_scope
-
- @property
- def is_standard_external_ref(self):
- """Obtain whether the result is a standard external reference."""
- return self._is_standard_external_ref
-
- @property
- def external_references(self):
- """Obtain all external references in the analyzed code."""
- return self._external_references
-
- def network_definitions(self):
- """Obtain the network definitions which need to be converted."""
- return {"functions": self._network_functions,
- "cell": self._network_classes}
-
- def process(self, ast_tree):
- """
- Start to analyze the code.
-
- Args:
- ast_tree (AST): The root node of the source code.
- """
- self.__init__()
- self._root_scope = scope.analyze(ast_tree)
- self._pre_process()
- self.visit(ast_tree)
- if not self._network_classes:
- msg = "model definition not be found."
- raise ScriptNotSupport(msg)
-
- @staticmethod
- def _check_external_standard(external_refs):
- """Check whether all external references are standard."""
- is_standard = True
- for external_name, external_ref_info in external_refs.items():
- is_standard = APIAnalysisSpec.check_external_alias_ref(external_name, external_ref_info.name)
- if not is_standard:
- break
- return is_standard
-
- def _is_base_from_cell(self, node):
- """
- Check whether the node bases from cell classes which are defined in APIAnalysisSpec.
-
- Args:
- node (ast.ClassDef): The node which is a class definition.
-
- Returns:
- boolean, True if the check result is Passed else False.
- """
- if self._is_ref_convertible_imports(node):
- whole_name = self._get_whole_name(node)
- if whole_name.split('.')[-1] in APIAnalysisSpec.get_network_base_class_names():
- return True
- return False
-
- def _pre_process(self):
- """Preprocessor checks the code before analyzing."""
- is_torch = False
-
- # check whether the code imports torch.
- for ref_name in self._root_scope.external_references.keys():
- if ref_name.split('.')[0] in APIAnalysisSpec.get_convertible_external_names():
- is_torch = True
- break
- if not is_torch:
- msg = "The source code does not import torch, model definition can not be found."
- raise ScriptNotSupport(msg)
-
- # Find out external reference in the code and save it.
- external_refs = self._analyze_import_references(self._root_scope)
- self._is_standard_external_ref = self._check_external_standard(external_refs)
- self._check_external_standard(external_refs)
- for external_name, external_ref_info in external_refs.items():
- self._external_references.update({
- external_name: {
- 'external_ref_info': external_ref_info,
- 'parent_node': None
- }
- })
-
- @staticmethod
- def _analyze_import_references(root_scope):
- """
- Find out all references from the import statements.
-
- Case1: (from)import alias, node_ref.name_ref.id is node_ref.name_ref.definition.asname.
- Case2: import without alias, node_ref.name_ref.definition.asname is None.
- e.g., import a.b.c, the reference definition id maybe is a, a.b or a.b.c.
- The reference id a.b.c is really wanted.
- """
- external_name_ref = dict()
- all_node_references = []
- for node_references in root_scope.external_references.values():
- all_node_references.extend(node_references)
-
- for node_ref in all_node_references:
- name_ref = node_ref.name_ref
- if not name_ref:
- continue
- definition = name_ref.definition
- if node_ref.name_ref.id in [definition.asname, definition.name]:
- external_name_ref[name_ref.id] = node_ref
-
- return external_name_ref
-
- def visit(self, node):
- """Overridden visit of the base class to maintain stack information to access parent node."""
- self._stack.append(node)
- super(CodeAnalyzer, self).visit(node)
- self._stack.pop()
-
- @staticmethod
- def _get_full_name(node):
- """Get the full name of the node."""
- if not isinstance(node, (ast.Attribute, ast.Name)):
- return None
- return pasta.dump(node)
-
- def _get_whole_name(self, node):
- """
- Get the whole name of the node.
-
- For example, nn.Module is spliced two nodes, nn node and Module node.
- When visit ast nodes,
- Module node is first visited, the full name is the same as the whole name, that is nn.Module.
- And then nn node is visited, the full name is nn, the whole name is nn.Module.
- """
- full_name = self._get_full_name(node)
- if not full_name:
- return None
- whole_name = full_name
- # node is in stack top pos
- if node is self._stack[-1]:
- parent_index = -1
- while isinstance(self._stack[parent_index], ast.Attribute):
- parent_index -= 1
-
- whole_name = self._get_full_name(self._stack[parent_index])
- return whole_name
-
- def _is_ref_convertible_imports(self, node):
- """Check whether the node references convertible imports."""
- check_result = False
- whole_name = self._get_whole_name(node)
- if whole_name:
- module_name = whole_name.split('.')[0]
- for ref_name, ref_info in self._external_references.items():
- external_ref = ref_info['external_ref_info']
- # external reference is convertible module
- if external_ref.name in APIAnalysisSpec.get_convertible_external_names():
- # import from the same external module
- if module_name == ref_name.split('.')[0]:
- check_result = True
- break
-
- return check_result
-
- @staticmethod
- def _get_external_node(external_references, only_convertible=False):
- """Get all external reference nodes."""
- external_nodes = {}
- for ref_name, ref_info in external_references.items():
- is_add = False
- if only_convertible:
- if ref_info['external_ref_info'].name in APIAnalysisSpec.get_convertible_external_names():
- is_add = True
- else:
- is_add = True
- if is_add:
- external_nodes.update({ref_info['external_ref_info'].node: ref_name})
- return external_nodes
-
- def _update_external_ref_parent(self, node):
- """Set external reference parent node info."""
- external_nodes = self._get_external_node(self._external_references, only_convertible=False)
- convertible_external_nodes = self._get_external_node(self._external_references, only_convertible=True)
- for name_node in node.names:
- if name_node in convertible_external_nodes.keys():
- if len(node.names) > 1:
- msg = """\
- Not support multiple imports of torch on one line in your script. line:%s: %s
- """ % (node.lineno, pasta.dump(node))
- raise ScriptNotSupport(msg)
- if name_node in external_nodes.keys():
- ref_name = external_nodes[name_node]
- self._external_references[ref_name]['parent_node'] = node
-
- @staticmethod
- def _get_class_scope(node_scope):
- """Find the class scope of the node_scope."""
- parent_scope = node_scope.parent_scope
- class_scope = None
- while parent_scope:
- if isinstance(parent_scope.node, ast.ClassDef):
- class_scope = parent_scope
- break
- parent_scope = parent_scope.parent_scope
- return class_scope
-
- def _update_convertible_functions(self, node):
- """Update convertible functions."""
- node_scope = self._root_scope.lookup_scope(node)
- class_scope = self._get_class_scope(node_scope)
- if class_scope:
- network_classes = self._network_classes.get(class_scope, [])
- if node_scope not in network_classes:
- network_classes.append(node_scope)
- else:
- if node_scope not in self._network_functions:
- self._network_functions.append(node_scope)
-
- def visit_ClassDef(self, node):
- """Callback function when visit AST tree"""
- if not self._stack[-1] is node:
- return
-
- for base in node.bases:
- if self._is_ref_convertible_imports(base):
- self._network_classes[self._root_scope.lookup_scope(node)] = []
-
- self.generic_visit(node)
-
- def _update_external_when_visit(self, node):
- """Update external reference when visiting import and import from statements."""
- self._update_external_ref_parent(node)
- self.generic_visit(node)
-
- def visit_Import(self, node):
- """Callback function when visit AST tree"""
- self._update_external_when_visit(node)
-
- def visit_ImportFrom(self, node):
- """Callback function when visit AST tree"""
- self._update_external_when_visit(node)
-
- def visit_Call(self, node):
- """Callback function when visit AST tree"""
- if not self._stack[-1] is node:
- return
- is_in_network_function = False
- # If torch call is happened in the function, save the function for network definition.
- if self._functions_stack and self._is_ref_convertible_imports(node.func):
- self._update_convertible_functions(self._functions_stack[-1])
- is_in_network_function = True
- if not is_in_network_function:
- self.generic_visit(node)
-
- def visit_FunctionDef(self, node):
- """Callback function when visit AST tree"""
- if not self._stack[-1] is node:
- return
- if node.name == "forward":
- self._update_convertible_functions(node)
-
- self._functions_stack.append(node)
- self.generic_visit(node)
- self._functions_stack.pop()
-
- def get_name(self, node):
- """
- Get the node name.
-
- Args:
- node (AST): The ast node of the source code.
-
- Returns:
- str, the name of the node
- """
- if isinstance(node, pasta.base.scope.Scope):
- items = [self.get_name(node.node)]
- parent_scope = node.parent_scope
- while parent_scope:
- if not isinstance(parent_scope.node, ast.Module):
- items.append(self.get_name(parent_scope.node))
- parent_scope = parent_scope.parent_scope
- return '.'.join(reversed(items))
- if isinstance(node, (ast.ClassDef, ast.FunctionDef)):
- return node.name
- if isinstance(node, (ast.Name, ast.Attribute)):
- return self._get_full_name(node)
- return str(node)
-
- def lookup_scope(self, node):
- """
- Search the scope of the node.
-
- Args:
- node (AST): The ast node of the source code.
-
- Returns:
- scope, the scope of the node
- """
- if isinstance(node, pasta.base.scope.Scope):
- return node
- return self._root_scope.lookup_scope(node)
|