You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

code_analysis.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless REQUIRED by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """code analysis module"""
  16. import ast
  17. import pasta
  18. from pasta.base import scope
  19. from mindinsight.mindconverter.common.exceptions import ScriptNotSupport
  20. class APIAnalysisSpec:
  21. """API analysis specifications"""
  22. import_name_mapping = {'torch': ['mindspore', None],
  23. 'torch.nn': ['mindspore.nn', 'nn'],
  24. 'torch.nn.functional': ['mindspore.ops.operations', 'P']}
  25. base_name_mapping = {'Module': 'Cell',
  26. 'Sequential': 'SequentialCell'
  27. }
  28. @classmethod
  29. def get_convertible_external_names(cls):
  30. """
  31. Obtain the convertible external names.
  32. The external name is the full dotted name being referenced.
  33. """
  34. return cls.import_name_mapping.keys()
  35. @staticmethod
  36. def get_network_base_class_names():
  37. """Obtain the base names which network class base from"""
  38. return ['Module',
  39. 'Sequential',
  40. 'ModuleList',
  41. 'ModuleDict',
  42. 'ParameterList',
  43. 'ParameterDict']
  44. @staticmethod
  45. def check_external_alias_ref(ref_name, external_name):
  46. """
  47. Check 'import as' is standard.
  48. Standard references are follow:
  49. import torch.nn as nn
  50. import torch.nn.functional as F
  51. Args:
  52. ref_name (str): The name that refers to the external_name.
  53. external_name (str): The full dotted name being referenced. For examples:
  54. 1. 'import torch.nn as nn', torch.nn is external_name, nn is ref_name.
  55. 2. 'from torch import nn as mm, torch.nn is external_name, mm is ref_name which is not a standard name.
  56. Returns:
  57. boolean, True if ref_name is standard else False.
  58. """
  59. if ref_name != 'nn' and external_name == 'torch.nn':
  60. is_standard = False
  61. elif ref_name != 'F' and external_name == 'torch.nn.functional':
  62. is_standard = False
  63. else:
  64. is_standard = True
  65. return is_standard
  66. class CodeAnalyzer(ast.NodeVisitor):
  67. """Code analyzer that analyzes PyTorch python script by AST Visitor.
  68. CodeAnalyzer find the codes that need to be converted to MindSpore,
  69. and provides the attributes related to the codes.
  70. """
  71. def __init__(self):
  72. self._stack = [] # Used to easily access the parent node
  73. self._external_references = {}
  74. self._is_standard_external_ref = True
  75. self._root_scope = None
  76. # Used to save functions that need to be converted, value type is pasta.base.scope.Scope
  77. self._network_functions = []
  78. # Used to easily trace the function node
  79. self._functions_stack = []
  80. # key type is pasta.base.scope.Scope, value type is list
  81. self._network_classes = {}
  82. @property
  83. def root_scope(self):
  84. """The root scope of the python script code."""
  85. return self._root_scope
  86. @property
  87. def is_standard_external_ref(self):
  88. """Obtain whether the result is a standard external reference."""
  89. return self._is_standard_external_ref
  90. @property
  91. def external_references(self):
  92. """Obtain all external references in the analyzed code."""
  93. return self._external_references
  94. def network_definitions(self):
  95. """Obtain the network definitions which need to be converted."""
  96. return {"functions": self._network_functions,
  97. "cell": self._network_classes}
  98. def process(self, ast_tree):
  99. """
  100. Start to analyze the code.
  101. Args:
  102. ast_tree (AST): The root node of the source code.
  103. """
  104. self.__init__()
  105. self._root_scope = scope.analyze(ast_tree)
  106. self._pre_process()
  107. self.visit(ast_tree)
  108. if not self._network_classes:
  109. msg = "model definition not be found."
  110. raise ScriptNotSupport(msg)
  111. @staticmethod
  112. def _check_external_standard(external_refs):
  113. """Check whether all external references are standard."""
  114. is_standard = True
  115. for external_name, external_ref_info in external_refs.items():
  116. is_standard = APIAnalysisSpec.check_external_alias_ref(external_name, external_ref_info.name)
  117. if not is_standard:
  118. break
  119. return is_standard
  120. def _is_base_from_cell(self, node):
  121. """
  122. Check whether the node bases from cell classes which are defined in APIAnalysisSpec.
  123. Args:
  124. node (ast.ClassDef): The node which is a class definition.
  125. Returns:
  126. boolean, True if the check result is Passed else False.
  127. """
  128. if self._is_ref_convertible_imports(node):
  129. whole_name = self._get_whole_name(node)
  130. if whole_name.split('.')[-1] in APIAnalysisSpec.get_network_base_class_names():
  131. return True
  132. return False
  133. def _pre_process(self):
  134. """Preprocessor checks the code before analyzing."""
  135. is_torch = False
  136. # check whether the code imports torch.
  137. for ref_name in self._root_scope.external_references.keys():
  138. if ref_name.split('.')[0] in APIAnalysisSpec.get_convertible_external_names():
  139. is_torch = True
  140. break
  141. if not is_torch:
  142. msg = "The source code does not import torch, model definition can not be found."
  143. raise ScriptNotSupport(msg)
  144. # Find out external reference in the code and save it.
  145. external_refs = self._analyze_import_references(self._root_scope)
  146. self._is_standard_external_ref = self._check_external_standard(external_refs)
  147. self._check_external_standard(external_refs)
  148. for external_name, external_ref_info in external_refs.items():
  149. self._external_references.update({
  150. external_name: {
  151. 'external_ref_info': external_ref_info,
  152. 'parent_node': None
  153. }
  154. })
  155. @staticmethod
  156. def _analyze_import_references(root_scope):
  157. """
  158. Find out all references from the import statements.
  159. Case1: (from)import alias, node_ref.name_ref.id is node_ref.name_ref.definition.asname.
  160. Case2: import without alias, node_ref.name_ref.definition.asname is None.
  161. e.g., import a.b.c, the reference definition id maybe is a, a.b or a.b.c.
  162. The reference id a.b.c is really wanted.
  163. """
  164. external_name_ref = dict()
  165. all_node_references = []
  166. for node_references in root_scope.external_references.values():
  167. all_node_references.extend(node_references)
  168. for node_ref in all_node_references:
  169. name_ref = node_ref.name_ref
  170. if not name_ref:
  171. continue
  172. definition = name_ref.definition
  173. if node_ref.name_ref.id in [definition.asname, definition.name]:
  174. external_name_ref[name_ref.id] = node_ref
  175. return external_name_ref
  176. def visit(self, node):
  177. """Overridden visit of the base class to maintain stack information to access parent node."""
  178. self._stack.append(node)
  179. super(CodeAnalyzer, self).visit(node)
  180. self._stack.pop()
  181. @staticmethod
  182. def _get_full_name(node):
  183. """Get the full name of the node."""
  184. if not isinstance(node, (ast.Attribute, ast.Name)):
  185. return None
  186. return pasta.dump(node)
  187. def _get_whole_name(self, node):
  188. """
  189. Get the whole name of the node.
  190. For example, nn.Module is spliced two nodes, nn node and Module node.
  191. When visit ast nodes,
  192. Module node is first visited, the full name is the same as the whole name, that is nn.Module.
  193. And then nn node is visited, the full name is nn, the whole name is nn.Module.
  194. """
  195. full_name = self._get_full_name(node)
  196. if not full_name:
  197. return None
  198. whole_name = full_name
  199. # node is in stack top pos
  200. if node is self._stack[-1]:
  201. parent_index = -1
  202. while isinstance(self._stack[parent_index], ast.Attribute):
  203. parent_index -= 1
  204. whole_name = self._get_full_name(self._stack[parent_index])
  205. return whole_name
  206. def _is_ref_convertible_imports(self, node):
  207. """Check whether the node references convertible imports."""
  208. check_result = False
  209. whole_name = self._get_whole_name(node)
  210. if whole_name:
  211. module_name = whole_name.split('.')[0]
  212. for ref_name, ref_info in self._external_references.items():
  213. external_ref = ref_info['external_ref_info']
  214. # external reference is convertible module
  215. if external_ref.name in APIAnalysisSpec.get_convertible_external_names():
  216. # import from the same external module
  217. if module_name == ref_name.split('.')[0]:
  218. check_result = True
  219. break
  220. return check_result
  221. @staticmethod
  222. def _get_external_node(external_references, only_convertible=False):
  223. """Get all external reference nodes."""
  224. external_nodes = {}
  225. for ref_name, ref_info in external_references.items():
  226. is_add = False
  227. if only_convertible:
  228. if ref_info['external_ref_info'].name in APIAnalysisSpec.get_convertible_external_names():
  229. is_add = True
  230. else:
  231. is_add = True
  232. if is_add:
  233. external_nodes.update({ref_info['external_ref_info'].node: ref_name})
  234. return external_nodes
  235. def _update_external_ref_parent(self, node):
  236. """Set external reference parent node info."""
  237. external_nodes = self._get_external_node(self._external_references, only_convertible=False)
  238. convertible_external_nodes = self._get_external_node(self._external_references, only_convertible=True)
  239. for name_node in node.names:
  240. if name_node in convertible_external_nodes.keys():
  241. if len(node.names) > 1:
  242. msg = """\
  243. Not support multiple imports of torch on one line in your script. line:%s: %s
  244. """ % (node.lineno, pasta.dump(node))
  245. raise ScriptNotSupport(msg)
  246. if name_node in external_nodes.keys():
  247. ref_name = external_nodes[name_node]
  248. self._external_references[ref_name]['parent_node'] = node
  249. @staticmethod
  250. def _get_class_scope(node_scope):
  251. """Find the class scope of the node_scope."""
  252. parent_scope = node_scope.parent_scope
  253. class_scope = None
  254. while parent_scope:
  255. if isinstance(parent_scope.node, ast.ClassDef):
  256. class_scope = parent_scope
  257. break
  258. parent_scope = parent_scope.parent_scope
  259. return class_scope
  260. def _update_convertible_functions(self, node):
  261. """Update convertible functions."""
  262. node_scope = self._root_scope.lookup_scope(node)
  263. class_scope = self._get_class_scope(node_scope)
  264. if class_scope:
  265. network_classes = self._network_classes.get(class_scope, [])
  266. if node_scope not in network_classes:
  267. network_classes.append(node_scope)
  268. else:
  269. if node_scope not in self._network_functions:
  270. self._network_functions.append(node_scope)
  271. def visit_ClassDef(self, node):
  272. """Callback function when visit AST tree"""
  273. if not self._stack[-1] is node:
  274. return
  275. for base in node.bases:
  276. if self._is_ref_convertible_imports(base):
  277. self._network_classes[self._root_scope.lookup_scope(node)] = []
  278. self.generic_visit(node)
  279. def _update_external_when_visit(self, node):
  280. """Update external reference when visiting import and import from statements."""
  281. self._update_external_ref_parent(node)
  282. self.generic_visit(node)
  283. def visit_Import(self, node):
  284. """Callback function when visit AST tree"""
  285. self._update_external_when_visit(node)
  286. def visit_ImportFrom(self, node):
  287. """Callback function when visit AST tree"""
  288. self._update_external_when_visit(node)
  289. def visit_Call(self, node):
  290. """Callback function when visit AST tree"""
  291. if not self._stack[-1] is node:
  292. return
  293. is_in_network_function = False
  294. # If torch call is happened in the function, save the function for network definition.
  295. if self._functions_stack and self._is_ref_convertible_imports(node.func):
  296. self._update_convertible_functions(self._functions_stack[-1])
  297. is_in_network_function = True
  298. if not is_in_network_function:
  299. self.generic_visit(node)
  300. def visit_FunctionDef(self, node):
  301. """Callback function when visit AST tree"""
  302. if not self._stack[-1] is node:
  303. return
  304. if node.name == "forward":
  305. self._update_convertible_functions(node)
  306. self._functions_stack.append(node)
  307. self.generic_visit(node)
  308. self._functions_stack.pop()
  309. def get_name(self, node):
  310. """
  311. Get the node name.
  312. Args:
  313. node (AST): The ast node of the source code.
  314. Returns:
  315. str, the name of the node
  316. """
  317. if isinstance(node, pasta.base.scope.Scope):
  318. items = [self.get_name(node.node)]
  319. parent_scope = node.parent_scope
  320. while parent_scope:
  321. if not isinstance(parent_scope.node, ast.Module):
  322. items.append(self.get_name(parent_scope.node))
  323. parent_scope = parent_scope.parent_scope
  324. return '.'.join(reversed(items))
  325. if isinstance(node, (ast.ClassDef, ast.FunctionDef)):
  326. return node.name
  327. if isinstance(node, (ast.Name, ast.Attribute)):
  328. return self._get_full_name(node)
  329. return str(node)
  330. def lookup_scope(self, node):
  331. """
  332. Search the scope of the node.
  333. Args:
  334. node (AST): The ast node of the source code.
  335. Returns:
  336. scope, the scope of the node
  337. """
  338. if isinstance(node, pasta.base.scope.Scope):
  339. return node
  340. return self._root_scope.lookup_scope(node)