You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless REQUIRED by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """API config"""
  16. import ast
  17. from collections import OrderedDict
  18. from functools import partial
  19. import json
  20. import os
  21. import pasta
  22. from mindinsight.mindconverter.enums import RequriedType
  23. from mindinsight.mindconverter.common.log import logger
  24. REQUIRED = RequriedType.REQUIRED.name
  25. UNREQUIRED = RequriedType.UNREQUIRED.name
  26. class APIPt:
  27. """Base API for args parse, and API for one frame."""
  28. def __init__(self, name: str, params: OrderedDict):
  29. self.name = name
  30. self.params = OrderedDict()
  31. for k, value in params.items():
  32. self.params[k] = self.to_str(value)
  33. @staticmethod
  34. def to_str(value):
  35. """
  36. Trans value to str.
  37. Args:
  38. value (Union[str,Number,int]): Each value for params of OrderedDict.
  39. Returns:
  40. str, str type of value.
  41. """
  42. if value is REQUIRED:
  43. return value
  44. if isinstance(value, str):
  45. return "'{}'".format(value)
  46. return str(value)
  47. def parse_args(self, call_name: str, args_str: str):
  48. """
  49. Parse call_name and args_str.
  50. Args:
  51. call_name (str): str of the call function, etc.
  52. args_str (str): str of args for function, which starts with '(' and end with ')'.
  53. Returns:
  54. OrderedDict, all args parsed.
  55. Raises:
  56. ValueError: If can not use ast to parse or the required parse node not type of ast.Call,
  57. or the given args_str not valid.
  58. """
  59. # expr is REQUIRED to meet (**) format
  60. if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"):
  61. raise ValueError('[{}] is think as args str, it should start with "(" and end with ")"'.format(args_str))
  62. try:
  63. ast_node = ast.parse("whatever_call_name" + args_str)
  64. call_node = ast_node.body[0].value
  65. if not isinstance(call_node, ast.Call):
  66. raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str))
  67. except:
  68. raise ValueError("can't parse code:\n{}".format(args_str))
  69. # regard all actual parameter as one parameter
  70. if len(self.params) == 1:
  71. k = list(self.params.keys())[0]
  72. if k.startswith('*'):
  73. value = args_str[1:-1]
  74. return OrderedDict([(k, value), ("call_name", call_name)])
  75. args = OrderedDict()
  76. # param which name not assigned
  77. param_iter = iter(self.params.keys())
  78. if len(call_node.args) > len(self.params):
  79. raise ValueError('Parse args of torch in {}, but there is problems with params'.format(call_name))
  80. for arg in call_node.args:
  81. if isinstance(arg, ast.Starred):
  82. logger.debug("Find *%s", arg.value.id)
  83. args['*'] = arg.value.id
  84. else:
  85. # remove \n
  86. args[next(param_iter)] = pasta.dump(arg).strip()
  87. # params which name is assigned
  88. for keyword in call_node.keywords:
  89. if keyword.arg is None:
  90. logger.info("Find **%s", keyword.value.id)
  91. args['**'] = keyword.value.id
  92. else:
  93. # remove \n
  94. args[keyword.arg] = pasta.dump(keyword.value).strip()
  95. args["call_name"] = call_name
  96. return args
  97. class APIMs(APIPt):
  98. """API for MindSpore"""
  99. def __init__(self, name: str, params: OrderedDict, p_attrs=None):
  100. self.is_primitive = name.startswith('P.')
  101. if self.is_primitive:
  102. self.p_attrs = p_attrs if p_attrs else set()
  103. super(APIMs, self).__init__(name, params)
  104. def create_args(self, params_pt, args_pt, ms2pt_map, explicit_map):
  105. """
  106. Create args for MindSpore according to other frame op info.
  107. Args:
  108. params_pt (OrderedDict): Params used for initialize function of APIPt.
  109. args_pt (OrderedDict): Args parsed from APIPt.
  110. ms2pt_map (dict): Dict of params mapping relation for ops between frames.
  111. explicit_map(func): Function to generate mapping relation for ops between frames.
  112. Returns:
  113. OrderedDict, args for MindSpore.
  114. """
  115. args = OrderedDict()
  116. # traverse MindSpore's params
  117. for k in self.params.keys():
  118. # has relevant param? yes
  119. if k in ms2pt_map:
  120. if ms2pt_map[k] in args_pt:
  121. # user assigned value
  122. args[k] = args_pt[ms2pt_map[k]]
  123. elif self.params[k] != params_pt[ms2pt_map[k]]:
  124. # user didn't assigned value, but initial value different between 2 frames
  125. args[k] = params_pt[ms2pt_map[k]]
  126. # has relevant param? no
  127. else:
  128. # params forced to display
  129. if k in explicit_map:
  130. args[k] = explicit_map[k]
  131. elif self.params[k] is REQUIRED:
  132. args[k] = "<REQUIRED>"
  133. # find * or ** in frame actual parameters
  134. for star in ('*', '**'):
  135. if star in args_pt:
  136. args[star] = args_pt[star]
  137. return args
  138. class MappingHelper:
  139. """Mapping from one frame to another frame"""
  140. def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs):
  141. ms2pt_mapping = kwargs.get('ms2pt_mapping')
  142. gen_explicit_map = kwargs.get('gen_explicit_map')
  143. export_key = kwargs.get('export_key')
  144. if ms2pt_mapping is None:
  145. ms2pt_mapping = {}
  146. if gen_explicit_map is None:
  147. gen_explicit_map = lambda params_pt, args_pt: {}
  148. self.ms_api = ms_api
  149. self.pt_api = pt_api
  150. self.ms2pt_mapping = ms2pt_mapping
  151. self.gen_explicit_map = gen_explicit_map
  152. if export_key is not None:
  153. self.export_key = export_key
  154. else:
  155. self.export_key = not ms_api.is_primitive
  156. def gen_args_expr(self, args):
  157. """
  158. Generate str assignment statement from given dict.
  159. Args:
  160. args (OrderedDict): Key, value pairs for assignment source.
  161. Returns:
  162. str, generated str.
  163. """
  164. expr = ''
  165. for key, value in args.items():
  166. if expr:
  167. expr += ', '
  168. sym = '' if key in ('*', '**') else '='
  169. if self.export_key:
  170. expr += key + sym
  171. expr += value
  172. return expr
  173. def gen_args_expr_for_p(self, args, p_attrs):
  174. """
  175. Generate str assignment statement from given dict for primitive and not primitive.
  176. Args:
  177. args (OrderedDict): Key, value pairs for assignment source.
  178. p_attrs (set): Exclusive params for operator.
  179. Returns:
  180. tuple, generated str for primitive, generated str for not primitive.
  181. """
  182. args_attrs = OrderedDict([(k, v) for k, v in args.items() if k in p_attrs])
  183. args_ios = OrderedDict([(k, v) for k, v in args.items() if k not in p_attrs])
  184. return self.gen_args_expr(args_attrs), self.gen_args_expr(args_ios)
  185. def convert(self, call_name_pt: str, args_str_pt: str):
  186. """
  187. Convert code sentence to MindSpore code sentence.
  188. Args:
  189. call_name_pt (str): str of the call function, etc.
  190. args_str_pt (str): str of args for function, which starts with '(' and end with ')'.
  191. Returns:
  192. str, converted code sentence for MindSpore.
  193. """
  194. # all value for args_pt is str
  195. args_pt = self.pt_api.parse_args(call_name_pt, args_str_pt)
  196. # all value for args_ms is str
  197. explicit_map = self.gen_explicit_map(self.pt_api.params, args_pt)
  198. args_ms = self.ms_api.create_args(self.pt_api.params, args_pt, self.ms2pt_mapping, explicit_map)
  199. if self.ms_api.is_primitive:
  200. if self.pt_api.name == '.size' and 'idx' in args_pt:
  201. args_expr = self.gen_args_expr(args_ms)
  202. expr_ms = "%s()(%s)[%s]" % (self.ms_api.name, args_expr, args_pt['idx'])
  203. else:
  204. expr_attrs, expr_ios = self.gen_args_expr_for_p(args_ms, self.ms_api.p_attrs)
  205. expr_ms = "%s(%s)(%s)" % (self.ms_api.name, expr_attrs, expr_ios)
  206. else:
  207. ms_expr = self.gen_args_expr(args_ms)
  208. expr_ms = "%s(%s)" % (self.ms_api.name, ms_expr)
  209. return expr_ms
  210. def gen_explicit_map_nn_sequential(_, args_pt):
  211. """
  212. Generate explicit_map for nn.Sequential.
  213. Args:
  214. args_pt (dict): Args for APIPt.
  215. Returns:
  216. dict, map between frames.
  217. """
  218. args = args_pt['*args']
  219. return {"*args": "[{}]".format(args)}
  220. def gen_explicit_map_nn_maxpool2d(params_pt, args_pt):
  221. """
  222. Generate explicit_map for nn.MaxPool2d.
  223. Args:
  224. params_pt (dict): Params for APIPt.
  225. args_pt (dict): Args for APIPt.
  226. Returns:
  227. dict, map between frames.
  228. """
  229. if 'padding' in args_pt:
  230. padding = args_pt['padding']
  231. else:
  232. padding = params_pt['padding']
  233. if padding.strip() in ("0", "(0,0)", "(0, 0)"):
  234. pad_mode = "'valid'"
  235. else:
  236. pad_mode = "'same'"
  237. return {"pad_mode": pad_mode}
  238. def gen_explicit_map_f_max_pool2d(params_pt, args_pt):
  239. """
  240. Generate explicit_map for F.MaxPool2d.
  241. Args:
  242. params_pt (dict): Params for APIPt.
  243. args_pt (dict): Args for APIPt.
  244. Returns:
  245. dict, map between frames.
  246. """
  247. if 'padding' in args_pt:
  248. padding = args_pt['padding']
  249. else:
  250. padding = params_pt['padding']
  251. if padding.strip() in ("0", "(0,0)", "(0, 0)"):
  252. padding = "'valid'"
  253. else:
  254. padding = "'same'"
  255. return {"padding": padding}
  256. def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt):
  257. """
  258. Generate explicit_map for which include mapping relationship is `1 - k_ms = k_pt`.
  259. Args:
  260. params_pt (dict): Params for APIPt.
  261. args_pt (dict): Args for APIPt.
  262. Returns:
  263. dict, map between frames.
  264. """
  265. value = args_pt[k_pt] if k_pt in args_pt else params_pt[k_pt]
  266. value = value.strip()
  267. def is_number(string):
  268. try:
  269. float(string)
  270. return True
  271. except ValueError:
  272. return False
  273. if is_number(value):
  274. return {k_ms: str(1 - float(value))}
  275. return {k_ms: "1.0 - " + value}
  276. def load_json_file(file_path):
  277. """
  278. Load data from given json file path.
  279. Args:
  280. file_path (str): The file to load json data from.
  281. Returns:
  282. list, the list data stored in file_path.
  283. """
  284. with open(file_path, 'r', encoding='utf-8') as file:
  285. info = json.loads(file.read())
  286. return info
  287. # ---------------------------- mappings ----------------------------
  288. NN_MAPPING = {
  289. 'nn.Sequential': MappingHelper(**{"ms_api": APIMs('nn.SequentialCell', OrderedDict([('*args', REQUIRED)])),
  290. "pt_api": APIPt('nn.Sequential', OrderedDict([('*args', REQUIRED)])),
  291. "gen_explicit_map": gen_explicit_map_nn_sequential,
  292. "export_key": False
  293. }),
  294. 'nn.Conv2d': MappingHelper(**{"ms_api": APIMs('nn.Conv2d', OrderedDict(in_channels=REQUIRED,
  295. out_channels=REQUIRED,
  296. kernel_size=REQUIRED,
  297. stride=1,
  298. pad_mode='same',
  299. padding=0,
  300. dilation=1,
  301. group=1,
  302. has_bias=False,
  303. weight_init='normal',
  304. bias_init='zeros')),
  305. "pt_api": APIPt('nn.Conv2d', OrderedDict(in_channels=REQUIRED,
  306. out_channels=REQUIRED,
  307. kernel_size=REQUIRED,
  308. stride=1,
  309. padding=0,
  310. dilation=1,
  311. groups=1,
  312. bias=True,
  313. padding_mode='zeros')),
  314. "ms2pt_mapping": {'in_channels': 'in_channels',
  315. 'out_channels': 'out_channels',
  316. 'kernel_size': 'kernel_size',
  317. 'stride': 'stride',
  318. 'padding': 'padding',
  319. 'dilation': 'dilation',
  320. 'group': 'groups',
  321. 'has_bias': 'bias'},
  322. "gen_explicit_map": (lambda params_pt, args_pt: {"pad_mode": "'pad'"})
  323. }),
  324. 'nn.BatchNorm2d': MappingHelper(**{"ms_api": APIMs('nn.BatchNorm2d', OrderedDict(num_features=REQUIRED,
  325. eps=1e-5,
  326. momentum=0.9,
  327. affine=True,
  328. gamma_init='ones',
  329. beta_init='zeros',
  330. moving_mean_init='zeros',
  331. moving_var_init='ones',
  332. use_batch_statistics=True)),
  333. "pt_api": APIPt('nn.BatchNorm2d', OrderedDict(num_features=REQUIRED,
  334. eps=1e-5,
  335. momentum=0.1,
  336. affine=True,
  337. track_running_stats=True)),
  338. "ms2pt_mapping": {"num_features": "num_features",
  339. "eps": "eps",
  340. "affine": "affine",
  341. "use_batch_statistics": "track_running_stats"},
  342. "gen_explicit_map": partial(gen_explicit_map_one_delta,
  343. k_ms="momentum", k_pt="momentum")
  344. }),
  345. 'nn.ReLU': MappingHelper(**{"ms_api": APIMs('nn.ReLU', OrderedDict()),
  346. "pt_api": APIPt('nn.ReLU', OrderedDict(inplace=False)),
  347. "ms2pt_mapping": {}}),
  348. 'nn.ReLU6': MappingHelper(**{"ms_api": APIMs('nn.ReLU6', OrderedDict()),
  349. "pt_api": APIPt('nn.ReLU6', OrderedDict(inplace=False)),
  350. "ms2pt_mapping": {}}),
  351. 'nn.Linear': MappingHelper(**{"ms_api": APIMs('nn.Dense', OrderedDict(in_channels=REQUIRED,
  352. out_channels=REQUIRED,
  353. weight_init='normal',
  354. bias_init='zeros',
  355. has_bias=True,
  356. activation=None)),
  357. "pt_api": APIPt('nn.Linear', OrderedDict(in_features=REQUIRED,
  358. out_features=REQUIRED,
  359. bias=True)),
  360. "ms2pt_mapping": {"in_channels": "in_features",
  361. "out_channels": "out_features",
  362. "has_bias": "bias"}
  363. }),
  364. 'nn.MaxPool2d': MappingHelper(**{"ms_api": APIMs('nn.MaxPool2d', OrderedDict(kernel_size=1,
  365. stride=1,
  366. pad_mode="valid")),
  367. "pt_api": APIPt('nn.MaxPool2d', OrderedDict(kernel_size=REQUIRED,
  368. stride=None,
  369. padding=0,
  370. dilation=1,
  371. return_indices=False,
  372. ceil_mode="False")),
  373. "ms2pt_mapping": {"kernel_size": "kernel_size",
  374. "stride": "stride"},
  375. "gen_explicit_map": gen_explicit_map_nn_maxpool2d
  376. }),
  377. 'nn.AvgPool2d': MappingHelper(**{"ms_api": APIMs('nn.AvgPool2d', OrderedDict(kernel_size=1,
  378. stride=1,
  379. pad_mode="valid")),
  380. "pt_api": APIPt('nn.AvgPool2d', OrderedDict(kernel_size=REQUIRED,
  381. stride=None,
  382. padding=0,
  383. dilation=1,
  384. return_indices=False,
  385. ceil_mode="False")),
  386. "ms2pt_mapping": {"kernel_size": "kernel_size",
  387. "stride": "stride"},
  388. "gen_explicit_map": gen_explicit_map_nn_maxpool2d
  389. }),
  390. 'nn.Dropout': MappingHelper(**{"ms_api": APIMs('nn.Dropout', OrderedDict(keep_prob=0.5,
  391. seed0=0,
  392. seed1=0,
  393. dtype="mstype.float32")),
  394. "pt_api": APIPt('nn.Dropout', OrderedDict(p=0.5,
  395. inplace=False)),
  396. "ms2pt_mapping": {"keep_prob": "p"},
  397. "gen_explicit_map": partial(gen_explicit_map_one_delta,
  398. k_ms="keep_prob", k_pt="p")
  399. })
  400. }
  401. # set alias nn. = torch.nn.
  402. NN_MAPPING.update({"torch." + k: v for k, v in NN_MAPPING.items()})
  403. F_MAPPING = {
  404. 'F.relu': MappingHelper(**{"ms_api": APIMs('P.ReLU', OrderedDict(input=REQUIRED)),
  405. "pt_api": APIPt('F.relu', OrderedDict(input=REQUIRED, inplace=False)),
  406. "ms2pt_mapping": {"input": "input"},
  407. }),
  408. 'F.relu6': MappingHelper(**{"ms_api": APIMs('P.ReLU6', OrderedDict(input=REQUIRED)),
  409. "pt_api": APIPt('F.relu6', OrderedDict(input=REQUIRED, inplace=False)),
  410. "ms2pt_mapping": {"input": "input"},
  411. }),
  412. 'F.max_pool2d': MappingHelper(**{"ms_api": APIMs('P.MaxPool', OrderedDict(ksize=1,
  413. strides=1,
  414. padding="valid",
  415. input=REQUIRED),
  416. p_attrs={"ksize", "strides", "padding"}),
  417. "pt_api": APIPt('F.max_pool2d', OrderedDict(input=REQUIRED,
  418. kernel_size=REQUIRED,
  419. stride=None,
  420. padding=0,
  421. dilation=1,
  422. ceil_mode=False,
  423. return_indices=False)),
  424. "ms2pt_mapping": {"ksize": "kernel_size",
  425. "strides": "stride",
  426. "input": "input",
  427. },
  428. "gen_explicit_map": gen_explicit_map_f_max_pool2d
  429. }),
  430. 'F.avg_pool2d': MappingHelper(**{"ms_api": APIMs('P.AvgPool', OrderedDict(ksize=1,
  431. strides=1,
  432. padding="valid",
  433. input=REQUIRED),
  434. p_attrs={"ksize", "strides", "padding"}),
  435. "pt_api": APIPt('F.avg_pool2d', OrderedDict(input=REQUIRED,
  436. kernel_size=REQUIRED,
  437. stride=None,
  438. padding=0,
  439. dilation=1,
  440. ceil_mode=False,
  441. return_indices=False)),
  442. "ms2pt_mapping": {"ksize": "kernel_size",
  443. "strides": "stride",
  444. "input": "input",
  445. },
  446. "gen_explicit_map": gen_explicit_map_f_max_pool2d
  447. }),
  448. }
  449. # set alias F = nn.functional = torch.nn.functional
  450. nn_functional_d = {"nn.functional." + k[2:]: v for k, v in F_MAPPING.items()}
  451. torch_nn_functional_d = {"torch.nn.functional." + k[2:]: v for k, v in F_MAPPING.items()}
  452. F_MAPPING.update(nn_functional_d)
  453. F_MAPPING.update(torch_nn_functional_d)
  454. TORCH_DOT_MAPPING = {
  455. 'torch.flatten': MappingHelper(**{"ms_api": APIMs('P.Flatten', OrderedDict(input=REQUIRED)),
  456. "pt_api": APIPt('torch.flatten', OrderedDict(input=REQUIRED,
  457. start_dim=0,
  458. end_dim=-1)),
  459. "ms2pt_mapping": {"input": "input"}
  460. }),
  461. 'torch.cat': MappingHelper(**{"ms_api": APIMs('P.Concat',
  462. OrderedDict(axis=0, input=REQUIRED),
  463. p_attrs={"axis"}),
  464. "pt_api": APIPt('torch.flatten', OrderedDict(tensors=REQUIRED, dim=0, out=None)),
  465. "ms2pt_mapping": {"input": "tensors",
  466. "axis": "dim"}
  467. }),
  468. }
  469. TENSOR_DOT_MAPPING = {
  470. '.view': MappingHelper(**{"ms_api": APIMs('P.Reshape', OrderedDict(x=REQUIRED, shape=REQUIRED)),
  471. "pt_api": APIPt('.view', OrderedDict([('*shape', REQUIRED)])),
  472. "ms2pt_mapping": {"x": "call_name"},
  473. "gen_explicit_map": (lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"})
  474. }),
  475. '.size': MappingHelper(**{"ms_api": APIMs('P.Shape', OrderedDict(x=REQUIRED)),
  476. "pt_api": APIPt('.size', OrderedDict([('idx', REQUIRED)])),
  477. "ms2pt_mapping": {"x": "call_name"}
  478. }),
  479. '.flatten': MappingHelper(**{"ms_api": APIMs('P.Flatten', OrderedDict(input=REQUIRED)),
  480. "pt_api": APIPt('.flatten', OrderedDict(start_dim=0,
  481. end_dim=-1)),
  482. "ms2pt_mapping": {"input": "call_name"}
  483. }),
  484. '.reshape': MappingHelper(**{"ms_api": APIMs('P.Reshape', OrderedDict(x=REQUIRED, shape=REQUIRED)),
  485. "pt_api": APIPt('.reshape', OrderedDict([('*shape', REQUIRED)])),
  486. "ms2pt_mapping": {"x": "call_name"},
  487. "gen_explicit_map": (
  488. lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"})
  489. }),
  490. '.mean': MappingHelper(**{"ms_api": APIMs('P.ReduceMean', OrderedDict(keep_dims=False,
  491. input=REQUIRED,
  492. axis=()),
  493. p_attrs={"keep_dims"}),
  494. "pt_api": APIPt('.mean', OrderedDict(dim=None,
  495. keepdim=False)),
  496. "ms2pt_mapping": {"keep_dims": "keepdim",
  497. "axis": "dim",
  498. "input": "call_name"},
  499. }),
  500. '.squeeze': MappingHelper(**{"ms_api": APIMs('P.ReduceMean', OrderedDict(input=REQUIRED,
  501. axis=()),
  502. p_attrs={"axis"}),
  503. "pt_api": APIPt('.squeeze', OrderedDict(dim=None)),
  504. "ms2pt_mapping": {"axis": "dim",
  505. "input": "call_name"},
  506. }),
  507. }
  508. ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING}
  509. # ---------------------------- api list support or not support ----------------------------
  510. NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'nn_list.json'))
  511. NN_LIST = load_json_file(NN_LIST_PATH)
  512. # set alias nn. = torch.nn.
  513. NN_LIST += ["torch." + name for name in NN_LIST]
  514. NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING]
  515. NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING]
  516. F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'f_list.json'))
  517. F_LIST = load_json_file(F_LIST_PATH)
  518. # set alias F = nn.functional = torch.nn.functional
  519. F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
  520. [name[len("torch."):] for name in F_LIST]
  521. F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING]
  522. F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING]
  523. TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'torch_dot_list.json'))
  524. TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH)
  525. TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING]
  526. TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING]
  527. TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'tensor_dot_list.json'))
  528. TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH)
  529. TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING]
  530. TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING]
  531. ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
  532. ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
  533. ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED
  534. ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED
  535. UNSUPPORTED_WARN_INFOS = {
  536. "nn.AdaptiveAvgPool2d": "maybe could convert to P.ReduceMean",
  537. "F.adaptive_avg_pool2d": "maybe could convert to P.ReduceMean",
  538. "F.dropout": "please use nn.Dropout in __init__()",
  539. }

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。