You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

converter.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless REQUIRED by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """converter module"""
  16. import os
  17. import stat
  18. import pasta
  19. from mindinsight.mindconverter.common.exceptions import ScriptNotSupport
  20. from mindinsight.mindconverter.common.log import logger
  21. from mindinsight.mindconverter.ast_edits import AstEditVisitor
  22. class Converter:
  23. """Convert class"""
  24. flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
  25. modes = stat.S_IWUSR | stat.S_IRUSR
  26. def __init__(self):
  27. self._tree = None
  28. self._infile = None
  29. self._code_analyzer = None
  30. self._ast_editor = None
  31. self._report = []
  32. def convert(self, infile, output_dir, report_dir):
  33. """
  34. Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
  35. Args:
  36. infile (str): The script to convert.
  37. output_dir (str): The path to save converted file.
  38. report_dir (str): The path to save report file.
  39. """
  40. in_file_split = _path_split(infile)
  41. in_file_split[-1], _ = _get_name_ext(in_file_split[-1])
  42. module_name = '.'.join(in_file_split)
  43. with open(infile, 'r') as file:
  44. content = ''.join(file.readlines())
  45. self._infile = infile
  46. self._tree = pasta.parse(content)
  47. self._report.clear()
  48. try:
  49. logger.info("Script file is %s", infile)
  50. logger.info("Start converting %s", module_name)
  51. self._report.append('[Start Convert]')
  52. self._ast_editor = AstEditVisitor()
  53. self._ast_editor.process(self._tree)
  54. self._report.extend(self._ast_editor.get_logs())
  55. self._report.append('[Convert Over]')
  56. dest_file = os.path.join(output_dir, os.path.basename(infile))
  57. with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file:
  58. script = pasta.dump(self._tree)
  59. script = adjust_mindspore_import_position(script)
  60. file.write(script)
  61. logger.info("Convert success. Result is wrote to %s.", dest_file)
  62. except ScriptNotSupport as error:
  63. self._report.append('[ScriptNotSupport] ' + error.message)
  64. self._report.append('[Convert failed]')
  65. raise error
  66. except Exception as error:
  67. self._report.clear()
  68. raise error
  69. finally:
  70. if self._report:
  71. dest_report_file = os.path.join(report_dir, f"report_of_{os.path.basename(infile).split('.')[0]}.txt")
  72. with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file:
  73. file.write('\n'.join(self._report))
  74. logger.info("Convert report is saved in %s", dest_report_file)
  75. @staticmethod
  76. def convert_api(source_code):
  77. """
  78. Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
  79. Args:
  80. source_code (ast.Call): The ast node to convert.
  81. Returns:
  82. str, the converted code.
  83. """
  84. ast_node = pasta.parse(source_code).body[0].value
  85. check_context = False
  86. replaced_code = AstEditVisitor().mapping_api(ast_node, check_context)
  87. return replaced_code
  88. def get_code_start_line_num(source_lines):
  89. """
  90. Get the start code line number exclude comments.
  91. Args:
  92. source_lines (list[str]): Split results of code.
  93. Returns:
  94. int, the start line number.
  95. """
  96. stack = []
  97. index = 0
  98. for i, line in enumerate(source_lines):
  99. line_strip = line.strip()
  100. if line_strip.startswith('#'):
  101. continue
  102. if line_strip.startswith('"""'):
  103. if not line_strip.endswith('"""'):
  104. stack.append('"""')
  105. continue
  106. if line_strip.startswith("'''"):
  107. if not line_strip.endswith("'''"):
  108. stack.append("'''")
  109. continue
  110. if line_strip.endswith('"""') or line_strip.endswith("'''"):
  111. stack.pop()
  112. continue
  113. if line_strip != '' and not stack:
  114. index = i
  115. break
  116. return index
  117. def adjust_mindspore_import_position(script):
  118. """
  119. Adjust code sentence `import mindspore` in script to a proper position if the sentence is set before a comment.
  120. Args:
  121. script (str): code script before adjust.
  122. Returns:
  123. str, code script adjusted.
  124. """
  125. script_list = script.split('\n')
  126. import_ms_sentence = 'import mindspore'
  127. if import_ms_sentence in script_list:
  128. import_index = script_list.index(import_ms_sentence)
  129. if script_list[import_index + 1].startswith('"""') or script_list[import_index + 1].startswith("'''"):
  130. script_list.pop(import_index)
  131. new_index = get_code_start_line_num(script_list)
  132. script_list.insert(new_index, import_ms_sentence)
  133. script = '\n'.join(script_list)
  134. return script
  135. def _get_name_ext(file):
  136. """
  137. Split a file name in name and extension.
  138. Args:
  139. file (str): Full file path.
  140. Returns:
  141. tuple (str, str), name and extension.
  142. """
  143. _, name = os.path.split(file)
  144. return os.path.splitext(name)
  145. def _path_split(file):
  146. """
  147. Split a path in head and tail.
  148. Args:
  149. file (str): The file path.
  150. Returns:
  151. list[str], list of file tail
  152. """
  153. file_dir, name = os.path.split(file)
  154. if file_dir:
  155. sep = file[len(file_dir) - 1]
  156. if file_dir.startswith(sep):
  157. return file.split(sep)[1:]
  158. return file.split(sep)
  159. return [name]
  160. def main(files_config):
  161. """
  162. The entrance for converter, script files will be converted.
  163. Args:
  164. files_config (dict): The config of files which to convert.
  165. """
  166. convert_ins = Converter()
  167. in_files = files_config['in_files']
  168. for in_file in in_files:
  169. convert_ins.convert(in_file, files_config['outfile_dir'], files_config['report_dir'])