You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ast_edits.py 38 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless REQUIRED by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Convert for Python scripts according API mapping information."""
  16. import ast
  17. import logging
  18. import re
  19. from enum import Enum
  20. import pasta
  21. from pasta.base import formatting as fmt
  22. from mindinsight.mindconverter.code_analysis import CodeAnalyzer
  23. from mindinsight.mindconverter.code_analysis import APIAnalysisSpec
  24. from mindinsight.mindconverter.config import ALL_MAPPING, F_LIST
  25. from mindinsight.mindconverter.config import NN_LIST
  26. from mindinsight.mindconverter.config import ALL_TORCH_APIS
  27. from mindinsight.mindconverter.config import ALL_2P_LIST
  28. from mindinsight.mindconverter.config import TENSOR_DOT_LIST
  29. from mindinsight.mindconverter.config import get_prompt_info
  30. from mindinsight.mindconverter.common.log import logger
  31. from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
  32. from mindinsight.mindconverter.forward_call import ForwardCall
  33. LOG_FMT_INSERT = "[Insert] '%s' is inserted to the converted file."
  34. LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'."
  35. LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s"
  36. LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s"
  37. LOG_FMT_PROMPT_INFO = "[INFO] %s"
  38. LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it."
  39. class ApiMatchingEnum(Enum):
  40. """Node edge type enum."""
  41. NOT_API = 'not an api name'
  42. API_INFER = 'infer api name to map'
  43. API_STANDARD = 'api name in the correct format'
  44. API_FOUND = 'found an api name in api list'
  45. API_MATCHED = 'api is matched to map'
  46. class _ConvertReport:
  47. """Report log of converting source code."""
  48. def __init__(self, is_stub=False):
  49. self._is_stub = is_stub
  50. self._max_line = 0
  51. self._log_head = []
  52. self._log_body = [] # report log, type is (severity, line, col, msg)
  53. def _add_log(self, severity, line, col, msg):
  54. """Add log."""
  55. if self._is_stub:
  56. return
  57. if line is None and col is None:
  58. self._log_head.append(msg)
  59. return
  60. if isinstance(line, int) and isinstance(col, int):
  61. self._log_body.append((severity, line, col, msg))
  62. if self._max_line < line:
  63. self._max_line = line
  64. else:
  65. raise TypeError('The parameter type is incorrect.')
  66. def info(self, line, col, msg):
  67. """Interface to add infer log"""
  68. self._add_log(logging.INFO, line, col, msg)
  69. def warning(self, line, col, msg):
  70. """Interface to add warning log"""
  71. self._add_log(logging.WARNING, line, col, msg)
  72. def header_msg(self, msg):
  73. """Interface to add header message log"""
  74. self._add_log(logging.INFO, None, None, msg)
  75. def get_logs(self):
  76. """Get convert logs"""
  77. logs = []
  78. logs.extend(self._log_head)
  79. # sort rule: line * self._max_line + col
  80. self._log_body.sort(key=lambda log: log[1] * self._max_line + log[2])
  81. for log_info in self._log_body:
  82. log_info = "line %d:%d: %s" % (log_info[1], log_info[2], log_info[3])
  83. if logs:
  84. # Deduplication for logs
  85. if logs[-1] != log_info:
  86. logs.append(log_info)
  87. else:
  88. logs.append(log_info)
  89. return logs
  90. class _LineColEditVisitor(ast.NodeVisitor):
  91. """
  92. Update line number and col offset of ast node.
  93. Use the line and column number of the original code to update
  94. the line and column number of the new code replaced with the original code.
  95. """
  96. class _NodeInfo:
  97. """NodeInfo class definition."""
  98. def __init__(self, node):
  99. self.node = node
  100. self.call_list = [] # Used to save all ast.Call node in self._node
  101. def __init__(self):
  102. self._dst_node_info = None
  103. self._src_node_info = None
  104. self._visiting = self._src_node_info # Used to point to the visiting node
  105. def update(self, replace_with_node, src_node):
  106. """Update the line and column number of the new code replaced with the original code."""
  107. replace_with_node.lineno = src_node.lineno
  108. replace_with_node.col_offset = src_node.col_offset
  109. self._dst_node_info = self._NodeInfo(replace_with_node)
  110. self._src_node_info = self._NodeInfo(src_node)
  111. self._visiting = self._src_node_info
  112. self.visit(self._visiting.node)
  113. self._visiting = self._dst_node_info
  114. self.visit(self._visiting.node)
  115. self._update_line_col()
  116. def visit_Call(self, node):
  117. """Callback function when visit AST tree"""
  118. self._visiting.call_list.append(node)
  119. self.generic_visit(node)
  120. def _update_line_col(self):
  121. """Update the line and column number information for all ast.Call node."""
  122. dst_call_list = list(self._dst_node_info.call_list)
  123. src_call_list = list(self._src_node_info.call_list)
  124. len_diff = len(dst_call_list) - len(src_call_list)
  125. # After MindSpore api replaces Torch api, more calls are generated.
  126. # For example, out.view() is replaced with P.Reshape()(out).
  127. # out.view() has only one call, but P.Reshape()(out) has two calls.
  128. # To match the replaced calls, the calls of out.view is padded to the same quantity.
  129. if len_diff > 0:
  130. src_call_list = [src_call_list[0]] * len_diff + src_call_list
  131. for dst_call, src_call in zip(dst_call_list, src_call_list):
  132. dst_call.lineno = src_call.lineno
  133. dst_call.col_offset = src_call.col_offset
  134. if not dst_call.args:
  135. continue
  136. # When out.size().view(1, ...) transforms to P.Reshape()(out.size(), 1, ...),
  137. # in this case, the column of parameter out.size() will be bigger than the following parameters.
  138. # To ensure the sequence of parameters, adjust the column of the second parameter.
  139. args = []
  140. for arg in dst_call.args:
  141. if self._check_arg2update(arg):
  142. args.append(arg)
  143. for arg in args:
  144. # line number starts from 1, column number starts from 0.
  145. arg.lineno += dst_call.lineno - 1
  146. arg.col_offset += dst_call.col_offset
  147. @staticmethod
  148. def _check_arg2update(arg):
  149. # When the arg is a function call, its col_offset is handled separately.
  150. if not isinstance(arg, ast.Call):
  151. return True
  152. return False
  153. class AstEditVisitor(ast.NodeVisitor):
  154. """AST Visitor that process function calls.
  155. Converts function calls from torch api to MindSpore api using api mapping information.
  156. """
  157. def __init__(self):
  158. self._process_log = _ConvertReport()
  159. self._tree = None
  160. self._code_analyzer = None
  161. self._stack = [] # Used to easily access the parent node
  162. self._forward_list = {}
  163. self._is_forward_function = False # Used to allow access the visiting function forward attribute
  164. self._new_call_nodes = [] # Used to save new ast.call nodes
  165. def process(self, ast_tree):
  166. """
  167. Convert source code to MindSpore code.
  168. Args:
  169. ast_tree (AST): The root node of the source code.
  170. """
  171. self.__init__()
  172. self._tree = ast_tree
  173. self._code_analyzer = CodeAnalyzer()
  174. self._code_analyzer.process(self._tree)
  175. self._forward_list = ForwardCall(self._tree).calls
  176. # replace python function under nn.Module
  177. self._convert_api()
  178. # replace external reference statements
  179. self._convert_external_reference()
  180. def get_logs(self):
  181. """Get conversion report."""
  182. return self._process_log.get_logs()
  183. def _convert_cell(self, cell_scope):
  184. """
  185. Convert a PyTorch Module class into MindSpore Cell class.
  186. Args:
  187. cell_scope (pasta.base.Scope): The network class definition node inherits from torch.nn.Module.
  188. """
  189. cell_ast_node = cell_scope.node
  190. line_no = cell_ast_node.lineno
  191. logger.info("Line %3d: start converting nn.Module %s", line_no, self._code_analyzer.get_name(cell_ast_node))
  192. class_elements = self._code_analyzer.network_definitions()['cell']
  193. # step1. update function definition
  194. for func_scope in class_elements.get(cell_scope, []):
  195. self._update_function_def(func_scope)
  196. # step2. update base name of class
  197. self._update_base_name(cell_scope)
  198. def _update_base_name(self, class_def_scope):
  199. """
  200. Update base name of class.
  201. Args:
  202. class_def_scope (ast.ClassDef): Class definition node.
  203. """
  204. base_name_mapping = APIAnalysisSpec.base_name_mapping
  205. class_def_node = class_def_scope.node
  206. base_class_nodes = class_def_scope.node.bases
  207. # update base class name
  208. for base_class_node in base_class_nodes:
  209. base_name = base_class_node.attr
  210. if base_name in APIAnalysisSpec.get_network_base_class_names():
  211. old_code = pasta.dump(base_class_node)
  212. if base_name in base_name_mapping:
  213. new_code = 'nn.' + base_name_mapping[base_class_node.attr]
  214. new_node = pasta.parse(new_code)
  215. pasta.ast_utils.replace_child(class_def_node, base_class_node, new_node)
  216. self._process_log.info(base_class_node.lineno, base_class_node.col_offset, LOG_FMT_CONVERT %
  217. (old_code, new_code))
  218. else:
  219. self._process_log.info(base_class_node.lineno, base_class_node.col_offset, LOG_FMT_NOT_CONVERT %
  220. (old_code, ''))
  221. @staticmethod
  222. def _modify_function_name(func_def_node, new_func_name):
  223. """Modify function name"""
  224. if not isinstance(func_def_node, ast.FunctionDef):
  225. raise NodeTypeNotSupport('It is not ast.FunctionDef node type.')
  226. old_func_name = func_def_node.name
  227. func_def_node.name = new_func_name
  228. # Modify formatting information stored by pasta
  229. old_function_def = fmt.get(func_def_node, 'function_def')
  230. if old_function_def:
  231. new_function_def = old_function_def.replace(old_func_name, new_func_name)
  232. fmt.set(func_def_node, 'function_def', new_function_def)
  233. fmt.set(func_def_node, 'name__src', new_func_name)
  234. def _update_function_def(self, func_scope):
  235. """
  236. Convert a PyTorch function into MindSpore function.
  237. Args:
  238. func_scope (pasta.base.scope.Scope): The node scope of function definition.
  239. """
  240. is_forward = self._judge_forward(func_scope)
  241. # step1. convert the content of the function.
  242. self._convert_function(func_scope, is_forward)
  243. # step2. replace function name if name is forward
  244. func_ast_node = func_scope.node
  245. old_func_name = 'forward'
  246. new_func_name = 'construct'
  247. if func_ast_node.name == old_func_name:
  248. self._modify_function_name(func_ast_node, new_func_name)
  249. real_line_number = self._get_real_line_number(func_ast_node)
  250. self._process_log.info(real_line_number, func_ast_node.col_offset,
  251. LOG_FMT_CONVERT % (old_func_name, new_func_name))
  252. def _convert_api(self):
  253. """Convert PyTorch api call to MindSpore api call in a function."""
  254. tasks = []
  255. found_func_nodes = []
  256. convert_elements = self._code_analyzer.network_definitions()
  257. for func_node_scope in convert_elements.get("functions", []):
  258. found_func_nodes.append(func_node_scope.node)
  259. is_forward = self._judge_forward(func_node_scope)
  260. tasks.append((self._convert_function, (func_node_scope, is_forward)))
  261. for class_scope, func_scopes in convert_elements.get("cell", []).items():
  262. for func_node_scope in func_scopes:
  263. found_func_nodes.append(func_node_scope.node)
  264. tasks.append((self._convert_cell, (class_scope,)))
  265. # Some functions in the forward call chain are not found by self._code_analyzer.
  266. for func_node in self._forward_list.values():
  267. is_forward = True
  268. if func_node and func_node not in found_func_nodes:
  269. func_node_scope = self._code_analyzer.lookup_scope(func_node)
  270. tasks.append((self._convert_function, (func_node_scope, is_forward)))
  271. for convert_fun, args in tasks:
  272. convert_fun(*args)
  273. @staticmethod
  274. def _dump_without_prefix(node):
  275. """Get the python source for an AST."""
  276. pos = 0
  277. source_prefix = pasta.base.formatting.get(node, 'prefix')
  278. if source_prefix:
  279. pos = len(source_prefix)
  280. source_code = pasta.dump(node)
  281. return source_code[pos:]
  282. @staticmethod
  283. def _get_real_line_number(node):
  284. """Get the real line number of the node."""
  285. try:
  286. line_number = node.lineno + len(node.decorator_list)
  287. except AttributeError:
  288. line_number = node.lineno
  289. return line_number
  290. def _replace_external_reference(self):
  291. """
  292. Replace external reference statements.
  293. Returns:
  294. dict, key is external name, value is the new replaced node.
  295. """
  296. all_name_mappings = APIAnalysisSpec.import_name_mapping
  297. names_replaced_with = dict()
  298. for ref_info in self._code_analyzer.external_references.values():
  299. external_ref_info = ref_info['external_ref_info']
  300. import_node = ref_info['parent_node']
  301. if import_node is None:
  302. continue
  303. code = self._dump_without_prefix(import_node)
  304. import_parent_node = self._code_analyzer.root_scope.parent(import_node)
  305. # replace import with new name
  306. if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names():
  307. external_ref_info = ref_info['external_ref_info']
  308. if external_ref_info.name in all_name_mappings.keys():
  309. replace_info = all_name_mappings[external_ref_info.name]
  310. new_node = self._make_import(name_to_import=replace_info[0], as_name=replace_info[1])
  311. new_code = pasta.dump(new_node)
  312. pasta.ast_utils.replace_child(import_parent_node, import_node, new_node)
  313. names_replaced_with.update({external_ref_info.name: new_node})
  314. self._process_log.info(import_node.lineno, import_node.col_offset, LOG_FMT_CONVERT %
  315. (code.strip(), new_code.strip()))
  316. elif external_ref_info.name.startswith('torch.'):
  317. self._process_log.warning(import_node.lineno, import_node.col_offset, LOG_FMT_NOT_CONVERT %
  318. (code.strip(), LOG_SUGGESTION_MANUAL_CONVERT))
  319. else:
  320. pass
  321. return names_replaced_with
  322. def _convert_external_reference(self):
  323. """Convert import statements."""
  324. all_name_mappings = APIAnalysisSpec.import_name_mapping
  325. # Step1. Replace external reference first.
  326. names_replaced_with = self._replace_external_reference()
  327. new_import_node = dict()
  328. insert_pos = 0
  329. # Step2. Find out remaining mapping name which not found in script.
  330. for src_name, new_import_name in all_name_mappings.items():
  331. if src_name not in names_replaced_with:
  332. new_node = self._make_import(name_to_import=new_import_name[0], as_name=new_import_name[1])
  333. new_import_node.update({insert_pos: new_node})
  334. insert_pos += 1
  335. else:
  336. try:
  337. # insert pos after the last one, if last one name is replaced.
  338. replaced_with_node = names_replaced_with[src_name]
  339. insert_pos = self._tree.body.index(replaced_with_node) + 1
  340. except ValueError:
  341. pass
  342. # Step3. Insert import reference in order.
  343. insert_cnt = 0
  344. for insert_pos, new_node in new_import_node.items():
  345. # Insert the node into the module
  346. self._tree.body.insert(insert_pos + insert_cnt, new_node)
  347. new_code = self._dump_without_prefix(new_node)
  348. self._process_log.header_msg(LOG_FMT_INSERT % new_code.strip())
  349. insert_cnt += 1
  350. @staticmethod
  351. def _make_import(name_to_import, as_name=None):
  352. """
  353. Create an import to the ast tree.
  354. Args:
  355. name_to_import: (string) The absolute name to import.
  356. as_name: (string) The alias for the import ("import name_to_import as asname")
  357. Returns:
  358. ast.Import, a new ast.Import node.
  359. """
  360. new_alias = ast.alias(name=name_to_import, asname=as_name)
  361. import_node = ast.Import(names=[new_alias])
  362. return import_node
  363. def _convert_function(self, func_scope, is_forward):
  364. """
  365. Convert a PyTorch function into MindSpore function.
  366. Args:
  367. func_scope (pasta.base.scope.Scope): The node scope of function definition.
  368. is_forward (boolean): If the function is defined in forward function in nn.Module in torch.
  369. """
  370. func_ast_node = func_scope.node
  371. line_no = func_ast_node.lineno
  372. logger.info("Line %3d: start converting function %s()", line_no, func_ast_node.name)
  373. parent = func_scope.parent_scope.node
  374. self._stack.clear()
  375. self._new_call_nodes.clear()
  376. if parent:
  377. self._stack.append(parent)
  378. self._is_forward_function = is_forward
  379. self.visit(func_scope.node)
  380. def _judge_forward(self, func_scope):
  381. """
  382. Check if function is a forward function.
  383. Args:
  384. func_scope (pasta.base.scope.Scope): The node scope of function definition.
  385. Returns:
  386. boolean, True or False
  387. """
  388. is_forward = func_scope.node in self._forward_list.values()
  389. if is_forward:
  390. logger.debug("%s is a forward function", self._code_analyzer.get_name(func_scope))
  391. return is_forward
  392. # Overridden to maintain stack information to access parent node
  393. def visit(self, node):
  394. """Visit a ast tree."""
  395. self._stack.append(node)
  396. super(AstEditVisitor, self).visit(node)
  397. self._stack.pop()
  398. def _mapping_standard_api_name(self, api_name):
  399. """Get mapping from external reference name to standard external reference name"""
  400. standard_name = api_name
  401. if not self._code_analyzer.is_standard_external_ref:
  402. # key is real ref name, value is standard ref name.
  403. mapping_names = self._mapping_standard_external_ref()
  404. api_name_parts = api_name.split('.')
  405. api_name_parts[0] = mapping_names.get(api_name_parts[0], api_name_parts[0])
  406. standard_name = '.'.join(api_name_parts)
  407. return standard_name
  408. def _infer_api_name(self, call_func_node, check_context=True):
  409. """Infer the call name.
  410. Examples:
  411. 1. nn.Sequential inferred to nn.Sequential
  412. 2. mmm.size inferred to .size if import torch.nn as nn
  413. 3. mmm.size inferred to mmm.size if import torch.nn as mmm
  414. """
  415. match_case = ApiMatchingEnum.NOT_API
  416. api_name = None
  417. call_name = pasta.dump(call_func_node)
  418. is_include_sub_call = self._is_include_sub_call(call_func_node)
  419. if is_include_sub_call:
  420. # x.y().z splits to ['x.y()', 'z']
  421. name_attributes = call_name.rsplit('.', 1)
  422. else:
  423. # x.y.z splits to ['x', 'y', 'z']
  424. name_attributes = call_name.split('.')
  425. # rewritten external module name
  426. # e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
  427. if check_context and not self._code_analyzer.is_standard_external_ref:
  428. standard_name = self._mapping_standard_api_name(name_attributes[0])
  429. else:
  430. standard_name = name_attributes[0]
  431. if standard_name in ["nn", "F", "torch"]:
  432. match_case = ApiMatchingEnum.API_STANDARD
  433. api_name = call_name
  434. else:
  435. # only infer function for tensor object.
  436. # e.g., api_call_name is out.view, .view is an api name for out which is maybe a tensor object.
  437. # e.g., 'xxxx'.size can be not inferred to .size, because string is not a tensor object.
  438. if self._check_tensor_object(call_func_node):
  439. api_name = '.' + name_attributes[-1]
  440. match_case = ApiMatchingEnum.API_INFER
  441. return api_name, match_case
  442. def _check_tensor_object(self, node):
  443. """Check whether the reference object of the node is a tensor object."""
  444. if not isinstance(node, (ast.Attribute, ast.Name)):
  445. return False
  446. name_attributes = self._dump_without_prefix(node).split('.')
  447. node_ref_name = name_attributes[0]
  448. if re.search(r'\W', node_ref_name) or len(name_attributes) == 1:
  449. return False
  450. func_name = '.' + name_attributes[-1]
  451. if func_name not in TENSOR_DOT_LIST:
  452. return False
  453. extracted_api = []
  454. for api in name_attributes[1:len(name_attributes) - 1]:
  455. if "(" or ")" in api:
  456. start = api.find("(")
  457. start = start if start != -1 else len(api)
  458. end = api.find(")")
  459. end = end if end != -1 else len(api)
  460. if start < end:
  461. api = f"{api[:start]}{api[end + 1:]}"
  462. extracted_api.append(api)
  463. is_tensor_object = True
  464. if self._code_analyzer:
  465. # Check whether the object is external reference.
  466. real_ref = None
  467. for ref_name in self._code_analyzer.external_references:
  468. if node_ref_name == ref_name:
  469. real_ref = self._code_analyzer.external_references[ref_name]["external_ref_info"]
  470. break
  471. if real_ref and f"{real_ref.name}.{'.'.join(extracted_api)}" not in F_LIST:
  472. is_tensor_object = False
  473. return is_tensor_object
  474. @staticmethod
  475. def _is_include_sub_call(call_func_node):
  476. """"Inspect a sub call in call expression.
  477. Examples:
  478. 1. nn.functional.relu() return False
  479. 2. nn.functional.relu(out).size() return True. nn.functional.relu(out) is sub call.
  480. 3. nn.functional.relu(out=out.size()).size() return False. out.size() is not sub call of argument.
  481. """
  482. is_include_call = False
  483. try:
  484. sub_node = call_func_node
  485. while sub_node and not isinstance(sub_node, ast.Call):
  486. sub_node = sub_node.value
  487. if isinstance(sub_node, ast.Call):
  488. is_include_call = True
  489. except AttributeError:
  490. is_include_call = False
  491. return is_include_call
  492. def match_api(self, call_func_node, is_forward, check_context=True):
  493. """
  494. Check api name to convert, check api name ok with a is_forward condition.
  495. Args:
  496. call_func_node (ast.Attribute): The call.func node.
  497. is_forward (bool): whether api belong to forward.
  498. check_context (boolean): If True, the code context will be checked. Default is True.
  499. Returns:
  500. str, the standard api name used to match.
  501. ApiMappingEnum, the match result.
  502. """
  503. match_case = ApiMatchingEnum.NOT_API
  504. api_call_name = pasta.dump(call_func_node)
  505. if api_call_name.startswith('self.'):
  506. return api_call_name, match_case
  507. api_name, match_case = self._infer_api_name(call_func_node, check_context)
  508. api_call_name = pasta.dump(call_func_node)
  509. is_tensor_obj_call = False
  510. if api_name != api_call_name:
  511. is_tensor_obj_call = True
  512. standard_api_call_name = api_name
  513. # rewritten external module name
  514. # e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
  515. if not is_tensor_obj_call:
  516. standard_api_call_name = self._get_api_whole_name(call_func_node, check_context)
  517. if standard_api_call_name in ALL_TORCH_APIS:
  518. match_case = ApiMatchingEnum.API_FOUND
  519. if (not is_forward and standard_api_call_name in NN_LIST) or \
  520. (is_forward and standard_api_call_name in ALL_2P_LIST):
  521. match_case = ApiMatchingEnum.API_MATCHED
  522. else:
  523. if standard_api_call_name and standard_api_call_name.startswith('torch.nn.init'):
  524. match_case = ApiMatchingEnum.API_MATCHED
  525. return standard_api_call_name, match_case
  526. @staticmethod
  527. def _get_call_parameters_str(call_node):
  528. """Get parameters string for a call node."""
  529. if not isinstance(call_node, ast.Call):
  530. raise NodeTypeNotSupport('It is not ast.Call node type.')
  531. parameters_str = ''
  532. call_str = pasta.dump(call_node)
  533. call_name = pasta.dump(call_node.func)
  534. last_parameter_str = ''
  535. if call_node.args:
  536. last_parameter_str = pasta.dump(call_node.args[-1])
  537. if call_node.keywords:
  538. last_parameter_str = pasta.dump(call_node.keywords[-1])
  539. if last_parameter_str:
  540. left_parenthesis_pos = call_str.find(call_name) + len(call_name)
  541. # call is like abc.call(a, b,), last parameter is b,
  542. # but parameters string must have last ',' character after the last parameter b.
  543. last_parameter_pos = call_str.rfind(last_parameter_str) + len(last_parameter_str)
  544. right_parenthesis_pos = call_str.find(')', last_parameter_pos)
  545. # parameters start pos must skip '(' character for calling.
  546. parameters_str = call_str[left_parenthesis_pos + 1:right_parenthesis_pos]
  547. return parameters_str
  548. def _get_api_whole_name(self, call_func_node, check_context=True):
  549. """
  550. Get the whole name for the call node.
  551. Args:
  552. call_func_node (AST): The func attribute of ast.Call.
  553. check_context (boolean): If True, the code context will be checked. Default is True.
  554. Returns:
  555. str, the whole name.
  556. """
  557. api_name, match_case = self._infer_api_name(call_func_node, check_context)
  558. if match_case == ApiMatchingEnum.API_STANDARD:
  559. api_name_splits = api_name.split('.')
  560. api_name_splits[0] = self._get_external_ref_whole_name(api_name_splits[0])
  561. if api_name_splits[0]:
  562. api_name = '.'.join(api_name_splits)
  563. return api_name
  564. def mapping_api(self, call_node, check_context=True):
  565. """
  566. Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
  567. If do not check context of the script, the code represented by the node must be written in the standard way.
  568. Args:
  569. call_node (ast.Call): The ast node to convert.
  570. check_context (boolean): If True, the code context will be checked. Default is True.
  571. Returns:
  572. str, the converted code.
  573. """
  574. if not isinstance(call_node, ast.Call):
  575. raise NodeTypeNotSupport("It is not ast.Call node.")
  576. code = pasta.dump(call_node)
  577. api_call_name = pasta.dump(call_node.func)
  578. if api_call_name.startswith('self.'):
  579. return code
  580. new_code = self._mapping_api(call_node, check_context)
  581. return new_code
  582. def _mapping_api(self, call_node, check_context=True):
  583. """
  584. Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
  585. If do not check context of the script, the code represented by the node must be written in the standard way.
  586. Args:
  587. call_node (ast.Call): The ast node to convert.
  588. check_context (boolean): If True, the code context will be checked. Default is True.
  589. Returns:
  590. str, the converted code.
  591. """
  592. code = pasta.dump(call_node)
  593. api_call_name = pasta.dump(call_node.func)
  594. # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
  595. args_str = '(' + self._get_call_parameters_str(call_node) + ')'
  596. try:
  597. api_name, _ = self._infer_api_name(call_node.func, check_context)
  598. standard_api_call_name = api_call_name
  599. if api_name != api_call_name:
  600. # api name .view inferred from out.view, split tensor object name is out
  601. tensor_obj_name = api_call_name[:-len(api_name)]
  602. map_helper = ALL_MAPPING[api_name]
  603. new_code = map_helper.convert(tensor_obj_name, args_str)
  604. else:
  605. # change to external ref name
  606. # e.g., mm.ReLU will be changed to nn.ReLU if 'import torch.nn as mm' in script.
  607. if check_context and not self._code_analyzer.is_standard_external_ref:
  608. standard_api_call_name = self._mapping_standard_api_name(api_name)
  609. map_helper = ALL_MAPPING[standard_api_call_name]
  610. new_code = map_helper.convert(standard_api_call_name, args_str)
  611. except KeyError:
  612. return code
  613. return new_code
  614. @staticmethod
  615. def _get_detail_prompt_msg(old_node, new_node):
  616. """Get detail converted prompt information."""
  617. msg = None
  618. if isinstance(old_node, ast.Call) and isinstance(new_node, ast.Call):
  619. old_api_name = pasta.dump(old_node.func)
  620. new_api_name = pasta.dump(new_node.func)
  621. if new_api_name == old_api_name:
  622. old_parameter_num = len(old_node.args) + len(old_node.keywords)
  623. new_parameter_num = len(new_node.args) + len(new_node.keywords)
  624. if old_parameter_num > 1:
  625. msg = 'Parameters are converted.'
  626. else:
  627. if old_parameter_num == 0 and new_parameter_num == 0:
  628. msg = 'The API name is converted to mindspore API'
  629. else:
  630. msg = 'Parameter is converted.'
  631. return msg
  632. def _convert_call(self, node, matched_api_name):
  633. """"Convert the call node."""
  634. new_node = None
  635. code = pasta.dump(node)
  636. api_name = pasta.dump(node.func)
  637. warning_info = get_prompt_info(matched_api_name)
  638. if warning_info is None:
  639. warning_info = ''
  640. if matched_api_name in ALL_MAPPING:
  641. logger.info("Line %3d start converting API: %s", node.lineno, api_name)
  642. new_code = self.mapping_api(node)
  643. if new_code != code:
  644. try:
  645. new_node = pasta.parse(new_code).body[0].value
  646. # find the first call name
  647. new_api_name = new_code[:new_code.find('(')]
  648. detail_msg = self._get_detail_prompt_msg(node, new_node)
  649. if detail_msg:
  650. warning_info = detail_msg + ' ' + warning_info
  651. except AttributeError:
  652. new_node = pasta.parse(new_code).body[0]
  653. new_api_name = new_code
  654. self._process_log.info(node.lineno, node.col_offset,
  655. LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info))
  656. else:
  657. logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info)
  658. self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info))
  659. return new_node
  660. def visit_Call(self, node):
  661. """Callback function when visit AST tree"""
  662. code = pasta.dump(node)
  663. api_name = pasta.dump(node.func)
  664. # The parent node first call is equal to this node, skip when parent node is replaced.
  665. # This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to
  666. # P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting.
  667. # Access from the penultimate element in reverse order.
  668. for parent_node in self._stack[-2::-1]:
  669. if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name):
  670. return
  671. parent = self._stack[-2]
  672. new_node = None
  673. new_code = code
  674. matched_api_name, match_case = self.match_api(node.func, self._is_forward_function)
  675. if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]:
  676. new_node = self._convert_call(node, matched_api_name)
  677. elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]:
  678. self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, ''))
  679. else:
  680. pass
  681. if parent and new_node:
  682. update_line_col = _LineColEditVisitor()
  683. update_line_col.update(new_node, node)
  684. pasta.ast_utils.replace_child(parent, node, new_node)
  685. self._new_call_nodes.append(new_node)
  686. node = new_node
  687. self._stack[-1] = node
  688. try:
  689. self.generic_visit(node)
  690. except Exception:
  691. logger.error('original code:%s, new code:%s', code, new_code, exc_info=True)
  692. raise
  693. def _mapping_standard_external_ref(self):
  694. """Obtain the mapping dict of mapping the external references to standard external references."""
  695. renames = {}
  696. external_refs = self._code_analyzer.external_references
  697. for ref_name, ref_info in external_refs.items():
  698. external_ref_info = ref_info['external_ref_info']
  699. if ref_name != 'nn' and external_ref_info.name == 'torch.nn':
  700. renames[ref_name] = 'nn'
  701. elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional':
  702. renames[ref_name] = 'F'
  703. return renames
  704. def _get_external_ref_whole_name(self, ref_name):
  705. """
  706. Find out external reference whole name.
  707. For example:
  708. In the parsed source code, there is import statement
  709. import torch.nn as new_name
  710. _get_external_ref_whole_name('new_name') will return 'torch.nn' string.
  711. """
  712. external_refs = self._code_analyzer.external_references
  713. for external_ref_name, ref_info in external_refs.items():
  714. external_ref_info = ref_info['external_ref_info']
  715. if external_ref_name == ref_name:
  716. return external_ref_info.name
  717. return None
  718. def _check_isinstance_parameter(self, node):
  719. """Check whether the second parameter of isinstance function contains the torch type."""
  720. is_isinstance_arg = False
  721. # Check whether node is the second parameter of the isinstance function call.
  722. # Access from the penultimate element in reverse order.
  723. for parent_node in self._stack[-2::-1]:
  724. if isinstance(parent_node, ast.Call) and pasta.dump(parent_node.func) == 'isinstance':
  725. isinstance_node = parent_node
  726. seconde_arg_type_nodes = []
  727. if isinstance(isinstance_node.args[1], ast.Tuple):
  728. seconde_arg_type_nodes.extend(isinstance_node.args[1].elts)
  729. else:
  730. seconde_arg_type_nodes.append(isinstance_node.args[1])
  731. if node in seconde_arg_type_nodes:
  732. is_isinstance_arg = True
  733. break
  734. if not is_isinstance_arg:
  735. return False
  736. isinstance_type_arg = pasta.dump(node)
  737. check_torch_type = False
  738. if isinstance_type_arg:
  739. type_splits = isinstance_type_arg.split('.')
  740. whole_name = self._get_external_ref_whole_name(type_splits[0])
  741. if whole_name and whole_name.startswith('torch'):
  742. check_torch_type = True
  743. if check_torch_type:
  744. _, match_case = self.match_api(node, False)
  745. if match_case != ApiMatchingEnum.NOT_API:
  746. warn_info = 'Manually determine the conversion type.'
  747. self._process_log.warning(node.lineno, node.col_offset,
  748. LOG_FMT_NOT_CONVERT % (isinstance_type_arg, warn_info))
  749. return check_torch_type
  750. def visit_Attribute(self, node):
  751. """Callback function when visit AST tree"""
  752. self._check_isinstance_parameter(node)
  753. self.generic_visit(node)