You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_stubapi.py 22 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. import os
  2. import re
  3. import sys
  4. import logging
  5. logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s',
  6. level=logging.INFO)
  7. """
  8. this attr is used for symbol table visible
  9. """
  10. GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY'
  11. """
  12. generate stub func body by return type
  13. """
  14. RETURN_STATEMENTS = {
  15. 'graphStatus': ' return GRAPH_SUCCESS;',
  16. 'Status': ' return SUCCESS;',
  17. 'Graph': ' return Graph();',
  18. 'Graph&': ' return *this;',
  19. 'Format': ' return Format();',
  20. 'Format&': ' return *this;',
  21. 'Shape': ' return Shape();',
  22. 'Shape&': ' return *this;',
  23. 'TensorDesc': ' return TensorDesc();',
  24. 'TensorDesc&': ' return *this;',
  25. 'Tensor': ' return Tensor();',
  26. 'Tensor&': ' return *this;',
  27. 'Operator': ' return Operator();',
  28. 'Operator&': ' return *this;',
  29. 'Ptr': ' return nullptr;',
  30. 'std::string': ' return "";',
  31. 'std::string&': ' return "";',
  32. 'string': ' return "";',
  33. 'int': ' return 0;',
  34. 'DataType': ' return DT_FLOAT;',
  35. 'InferenceContextPtr': ' return nullptr;',
  36. 'SubgraphBuilder': ' return nullptr;',
  37. 'OperatorImplPtr': ' return nullptr;',
  38. 'OutHandler': ' return nullptr;',
  39. 'std::vector<std::string>': ' return {};',
  40. 'std::vector<int64_t>': ' return {};',
  41. 'std::map': ' return {};',
  42. 'uint32_t': ' return 0;',
  43. 'int64_t': ' return 0;',
  44. 'uint64_t': ' return 0;',
  45. 'size_t': ' return 0;',
  46. 'float': ' return 0.0f;',
  47. 'bool': ' return false;',
  48. }
  49. """
  50. max code len per line in hua_wei software programming specifications
  51. """
  52. max_code_len_per_line = 100
  53. """
  54. white_list_for_debug, include_dir_key_words is to
  55. determines which header files to generate cc files from
  56. when DEBUG on
  57. """
  58. white_list_for_debug = ["operator.h", "tensor.h",
  59. "graph.h", "operator_factory.h",
  60. "ge_ir_build.h"]
  61. include_dir_key_words = ["ge", "graph"]
  62. DEBUG = True
  63. def need_generate_func(func_line):
  64. """
  65. :param func_line:
  66. :return:
  67. """
  68. if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \
  69. or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"):
  70. return False
  71. return True
  72. def file_endswith_white_list_suffix(file):
  73. """
  74. :param file:
  75. :return:
  76. """
  77. if DEBUG:
  78. for suffix in white_list_for_debug:
  79. if file.endswith(suffix):
  80. return True
  81. return False
  82. else:
  83. return True
  84. """
  85. belows are patterns used for analyse .h file
  86. """
  87. # pattern function
  88. pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after
  89. ([a-zA-Z~_] # void int likely
  90. .*
  91. [)] #we find )
  92. (?!.*{) # we do not want the case int abc() const { return 1;}
  93. .*)
  94. (;.*) #we want to find ; and after for we will replace these later
  95. \n$
  96. """, re.VERBOSE | re.MULTILINE | re.DOTALL)
  97. # pattern comment
  98. pattern_comment = re.compile(r'^\s*//')
  99. pattern_comment_2_start = re.compile(r'^\s*/[*]')
  100. pattern_comment_2_end = re.compile(r'[*]/\s*$')
  101. # pattern define
  102. pattern_define = re.compile(r'^\s*#define')
  103. pattern_define_return = re.compile(r'\\\s*$')
  104. # blank line
  105. pattern_blank_line = re.compile(r'^\s*$')
  106. # virtual,explicit,friend,static
  107. pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)')
  108. # lead space
  109. pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]')
  110. # functions will have patterns such as func ( or func(
  111. # but operator is an exception; the class name is preceded by an operator, and the above mode does not exist
  112. # format like :"operator = ()"
  113. pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]')
  114. # template
  115. pattern_template = re.compile(r'^\s*template')
  116. pattern_template_end = re.compile(r'>\s*$')
  117. # namespace
  118. pattern_namespace = re.compile(r'namespace.*{')
  119. # class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with
  120. pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+<?)(?!.*;)' % GE_ATTR)
  121. # {}
  122. pattern_start = re.compile('{')
  123. pattern_end = re.compile('}')
  124. line_index = 0
  125. class H2CC(object):
  126. def __init__(self, input_file, output_file, shared_includes_content):
  127. """
  128. :param input_file:
  129. :param output_file:
  130. :param shared_includes_content:
  131. """
  132. self.input_file = input_file
  133. self.output_file = output_file
  134. self.shared_includes_content = shared_includes_content
  135. self.line_index = 0
  136. self.input_fd = open(self.input_file, 'r')
  137. self.input_content = self.input_fd.readlines()
  138. self.output_fd = open(self.output_file, 'w')
  139. # The state may be normal_now(in the middle of {}),class_now,namespace_now
  140. self.stack = []
  141. self.stack_class = []
  142. self.stack_template = []
  143. # record funcs generated by h2cc func
  144. self.func_list_exist = []
  145. def __del__(self):
  146. self.input_fd.close()
  147. self.output_fd.close()
  148. del self.stack
  149. del self.stack_class
  150. del self.stack_template
  151. del self.func_list_exist
  152. def just_skip(self):
  153. # skip blank line or comment
  154. if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search(
  155. self.input_content[self.line_index]): # /n or comment using //
  156. self.line_index += 1
  157. if pattern_comment_2_start.search(self.input_content[self.line_index]): # comment using /*
  158. while not pattern_comment_2_end.search(self.input_content[self.line_index]): # */
  159. self.line_index += 1
  160. self.line_index += 1
  161. # skip define
  162. if pattern_define.search(self.input_content[self.line_index]):
  163. while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search(
  164. self.input_content[self.line_index]):
  165. self.line_index += 1
  166. self.line_index += 1
  167. def write_inc_content(self):
  168. for shared_include_content in self.shared_includes_content:
  169. self.output_fd.write(shared_include_content)
  170. def h2cc(self):
  171. """
  172. :return:
  173. """
  174. logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file)
  175. global pattern_comment
  176. global pattern_comment_2_start
  177. global pattern_comment_2_end
  178. global pattern_blank_line
  179. global pattern_func
  180. global pattern_keyword
  181. global pattern_leading_space
  182. global pattern_func_name
  183. global pattern_template
  184. global pattern_template_end
  185. global pattern_namespace
  186. global pattern_class
  187. global pattern_start
  188. global pattern_end
  189. global line_index
  190. # write inc content
  191. self.write_inc_content()
  192. # core processing cycle, process the input .h file by line
  193. while self.line_index < len(self.input_content):
  194. # handle comment and blank line
  195. self.just_skip()
  196. # match namespace
  197. self.handle_namespace()
  198. # match template
  199. template_string = self.handle_template()
  200. # match class
  201. line = self.input_content[self.line_index]
  202. match_class = pattern_class.search(line)
  203. match_start = pattern_start.search(line)
  204. handle_class_result = self.handle_class(template_string, line, match_start, match_class)
  205. if handle_class_result == "continue":
  206. continue
  207. # match "}"
  208. handle_stack_result = self.handle_stack(match_start)
  209. if handle_stack_result == "continue":
  210. continue
  211. # handle func
  212. handle_func1_result, line, start_i = self.handle_func1(line)
  213. if handle_func1_result == "continue":
  214. continue
  215. # here means func is found
  216. # delete key word
  217. line = pattern_keyword.sub('', line)
  218. logging.info("line[%s]", line)
  219. # Class member function
  220. # if friend we will not add class name
  221. friend_match = re.search('friend ', line)
  222. if len(self.stack_class) > 0 and not friend_match:
  223. line, func_name = self.handle_class_member_func(line, template_string)
  224. # Normal functions
  225. else:
  226. line, func_name = self.handle_normal_func(line, template_string)
  227. need_generate = need_generate_func(line)
  228. # func body
  229. line += self.implement_function(line)
  230. # comment
  231. line = self.gen_comment(start_i) + line
  232. # write to out file
  233. self.write_func_content(line, func_name, need_generate)
  234. # next loop
  235. self.line_index += 1
  236. logging.info('Added %s functions', len(self.func_list_exist))
  237. logging.info('Successfully converted,please see ' + self.output_file)
  238. def handle_func1(self, line):
  239. """
  240. :param line:
  241. :return:
  242. """
  243. find1 = re.search('[(]', line)
  244. if not find1:
  245. self.line_index += 1
  246. return "continue", line, None
  247. find2 = re.search('[)]', line)
  248. start_i = self.line_index
  249. space_match = pattern_leading_space.search(line)
  250. # deal with
  251. # int abc(int a,
  252. # int b)
  253. if find1 and (not find2):
  254. self.line_index += 1
  255. line2 = self.input_content[self.line_index]
  256. if space_match:
  257. line2 = re.sub('^' + space_match.group(1), '', line2)
  258. line += line2
  259. while self.line_index < len(self.input_content) and (not re.search('[)]', line2)):
  260. self.line_index += 1
  261. line2 = self.input_content[self.line_index]
  262. line2 = re.sub('^' + space_match.group(1), '', line2)
  263. line += line2
  264. match_start = pattern_start.search(self.input_content[self.line_index])
  265. match_end = pattern_end.search(self.input_content[self.line_index])
  266. if match_start: # like ) { or ) {} int the last line
  267. if not match_end:
  268. self.stack.append('normal_now')
  269. ii = start_i
  270. while ii <= self.line_index:
  271. ii += 1
  272. self.line_index += 1
  273. return "continue", line, start_i
  274. logging.info("line[%s]", line)
  275. # ' int abc();'->'int abc()'
  276. (line, match) = pattern_func.subn(r'\2\n', line)
  277. logging.info("line[%s]", line)
  278. # deal with case:
  279. # 'int \n abc(int a, int b)'
  280. if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]):
  281. line = self.input_content[start_i - 1] + line
  282. line = line.lstrip()
  283. if not match:
  284. self.line_index += 1
  285. return "continue", line, start_i
  286. return "pass", line, start_i
  287. def handle_stack(self, match_start):
  288. """
  289. :param match_start:
  290. :return:
  291. """
  292. line = self.input_content[self.line_index]
  293. match_end = pattern_end.search(line)
  294. if match_start:
  295. self.stack.append('normal_now')
  296. if match_end:
  297. top_status = self.stack.pop()
  298. if top_status == 'namespace_now':
  299. self.output_fd.write(line + '\n')
  300. elif top_status == 'class_now':
  301. self.stack_class.pop()
  302. self.stack_template.pop()
  303. if match_start or match_end:
  304. self.line_index += 1
  305. return "continue"
  306. if len(self.stack) > 0 and self.stack[-1] == 'normal_now':
  307. self.line_index += 1
  308. return "continue"
  309. return "pass"
  310. def handle_class(self, template_string, line, match_start, match_class):
  311. """
  312. :param template_string:
  313. :param line:
  314. :param match_start:
  315. :param match_class:
  316. :return:
  317. """
  318. if match_class: # we face a class
  319. self.stack_template.append(template_string)
  320. self.stack.append('class_now')
  321. class_name = match_class.group(3)
  322. # class template specializations: class A<u,Node<u> >
  323. if '<' in class_name:
  324. k = line.index('<')
  325. fit = 1
  326. for ii in range(k + 1, len(line)):
  327. if line[ii] == '<':
  328. fit += 1
  329. if line[ii] == '>':
  330. fit -= 1
  331. if fit == 0:
  332. break
  333. class_name += line[k + 1:ii + 1]
  334. logging.info('class_name[%s]', class_name)
  335. self.stack_class.append(class_name)
  336. while not match_start:
  337. self.line_index += 1
  338. line = self.input_content[self.line_index]
  339. match_start = pattern_start.search(line)
  340. self.line_index += 1
  341. return "continue"
  342. return "pass"
  343. def handle_template(self):
  344. line = self.input_content[self.line_index]
  345. match_template = pattern_template.search(line)
  346. template_string = ''
  347. if match_template:
  348. match_template_end = pattern_template_end.search(line)
  349. template_string = line
  350. while not match_template_end:
  351. self.line_index += 1
  352. line = self.input_content[self.line_index]
  353. template_string += line
  354. match_template_end = pattern_template_end.search(line)
  355. self.line_index += 1
  356. return template_string
  357. def handle_namespace(self):
  358. line = self.input_content[self.line_index]
  359. match_namespace = pattern_namespace.search(line)
  360. if match_namespace: # we face namespace
  361. self.output_fd.write(line + '\n')
  362. self.stack.append('namespace_now')
  363. self.line_index += 1
  364. def handle_normal_func(self, line, template_string):
  365. template_line = ''
  366. self.stack_template.append(template_string)
  367. if self.stack_template[-1] != '':
  368. template_line = re.sub(r'\s*template', 'template', self.stack_template[-1])
  369. # change '< class T = a, class U = A(3)>' to '<class T, class U>'
  370. template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
  371. template_line = re.sub(r'\s*=.*,', ',', template_line)
  372. template_line = re.sub(r'\s*=.*', '', template_line)
  373. line = re.sub(r'\s*=.*,', ',', line)
  374. line = re.sub(r'\s*=.*\)', ')', line)
  375. line = template_line + line
  376. self.stack_template.pop()
  377. func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
  378. logging.info("line[%s]", line)
  379. logging.info("func_name[%s]", func_name)
  380. return line, func_name
  381. def handle_class_member_func(self, line, template_string):
  382. template_line = ''
  383. x = ''
  384. if template_string != '':
  385. template_string = re.sub(r'\s*template', 'template', template_string)
  386. template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string)
  387. template_string = re.sub(r'\s*=.*,', ',', template_string)
  388. template_string = re.sub(r'\s*=.*', '', template_string)
  389. if self.stack_template[-1] != '':
  390. if not (re.search(r'<\s*>', stack_template[-1])):
  391. template_line = re.sub(r'^\s*template', 'template', stack_template[-1])
  392. if not (re.search(r'<.*>', self.stack_class[-1])):
  393. # for x we get like template<class T, typename U> -> <T,U>
  394. x = re.sub(r'template\s*<', '<', template_line) # remove template -> <class T, typename U>
  395. x = re.sub(r'\n', '', x)
  396. x = re.sub(r'\s*=.*,', ',', x)
  397. x = re.sub(r'\s*=.*\>', '>', x)
  398. x = x.rstrip() # remove \n
  399. x = re.sub(r'(class|typename)\s+|(<class>|<typename>\s*class)', '',
  400. x) # remove class,typename -> <T, U>
  401. x = re.sub(r'<\s+', '<', x)
  402. x = re.sub(r'\s+>', '>', x)
  403. x = re.sub(r'\s+,', ',', x)
  404. x = re.sub(r',\s+', ', ', x)
  405. line = re.sub(r'\s*=\s+0', '', line)
  406. line = re.sub(r'\s*=\s+.*,', ',', line)
  407. line = re.sub(r'\s*=\s+.*\)', ')', line)
  408. logging.info("x[%s]\nline[%s]", x, line)
  409. # if the function is long, void ABC::foo()
  410. # breaks into two lines void ABC::\n foo()
  411. temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1)
  412. if len(temp_line) > max_code_len_per_line:
  413. line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1)
  414. else:
  415. line = temp_line
  416. logging.info("line[%s]", line)
  417. # add template as the above if there is one
  418. template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
  419. template_line = re.sub(r'\s*=.*,', ',', template_line)
  420. template_line = re.sub(r'\s*=.*', '', template_line)
  421. line = template_line + template_string + line
  422. func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
  423. logging.info("line[%s]", line)
  424. logging.info("func_name[%s]", func_name)
  425. return line, func_name
  426. def write_func_content(self, content, func_name, need_generate):
  427. if not (func_name in self.func_list_exist) and need_generate:
  428. self.output_fd.write(content)
  429. self.func_list_exist.append(func_name)
  430. logging.info('add func:[%s]', func_name)
  431. def gen_comment(self, start_i):
  432. comment_line = ''
  433. # Function comments are on top of function declarations, copy them over
  434. k = start_i - 1 # one line before this func start
  435. if pattern_template.search(self.input_content[k]):
  436. k -= 1
  437. if pattern_comment_2_end.search(self.input_content[k]):
  438. comment_line = self.input_content[k].lstrip()
  439. while not pattern_comment_2_start.search(self.input_content[k]):
  440. k -= 1
  441. comment_line = self.input_content[k].lstrip() + comment_line
  442. else:
  443. for j in range(k, 0, -1):
  444. c_line = self.input_content[j]
  445. if pattern_comment.search(c_line):
  446. c_line = re.sub(r'\s*//', '//', c_line)
  447. comment_line = c_line + comment_line
  448. else:
  449. break
  450. return comment_line
  451. @staticmethod
  452. def implement_function(func):
  453. function_def = ''
  454. function_def += '{\n'
  455. all_items = func.split()
  456. start = 0
  457. return_type = all_items[start]
  458. if return_type == "const":
  459. start += 1
  460. return_type = all_items[start]
  461. if return_type.startswith(('std::map', 'std::set', 'std::vector')):
  462. return_type = "std::map"
  463. if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')):
  464. return_type = "Ptr"
  465. if len(all_items) > start + 1 and all_items[start + 1].startswith('&'):
  466. return_type += "&"
  467. if RETURN_STATEMENTS.__contains__(return_type):
  468. function_def += RETURN_STATEMENTS[return_type]
  469. else:
  470. logging.warning("Unhandled return type[%s]", return_type)
  471. function_def += '\n'
  472. function_def += '}\n'
  473. function_def += '\n'
  474. return function_def
  475. def collect_header_files(path):
  476. """
  477. :param path:
  478. :return:
  479. """
  480. header_files = []
  481. shared_includes_content = []
  482. for root, dirs, files in os.walk(path):
  483. files.sort()
  484. for file in files:
  485. if file.find("git") >= 0:
  486. continue
  487. if not file.endswith('.h'):
  488. continue
  489. file_path = os.path.join(root, file)
  490. file_path = file_path.replace('\\', '/')
  491. header_files.append(file_path)
  492. include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:])
  493. shared_includes_content.append(include_str)
  494. return header_files, shared_includes_content
  495. def generate_stub_file(inc_dir, out_cc_dir):
  496. """
  497. :param inc_dir:
  498. :param out_cc_dir:
  499. :return:
  500. """
  501. target_header_files, shared_includes_content = collect_header_files(inc_dir)
  502. for header_file in target_header_files:
  503. if not file_endswith_white_list_suffix(header_file):
  504. continue
  505. cc_file = re.sub('.h*$', '.cc', header_file)
  506. h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content)
  507. h_2_cc.h2cc()
  508. def gen_code(inc_dir, out_cc_dir):
  509. """
  510. :param inc_dir:
  511. :param out_cc_dir:
  512. :return:
  513. """
  514. if not inc_dir.endswith('/'):
  515. inc_dir += '/'
  516. if not out_cc_dir.endswith('/'):
  517. out_cc_dir += '/'
  518. for include_dir_key_word in include_dir_key_words:
  519. generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir)
  520. if __name__ == '__main__':
  521. inc_dir = sys.argv[1]
  522. out_cc_dir = sys.argv[2]
  523. gen_code(inc_dir, out_cc_dir)

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知.