You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ast_edits.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless REQUIRED by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Convert for Python scripts according API mapping information."""
  16. import ast
  17. import logging
  18. import re
  19. from enum import Enum
  20. import pasta
  21. from pasta.augment import import_utils
  22. from mindinsight.mindconverter.code_analysis import CodeAnalyzer
  23. from mindinsight.mindconverter.code_analysis import APIAnalysisSpec
  24. from mindinsight.mindconverter.config import ALL_MAPPING
  25. from mindinsight.mindconverter.config import NN_LIST
  26. from mindinsight.mindconverter.config import ALL_TORCH_APIS
  27. from mindinsight.mindconverter.config import ALL_2P_LIST
  28. from mindinsight.mindconverter.config import get_prompt_info
  29. from mindinsight.mindconverter.common.log import logger
  30. from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
  31. from mindinsight.mindconverter.forward_call import ForwardCall
  32. LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'."
  33. LOG_FMT_CONVERT_WITH_TIPS = "[Convert] '%s' is converted to '%s'. %s"
  34. LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s"
  35. LOG_FMT_PROMPT_INFO = "[INFO] %s"
  36. LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it."
  37. class ApiMatchingEnum(Enum):
  38. """Node edge type enum."""
  39. NOT_API = 'not an api name'
  40. API_INFER = 'infer api name to map'
  41. API_STANDARD = 'api name in the correct format'
  42. API_FOUND = 'found an api name in api list'
  43. API_MATCHED = 'api is matched to map'
  44. class _ConvertReport:
  45. """Report log of converting source code."""
  46. def __init__(self, is_stub=False):
  47. self._is_stub = is_stub
  48. self._max_line = 0
  49. self._log = [] # report log, type is (severity, line, col, msg)
  50. def _add_log(self, severity, line, col, msg):
  51. """Add log."""
  52. if self._is_stub:
  53. return
  54. if isinstance(line, int) and isinstance(col, int):
  55. self._log.append((severity, line, col, msg))
  56. if self._max_line < line:
  57. self._max_line = line
  58. def info(self, line, col, msg):
  59. """Interface to add infer log"""
  60. self._add_log(logging.INFO, line, col, msg)
  61. def warning(self, line, col, msg):
  62. """Interface to add warning log"""
  63. self._add_log(logging.WARNING, line, col, msg)
  64. def get_logs(self):
  65. """Get convert logs"""
  66. logs = []
  67. # sort rule: line * self._max_line + col
  68. self._log.sort(key=lambda log: log[1] * self._max_line + log[2])
  69. for log_info in self._log:
  70. log_info = "line %d:%d: %s" % (log_info[1], log_info[2], log_info[3])
  71. logs.append(log_info)
  72. return logs
  73. class _LineColEditVisitor(ast.NodeVisitor):
  74. """
  75. Update line number and col offset of ast node.
  76. Use the line and column number of the original code to update
  77. the line and column number of the new code replaced with the original code.
  78. """
  79. class _NodeInfo:
  80. """NodeInfo class definition."""
  81. def __init__(self, node):
  82. self.node = node
  83. self.call_list = [] # Used to save all ast.Call node in self._node
  84. def __init__(self):
  85. self._dst_node_info = None
  86. self._src_node_info = None
  87. self._visiting = self._src_node_info # Used to point to the visiting node
  88. def update(self, replace_with_node, src_node):
  89. """Update the line and column number of the new code replaced with the original code."""
  90. replace_with_node.lineno = src_node.lineno
  91. replace_with_node.col_offset = src_node.col_offset
  92. self._dst_node_info = self._NodeInfo(replace_with_node)
  93. self._src_node_info = self._NodeInfo(src_node)
  94. self._visiting = self._src_node_info
  95. self.visit(self._visiting.node)
  96. self._visiting = self._dst_node_info
  97. self.visit(self._visiting.node)
  98. self._update_line_col()
  99. def visit_Call(self, node):
  100. """Callback function when visit AST tree"""
  101. self._visiting.call_list.append(node)
  102. self.generic_visit(node)
  103. def _update_line_col(self):
  104. """Update the line and column number information for all ast.Call node."""
  105. dst_call_list = list(self._dst_node_info.call_list)
  106. src_call_list = list(self._src_node_info.call_list)
  107. len_diff = len(dst_call_list) - len(src_call_list)
  108. # After MindSpore api replaces Torch api, more calls are generated.
  109. # For example, out.view() is replaced with P.Reshape()(out).
  110. # out.view() has only one call, but P.Reshape()(out) has two calls.
  111. # To match the replaced calls, the calls of out.view is padded to the same quantity.
  112. if len_diff > 0:
  113. src_call_list = [src_call_list[0]] * len_diff + src_call_list
  114. for dst_call, src_call in zip(dst_call_list, src_call_list):
  115. dst_call.lineno = src_call.lineno
  116. dst_call.col_offset = src_call.col_offset
  117. if not dst_call.args:
  118. continue
  119. # When out.size().view(1, ...) transforms to P.Reshape()(out.size(), 1, ...),
  120. # in this case, the column of parameter out.size() will be bigger than the following parameters.
  121. # To ensure the sequence of parameters, adjust the column of the second parameter.
  122. args = []
  123. for arg in dst_call.args:
  124. if self._check_arg2update(arg):
  125. args.append(arg)
  126. for arg in args:
  127. arg.lineno = dst_call.lineno
  128. arg.col_offset += dst_call.col_offset
  129. @staticmethod
  130. def _check_arg2update(arg):
  131. # Only the col_offset of the first line code is re-counted, needs to be corrected.
  132. # When the arg is a function call, its col_offset is handled separately.
  133. if not isinstance(arg, ast.Call) and arg.lineno == 1:
  134. return True
  135. return False
  136. class AstEditVisitor(ast.NodeVisitor):
  137. """AST Visitor that process function calls.
  138. Converts function calls from torch api to MindSpore api using api mapping information.
  139. """
  140. def __init__(self):
  141. self._process_log = _ConvertReport()
  142. self._tree = None
  143. self._code_analyzer = None
  144. self._stack = [] # Used to easily access the parent node
  145. self._forward_list = {}
  146. self._is_forward_function = False # Used to allow access the visiting function forward attribute
  147. self._new_call_nodes = [] # Used to save new ast.call nodes
  148. def process(self, ast_tree):
  149. """
  150. Convert source code to MindSpore code.
  151. Args:
  152. ast_tree (AST): The root node of the source code.
  153. """
  154. self.__init__()
  155. self._tree = ast_tree
  156. self._code_analyzer = CodeAnalyzer()
  157. self._code_analyzer.process(self._tree)
  158. self._forward_list = ForwardCall(self._tree).calls
  159. # replace python function under nn.Module
  160. self._convert_api()
  161. # replace external reference statements
  162. self._convert_external_reference()
  163. def get_logs(self):
  164. """Get conversion report."""
  165. return self._process_log.get_logs()
  166. def _convert_cell(self, cell_scope):
  167. """
  168. Convert a PyTorch Module class into MindSpore Cell class.
  169. Args:
  170. cell_scope (pasta.base.Scope): The network class definition node inherits from torch.nn.Module.
  171. """
  172. cell_ast_node = cell_scope.node
  173. line_no = cell_ast_node.lineno
  174. logger.info("Line %3d: start converting nn.Module %s", line_no, self._code_analyzer.get_name(cell_ast_node))
  175. class_elements = self._code_analyzer.network_definitions()['cell']
  176. # step1. update function definition
  177. for func_scope in class_elements.get(cell_scope, []):
  178. self._update_function_def(func_scope)
  179. # step2. update base name of class
  180. self._update_base_name(cell_scope)
  181. def _update_base_name(self, class_def_scope):
  182. """
  183. Update base name of class.
  184. Args:
  185. class_def_scope (ast.ClassDef): Class definition node.
  186. """
  187. base_name_mapping = APIAnalysisSpec.base_name_mapping
  188. class_def_node = class_def_scope.node
  189. base_class_nodes = class_def_scope.node.bases
  190. # update base class name
  191. for base_class_node in base_class_nodes:
  192. base_name = base_class_node.attr
  193. if base_name in APIAnalysisSpec.get_network_base_class_names():
  194. old_code = pasta.dump(base_class_node)
  195. if base_name in base_name_mapping:
  196. new_code = 'nn.' + base_name_mapping[base_class_node.attr]
  197. new_node = pasta.parse(new_code)
  198. pasta.ast_utils.replace_child(class_def_node, base_class_node, new_node)
  199. self._process_log.info(base_class_node.lineno, base_class_node.col_offset, LOG_FMT_CONVERT %
  200. (old_code, new_code))
  201. else:
  202. self._process_log.info(base_class_node.lineno, base_class_node.col_offset, LOG_FMT_NOT_CONVERT %
  203. (old_code, ''))
  204. def _update_function_def(self, func_scope):
  205. """
  206. Convert a PyTorch function into MindSpore function.
  207. Args:
  208. func_scope (pasta.base.scope.Scope): The node scope of function definition.
  209. """
  210. is_forward = self._judge_forward(func_scope)
  211. # step1. convert the content of the function.
  212. self._convert_function(func_scope, is_forward)
  213. # step2. replace function name if name is forward
  214. func_ast_node = func_scope.node
  215. old_func_name = 'forward'
  216. new_func_name = 'construct'
  217. if func_ast_node.name == old_func_name:
  218. func_ast_node.name = new_func_name
  219. self._process_log.info(func_ast_node.lineno, func_ast_node.col_offset,
  220. LOG_FMT_CONVERT % (old_func_name, new_func_name))
  221. def _convert_api(self):
  222. """Convert PyTorch api call to MindSpore api call in a function."""
  223. tasks = []
  224. convert_elements = self._code_analyzer.network_definitions()
  225. for func_node_scope in convert_elements.get("functions", []):
  226. is_forward = self._judge_forward(func_node_scope)
  227. tasks.append((self._convert_function, (func_node_scope, is_forward)))
  228. for class_scope in convert_elements.get("cell", []).keys():
  229. tasks.append((self._convert_cell, (class_scope,)))
  230. for convert_fun, args in tasks:
  231. convert_fun(*args)
  232. def _convert_external_reference(self):
  233. """Convert import statements."""
  234. name_replace = APIAnalysisSpec.import_name_mapping
  235. replace_imports = list(name_replace.values())
  236. for ref_info in self._code_analyzer.external_references.values():
  237. external_ref_info = ref_info['external_ref_info']
  238. parent_node = ref_info['parent_node']
  239. if parent_node is None:
  240. continue
  241. code = pasta.dump(parent_node)
  242. if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names():
  243. external_ref_info = ref_info['external_ref_info']
  244. if external_ref_info.name in name_replace.keys():
  245. import_utils.remove_import_alias_node(self._code_analyzer.root_scope, external_ref_info.node)
  246. replace_info = name_replace[external_ref_info.name]
  247. new_ref_name = replace_info[1]
  248. new_external_name = replace_info[0]
  249. if new_ref_name:
  250. new_code = f'import {new_external_name} as {new_ref_name}'
  251. else:
  252. new_code = f'import {new_external_name}'
  253. self._process_log.info(parent_node.lineno, parent_node.col_offset, LOG_FMT_CONVERT %
  254. (code.strip(), new_code.strip()))
  255. elif external_ref_info.name.startswith('torch.'):
  256. self._process_log.warning(parent_node.lineno, parent_node.col_offset, LOG_FMT_NOT_CONVERT %
  257. (code.strip(), LOG_SUGGESTION_MANUAL_CONVERT))
  258. else:
  259. pass
  260. # Insert import in reverse order, display in forward order.
  261. for idx in range(len(replace_imports) - 1, -1, -1):
  262. replace_import = replace_imports[idx]
  263. if replace_import[1]:
  264. self._add_import(name_to_import=replace_import[0], as_name=replace_import[1])
  265. else:
  266. self._add_import(name_to_import=replace_import[0])
  267. def _add_import(self, name_to_import, as_name=None):
  268. """
  269. Adds an import to the ast tree.
  270. Args:
  271. name_to_import: (string) The absolute name to import.
  272. as_name: (string) The alias for the import ("import name_to_import as asname")
  273. """
  274. new_alias = ast.alias(name=name_to_import, asname=as_name)
  275. import_node = ast.Import(names=[new_alias])
  276. # Insert the node at the top of the module
  277. self._tree.body.insert(1 if pasta.base.ast_utils.has_docstring(self._tree) else 0, import_node)
  278. def _convert_function(self, func_scope, is_forward):
  279. """
  280. Convert a PyTorch function into MindSpore function.
  281. Args:
  282. func_scope (pasta.base.scope.Scope): The node scope of function definition.
  283. is_forward (boolean): If the function is defined in forward function in nn.Module in torch.
  284. """
  285. func_ast_node = func_scope.node
  286. line_no = func_ast_node.lineno
  287. logger.info("Line %3d: start converting function %s()", line_no, func_ast_node.name)
  288. parent = func_scope.parent_scope.node
  289. self._stack.clear()
  290. self._new_call_nodes.clear()
  291. if parent:
  292. self._stack.append(parent)
  293. self._is_forward_function = is_forward
  294. self.visit(func_scope.node)
  295. def _judge_forward(self, func_scope):
  296. """
  297. Check if function is a forward function.
  298. Args:
  299. func_scope (pasta.base.scope.Scope): The node scope of function definition.
  300. Returns:
  301. boolean, True or False
  302. """
  303. is_forward = func_scope.node in self._forward_list.values()
  304. if is_forward:
  305. logger.debug("%s is a forward function", self._code_analyzer.get_name(func_scope))
  306. return is_forward
  307. # Overridden to maintain stack information to access parent node
  308. def visit(self, node):
  309. """Visit a ast tree."""
  310. self._stack.append(node)
  311. super(AstEditVisitor, self).visit(node)
  312. self._stack.pop()
  313. def _mapping_standard_api_name(self, api_name):
  314. """Get mapping from external reference name to standard external reference name"""
  315. standard_name = api_name
  316. if not self._code_analyzer.is_standard_external_ref:
  317. # key is real ref name, value is standard ref name.
  318. mapping_names = self._mapping_standard_external_ref()
  319. api_name_parts = api_name.split('.')
  320. api_name_parts[0] = mapping_names.get(api_name_parts[0], api_name_parts[0])
  321. standard_name = '.'.join(api_name_parts)
  322. return standard_name
  323. def _infer_api_name(self, call_func_node, check_context=True):
  324. """Infer the call name.
  325. Examples:
  326. 1. nn.Sequential inferred to nn.Sequential
  327. 2. mmm.size inferred to .size if import torch.nn as nn
  328. 3. mmm.size inferred to mmm.size if import torch.nn as mmm
  329. """
  330. match_case = ApiMatchingEnum.NOT_API
  331. api_name = None
  332. call_name = pasta.dump(call_func_node)
  333. is_include_sub_call = self._is_include_sub_call(call_func_node)
  334. if is_include_sub_call:
  335. name_attributes = call_name.rsplit('.', 1)
  336. else:
  337. name_attributes = call_name.split('.')
  338. # rewritten external module name
  339. # e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
  340. if check_context and not self._code_analyzer.is_standard_external_ref:
  341. standard_name = self._mapping_standard_api_name(name_attributes[0])
  342. else:
  343. standard_name = name_attributes[0]
  344. if standard_name in ["nn", "F", "torch"]:
  345. match_case = ApiMatchingEnum.API_STANDARD
  346. api_name = call_name
  347. else:
  348. # only infer function for tensor object.
  349. # e.g., api_call_name is out.view, .view is an api name for out which is maybe a tensor object.
  350. # e.g., 'xxxx'.size can be not inferred to .size, because string is not a tensor object.
  351. first_name = standard_name.split('.')[0]
  352. if not re.search(r'\W', first_name) and len(name_attributes) > 1:
  353. api_name = '.' + name_attributes[-1]
  354. match_case = ApiMatchingEnum.API_INFER
  355. return api_name, match_case
  356. @staticmethod
  357. def _is_include_sub_call(call_func_node):
  358. """"Inspect a sub call in call expression.
  359. Examples:
  360. 1. nn.functional.relu() return False
  361. 2. nn.functional.relu(out).size() return True. nn.functional.relu(out) is sub call.
  362. 3. nn.functional.relu(out=out.size()).size() return False. out.size() is not sub call of argument.
  363. """
  364. is_include_call = False
  365. try:
  366. sub_node = call_func_node
  367. while sub_node and not isinstance(sub_node, ast.Call):
  368. sub_node = sub_node.value
  369. if isinstance(sub_node, ast.Call):
  370. is_include_call = True
  371. except AttributeError:
  372. is_include_call = False
  373. return is_include_call
  374. def match_api(self, call_func_node, is_forward, check_context=True):
  375. """
  376. Check api name to convert, check api name ok with a is_forward condition.
  377. Args:
  378. call_func_node (ast.Attribute): The call.func node.
  379. is_forward (bool): whether api belong to forward.
  380. check_context (boolean): If True, the code context will be checked. Default is True.
  381. Returns:
  382. str, the standard api name used to match.
  383. ApiMappingEnum, the match result.
  384. """
  385. match_case = ApiMatchingEnum.NOT_API
  386. api_call_name = pasta.dump(call_func_node)
  387. if api_call_name.startswith('self.'):
  388. return api_call_name, match_case
  389. api_name, match_case = self._infer_api_name(call_func_node, check_context)
  390. api_call_name = pasta.dump(call_func_node)
  391. is_tensor_obj_call = False
  392. if api_name != api_call_name:
  393. is_tensor_obj_call = True
  394. standard_api_call_name = api_name
  395. # rewritten external module name
  396. # e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
  397. if not is_tensor_obj_call:
  398. standard_api_call_name = self._get_api_whole_name(call_func_node, check_context)
  399. if standard_api_call_name in ALL_TORCH_APIS:
  400. match_case = ApiMatchingEnum.API_FOUND
  401. if (not is_forward and standard_api_call_name in NN_LIST) or \
  402. (is_forward and standard_api_call_name in ALL_2P_LIST):
  403. match_case = ApiMatchingEnum.API_MATCHED
  404. else:
  405. if standard_api_call_name and standard_api_call_name.startswith('torch.nn.init'):
  406. match_case = ApiMatchingEnum.API_MATCHED
  407. return standard_api_call_name, match_case
  408. @staticmethod
  409. def _get_call_parameters_str(call_node):
  410. """Get parameters string for a call node."""
  411. if not isinstance(call_node, ast.Call):
  412. raise NodeTypeNotSupport('It is not ast.Call node type.')
  413. parameters_str = ''
  414. call_str = pasta.dump(call_node)
  415. call_name = pasta.dump(call_node.func)
  416. last_parameter_str = ''
  417. if call_node.args:
  418. last_parameter_str = pasta.dump(call_node.args[-1])
  419. if call_node.keywords:
  420. last_parameter_str = pasta.dump(call_node.keywords[-1])
  421. if last_parameter_str:
  422. left_parenthesis_pos = call_str.find(call_name) + len(call_name)
  423. # call is like abc.call(a, b,), last parameter is b,
  424. # but parameters string must have last ',' character after the last parameter b.
  425. last_parameter_pos = call_str.rfind(last_parameter_str) + len(last_parameter_str)
  426. right_parenthesis_pos = call_str.find(')', last_parameter_pos)
  427. # parameters start pos must skip '(' character for calling.
  428. parameters_str = call_str[left_parenthesis_pos + 1:right_parenthesis_pos]
  429. return parameters_str
  430. def _get_api_whole_name(self, call_func_node, check_context=True):
  431. """
  432. Get the whole name for the call node.
  433. Args:
  434. call_func_node (AST): The func attribute of ast.Call.
  435. check_context (boolean): If True, the code context will be checked. Default is True.
  436. Returns:
  437. str, the whole name.
  438. """
  439. api_name, match_case = self._infer_api_name(call_func_node, check_context)
  440. if match_case == ApiMatchingEnum.API_STANDARD:
  441. api_name_splits = api_name.split('.')
  442. api_name_splits[0] = self._get_external_ref_whole_name(api_name_splits[0])
  443. if api_name_splits[0]:
  444. api_name = '.'.join(api_name_splits)
  445. return api_name
  446. def mapping_api(self, call_node, check_context=True):
  447. """
  448. Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
  449. If do not check context of the script, the code represented by the node must be written in the standard way.
  450. Args:
  451. call_node (ast.Call): The ast node to convert.
  452. check_context (boolean): If True, the code context will be checked. Default is True.
  453. Returns:
  454. str, the converted code.
  455. """
  456. if not isinstance(call_node, ast.Call):
  457. raise NodeTypeNotSupport("It is not ast.Call node.")
  458. code = pasta.dump(call_node)
  459. api_call_name = pasta.dump(call_node.func)
  460. if api_call_name.startswith('self.'):
  461. return code
  462. new_code = self._mapping_api(call_node, check_context)
  463. return new_code
  464. def _mapping_api(self, call_node, check_context=True):
  465. """
  466. Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
  467. If do not check context of the script, the code represented by the node must be written in the standard way.
  468. Args:
  469. call_node (ast.Call): The ast node to convert.
  470. check_context (boolean): If True, the code context will be checked. Default is True.
  471. Returns:
  472. str, the converted code.
  473. """
  474. code = pasta.dump(call_node)
  475. api_call_name = pasta.dump(call_node.func)
  476. # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
  477. args_str = '(' + self._get_call_parameters_str(call_node) + ')'
  478. try:
  479. api_name, _ = self._infer_api_name(call_node.func, check_context)
  480. standard_api_call_name = api_call_name
  481. if api_name != api_call_name:
  482. # api name .view inferred from out.view, split tensor object name is out
  483. tensor_obj_name = api_call_name[:-len(api_name)]
  484. map_helper = ALL_MAPPING[api_name]
  485. new_code = map_helper.convert(tensor_obj_name, args_str)
  486. else:
  487. # change to external ref name
  488. # e.g., mm.ReLU will be changed to nn.ReLU if 'import torch.nn as mm' in script.
  489. if check_context and not self._code_analyzer.is_standard_external_ref:
  490. standard_api_call_name = self._mapping_standard_api_name(api_name)
  491. map_helper = ALL_MAPPING[standard_api_call_name]
  492. new_code = map_helper.convert(standard_api_call_name, args_str)
  493. except KeyError:
  494. return code
  495. return new_code
  496. def visit_Call(self, node):
  497. """Callback function when visit AST tree"""
  498. code = pasta.dump(node)
  499. api_name = pasta.dump(node.func)
  500. # The parent node first call is equal to this node, skip when parent node is replaced.
  501. # This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to
  502. # P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting.
  503. # Access from the penultimate element in reverse order.
  504. for parent_node in self._stack[-2::-1]:
  505. if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name):
  506. return
  507. parent = self._stack[-2]
  508. new_node = None
  509. new_code = code
  510. matched_api_name, match_case = self.match_api(node.func, self._is_forward_function)
  511. if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]:
  512. warning_info = get_prompt_info(matched_api_name)
  513. if warning_info is None:
  514. warning_info = ''
  515. if matched_api_name in ALL_MAPPING:
  516. logger.info("Line %3d start converting API: %s", node.lineno, api_name)
  517. new_code = self.mapping_api(node)
  518. if new_code != code:
  519. try:
  520. new_node = pasta.parse(new_code).body[0].value
  521. # find the first call name
  522. new_api_name = new_code[:new_code.find('(')]
  523. except AttributeError:
  524. new_node = pasta.parse(new_code).body[0]
  525. new_api_name = new_code
  526. self._process_log.info(node.lineno, node.col_offset,
  527. LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info))
  528. else:
  529. logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info)
  530. self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info))
  531. elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]:
  532. self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, ''))
  533. else:
  534. pass
  535. if parent and new_node:
  536. update_line_col = _LineColEditVisitor()
  537. update_line_col.update(new_node, node)
  538. pasta.ast_utils.replace_child(parent, node, new_node)
  539. self._new_call_nodes.append(new_node)
  540. node = new_node
  541. self._stack[-1] = node
  542. try:
  543. self.generic_visit(node)
  544. except Exception:
  545. logger.error('original code:%s, new code:%s', code, new_code, exc_info=True)
  546. raise
  547. def _mapping_standard_external_ref(self):
  548. """Obtain the mapping dict of mapping the external references to standard external references."""
  549. renames = {}
  550. external_refs = self._code_analyzer.external_references
  551. for ref_name, ref_info in external_refs.items():
  552. external_ref_info = ref_info['external_ref_info']
  553. if ref_name != 'nn' and external_ref_info.name == 'torch.nn':
  554. renames[ref_name] = 'nn'
  555. elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional':
  556. renames[ref_name] = 'F'
  557. return renames
  558. def _get_external_ref_whole_name(self, ref_name):
  559. """
  560. Find out external reference whole name.
  561. For example:
  562. In the parsed source code, there is import statement
  563. import torch.nn as new_name
  564. _get_external_ref_whole_name('new_name') will return 'torch.nn' string.
  565. """
  566. external_refs = self._code_analyzer.external_references
  567. for external_ref_name, ref_info in external_refs.items():
  568. external_ref_info = ref_info['external_ref_info']
  569. if external_ref_name == ref_name:
  570. return external_ref_info.name
  571. return None
  572. def _check_isinstance_parameter(self, node):
  573. """Check whether the second parameter of isinstance function contains the torch type."""
  574. is_isinstance_arg = False
  575. # Check whether node is the second parameter of the isinstance function call.
  576. # Access from the penultimate element in reverse order.
  577. for parent_node in self._stack[-2::-1]:
  578. if isinstance(parent_node, ast.Call) and pasta.dump(parent_node.func) == 'isinstance':
  579. isinstance_node = parent_node
  580. seconde_arg_type_nodes = []
  581. if isinstance(isinstance_node.args[1], ast.Tuple):
  582. seconde_arg_type_nodes.extend(isinstance_node.args[1].elts)
  583. else:
  584. seconde_arg_type_nodes.append(isinstance_node.args[1])
  585. if node in seconde_arg_type_nodes:
  586. is_isinstance_arg = True
  587. break
  588. if not is_isinstance_arg:
  589. return False
  590. isinstance_type_arg = pasta.dump(node)
  591. check_torch_type = False
  592. if isinstance_type_arg:
  593. type_splits = isinstance_type_arg.split('.')
  594. whole_name = self._get_external_ref_whole_name(type_splits[0])
  595. if whole_name and whole_name.startswith('torch'):
  596. check_torch_type = True
  597. if check_torch_type:
  598. _, match_case = self.match_api(node, False)
  599. if match_case != ApiMatchingEnum.NOT_API:
  600. warn_info = 'Manually determine the conversion type.'
  601. self._process_log.warning(node.lineno, node.col_offset,
  602. LOG_FMT_NOT_CONVERT % (isinstance_type_arg, warn_info))
  603. return check_torch_type
  604. def visit_Attribute(self, node):
  605. """Callback function when visit AST tree"""
  606. self._check_isinstance_parameter(node)
  607. self.generic_visit(node)