You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless REQUIRED by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """API config"""
  16. import ast
  17. from collections import OrderedDict
  18. from importlib import import_module
  19. import json
  20. import os
  21. import pasta
  22. from mindinsight.mindconverter.common.log import logger
  23. from mindinsight.mindconverter.common.exceptions import CodeSyntaxError
  24. REQUIRED = 'REQUIRED'
  25. UNREQUIRED = 'UNREQUIRED'
  26. FUNC_MODULE = 'mindinsight.mindconverter.funcs'
  27. class APIPt:
  28. """Base API for args parse, and API for one frame."""
  29. def __init__(self, name: str, params: dict):
  30. self.name = name
  31. self.params = OrderedDict()
  32. for k, value in params.items():
  33. self.params[k] = self.to_str(value)
  34. @staticmethod
  35. def to_str(value):
  36. """
  37. Trans value to str.
  38. Args:
  39. value (Union[str,Number,int]): The value to convert.
  40. Returns:
  41. str, str type of value.
  42. """
  43. if value is REQUIRED:
  44. return value
  45. if isinstance(value, str):
  46. return "'{}'".format(value)
  47. return str(value)
  48. def parse_args(self, call_name: str, args_str: str):
  49. """
  50. Parse call_name and args_str.
  51. Args:
  52. call_name (str): str of the call function, etc.
  53. args_str (str): str of args for function, which starts with '(' and end with ')'.
  54. Returns:
  55. OrderedDict, all args parsed.
  56. Raises:
  57. ValueError: If can not use ast to parse or the required parse node not type of ast.Call,
  58. or the given args_str not valid.
  59. """
  60. # expr is REQUIRED to meet (**) format
  61. if not (len(args_str) >= 2 and args_str[0] == "(" and args_str.strip()[-1] == ")"):
  62. raise ValueError('"{}" is think as args string, it should start with "(" and end with ")" without '
  63. 'considering spaces'.format(args_str))
  64. try:
  65. ast_node = ast.parse("whatever_call_name" + args_str)
  66. call_node = ast_node.body[0].value
  67. except SyntaxError as parse_error:
  68. raise CodeSyntaxError("can't parse code:\n{}".format(args_str)) from parse_error
  69. # regard all actual parameter as one parameter
  70. if len(self.params) == 1:
  71. k = list(self.params.keys())[0]
  72. if k.startswith('*'):
  73. value = args_str[1:-1]
  74. return OrderedDict([(k, value), ("call_name", call_name)])
  75. args = OrderedDict()
  76. # param which name not assigned
  77. param_iter = iter(self.params.keys())
  78. if len(call_node.args) > len(self.params):
  79. raise ValueError('Parse args of torch in {}, but there is problems with params'.format(call_name))
  80. for arg in call_node.args:
  81. if isinstance(arg, ast.Starred):
  82. logger.debug("Find *%s", arg.value.id)
  83. args['*'] = arg.value.id
  84. else:
  85. # remove \n
  86. args[next(param_iter)] = pasta.dump(arg).strip()
  87. # params which name is assigned
  88. for keyword in call_node.keywords:
  89. if keyword.arg is None:
  90. logger.info("Find **%s", keyword.value.id)
  91. args['**'] = keyword.value.id
  92. else:
  93. # remove \n
  94. args[keyword.arg] = pasta.dump(keyword.value).strip()
  95. args["call_name"] = call_name
  96. return args
  97. class APIMs(APIPt):
  98. """API for MindSpore"""
  99. def __init__(self, name: str, params: dict, p_attrs=None):
  100. self.is_primitive = name.startswith('P.')
  101. if self.is_primitive:
  102. self.p_attrs = p_attrs if p_attrs else set()
  103. super(APIMs, self).__init__(name, params)
  104. def create_args(self, params_pt, args_pt, ms2pt_map, explicit_map):
  105. """
  106. Create args for MindSpore according to other frame op info.
  107. Args:
  108. params_pt (OrderedDict): Params used for initialize function of APIPt.
  109. args_pt (OrderedDict): Args parsed from APIPt.
  110. ms2pt_map (dict): Dict of params mapping relation for ops between frames.
  111. explicit_map(func): Function to generate mapping relation for ops between frames.
  112. Returns:
  113. OrderedDict, args for MindSpore.
  114. """
  115. args = OrderedDict()
  116. # traverse MindSpore's params
  117. for k in self.params.keys():
  118. # has relevant param? yes
  119. if k in ms2pt_map:
  120. if ms2pt_map[k] in args_pt:
  121. # user assigned value
  122. args[k] = args_pt[ms2pt_map[k]]
  123. elif self.params[k] != params_pt[ms2pt_map[k]]:
  124. # user didn't assigned value, but initial value different between 2 frames
  125. args[k] = params_pt[ms2pt_map[k]]
  126. # has relevant param? no
  127. else:
  128. # params forced to display
  129. if k in explicit_map:
  130. args[k] = explicit_map[k]
  131. elif self.params[k] is REQUIRED:
  132. args[k] = "<REQUIRED>"
  133. # find * or ** in frame actual parameters
  134. for star in ('*', '**'):
  135. if star in args_pt:
  136. args[star] = args_pt[star]
  137. return args
  138. class MappingHelper:
  139. """Mapping from one frame to another frame"""
  140. def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs):
  141. ms2pt_mapping = kwargs.get('ms2pt_mapping')
  142. gen_explicit_map = kwargs.get('gen_explicit_map')
  143. export_key = kwargs.get('export_key')
  144. if ms2pt_mapping is None:
  145. ms2pt_mapping = {}
  146. if gen_explicit_map is None:
  147. gen_explicit_map = lambda params_pt, args_pt: {}
  148. self.ms_api = ms_api
  149. self.pt_api = pt_api
  150. self.ms2pt_mapping = ms2pt_mapping
  151. self.gen_explicit_map = gen_explicit_map
  152. if export_key is not None:
  153. self.export_key = export_key
  154. else:
  155. self.export_key = not ms_api.is_primitive
  156. def gen_args_expr(self, args):
  157. """
  158. Generate str assignment statement from given dict.
  159. Args:
  160. args (OrderedDict): Key, value pairs for assignment source.
  161. Returns:
  162. str, generated str.
  163. """
  164. expr = ''
  165. for key, value in args.items():
  166. if expr:
  167. expr += ', '
  168. sym = '' if key in ('*', '**') else '='
  169. if self.export_key:
  170. expr += key + sym
  171. expr += value
  172. return expr
  173. def gen_args_expr_for_p(self, args, p_attrs):
  174. """
  175. Generate str assignment statement from given dict for primitive and not primitive.
  176. Args:
  177. args (OrderedDict): Key, value pairs for assignment source.
  178. p_attrs (set): Exclusive params for operator.
  179. Returns:
  180. tuple, generated str for primitive, generated str for not primitive.
  181. """
  182. args_attrs = OrderedDict([(k, v) for k, v in args.items() if k in p_attrs])
  183. args_ios = OrderedDict([(k, v) for k, v in args.items() if k not in p_attrs])
  184. return self.gen_args_expr(args_attrs), self.gen_args_expr(args_ios)
  185. def convert(self, call_name_pt: str, args_str_pt: str):
  186. """
  187. Convert code sentence to MindSpore code sentence.
  188. Args:
  189. call_name_pt (str): str of the call function, etc.
  190. args_str_pt (str): str of args for function, which starts with '(' and end with ')'.
  191. Returns:
  192. str, converted code sentence for MindSpore.
  193. """
  194. # all value for args_pt is str
  195. args_pt = self.pt_api.parse_args(call_name_pt, args_str_pt)
  196. # all value for args_ms is str
  197. explicit_map = self.gen_explicit_map(self.pt_api.params, args_pt)
  198. args_ms = self.ms_api.create_args(self.pt_api.params, args_pt, self.ms2pt_mapping, explicit_map)
  199. if self.ms_api.is_primitive:
  200. if self.pt_api.name == '.size' and 'idx' in args_pt:
  201. args_expr = self.gen_args_expr(args_ms)
  202. expr_ms = "%s()(%s)[%s]" % (self.ms_api.name, args_expr, args_pt['idx'])
  203. else:
  204. expr_attrs, expr_ios = self.gen_args_expr_for_p(args_ms, self.ms_api.p_attrs)
  205. expr_ms = "%s(%s)(%s)" % (self.ms_api.name, expr_attrs, expr_ios)
  206. else:
  207. ms_expr = self.gen_args_expr(args_ms)
  208. expr_ms = "%s(%s)" % (self.ms_api.name, ms_expr)
  209. return expr_ms
  210. def get_ms_api(ms_api_info):
  211. """
  212. Get APIMs instance from ms_api_info.
  213. Args:
  214. ms_api_info (list): info for create an APIMs instance, the first value in list is name for APIMs, the second(if
  215. provided) is params for APIMs, the third(if provided) is p_attrs for APIMs.
  216. Returns:
  217. APIMs, instance of APIMs parsed from given info.
  218. """
  219. ms_name = ms_api_info[0]
  220. ms_params = ms_api_info[1] if len(ms_api_info) >= 2 else None
  221. ms_p_attrs = set(ms_api_info[2]) if len(ms_api_info) >= 3 else None
  222. ms_api = APIMs(name=ms_name, params=ms_params, p_attrs=ms_p_attrs)
  223. return ms_api
  224. def get_pt_api(pt_api_info):
  225. """
  226. Get APIPt instance from pt_api_info.
  227. Args:
  228. pt_api_info (list): info for create an APIMs instance, the first value in list is name for APIPt, the second(if
  229. provided) is params for APIPt.
  230. Returns:
  231. APIMs, instance of APIMs parsed from given info.
  232. """
  233. pt_name = pt_api_info[0]
  234. pt_params = pt_api_info[1] if len(pt_api_info) >= 2 else None
  235. pt_api = APIPt(name=pt_name, params=pt_params)
  236. return pt_api
  237. def get_mapping_from_file(path):
  238. """
  239. Parse mapping info from given file.
  240. Args:
  241. path (str): The file path.
  242. Returns:
  243. dict, key is op name, value is a relevant instance of MappingHelper.
  244. """
  245. mapping_info_d = load_json_file(path)
  246. parse_mapping_dict = {}
  247. for key, value in mapping_info_d.items():
  248. ms_api_info = value.pop('ms_api')
  249. ms_api = get_ms_api(ms_api_info)
  250. pt_api_info = value.pop('pt_api')
  251. pt_api = get_pt_api(pt_api_info)
  252. gen_explicit_map = value.get('gen_explicit_map')
  253. if gen_explicit_map:
  254. module_name = import_module(FUNC_MODULE)
  255. value['gen_explicit_map'] = getattr(module_name, gen_explicit_map)
  256. parse_mapping_dict.update({key: MappingHelper(**dict(ms_api=ms_api, pt_api=pt_api), **value)})
  257. return parse_mapping_dict
  258. def load_json_file(file_path):
  259. """
  260. Load data from given json file path.
  261. Args:
  262. file_path (str): The file to load json data from.
  263. Returns:
  264. list(str), the list data stored in file_path.
  265. """
  266. with open(file_path, 'r', encoding='utf-8') as file:
  267. info = json.loads(file.read())
  268. return info
  269. def get_corresponding_ms_name(pt_name):
  270. """
  271. Get corresponding MindSpore op name for PyTorch name according to the mappings in mindconverter.
  272. Args:
  273. pt_name: PyTorch op name, whether shortened form or full name is available.
  274. Returns:
  275. str, full MindSpore op name, None if the op is not supported in mindconverter.
  276. Raises:
  277. ValueError, if get shortened form of MindSpore name not starts with `P` or 'nn', which means it is wrong in
  278. the mappings file.
  279. """
  280. helper = ALL_MAPPING.get(pt_name)
  281. if helper is None:
  282. return None
  283. ms_name = helper.ms_api.name
  284. if ms_name.startswith('nn.'):
  285. full_ms_name = 'mindspore.' + ms_name
  286. elif ms_name.startswith('P.'):
  287. full_ms_name = 'mindspore.ops.operations.' + ms_name[len('P.'):]
  288. else:
  289. raise ValueError('check your mapping infos, the corresponding mindspore op name may wrong for torch op : '
  290. '{}'.format(pt_name))
  291. return full_ms_name
  292. def get_prompt_info(pt_name):
  293. """
  294. Get prompt info for PyTorch op name.
  295. Args:
  296. pt_name: PyTorch op name, whether shortened form or full name is available.
  297. Returns:
  298. str, prompt info on the op, None if no prompt info for the op.
  299. """
  300. prompt_dict = {**UNSUPPORTED_WARN_INFOS, **SUPPORTED_WARN_INFOS}
  301. return prompt_dict.get(pt_name)
  302. # ---------------------------- mappings ----------------------------
  303. NN_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/nn_mappings.json'))
  304. NN_MAPPING = get_mapping_from_file(NN_MAPPING_PATH)
  305. # update to add key with full api_name, which starts with 'torch.nn.'
  306. NN_MAPPING.update({"torch." + k: v for k, v in NN_MAPPING.items()})
  307. F_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/f_mappings.json'))
  308. F_MAPPING = get_mapping_from_file(F_MAPPING_PATH)
  309. # update to add key starts with 'nn.functional.'
  310. NN_FUNCTIONAL_D = {"nn.functional." + k[len('F.'):]: v for k, v in F_MAPPING.items()}
  311. # update to add key starts with 'torch.nn.functional.'
  312. TORCH_NN_FUNCTIONAL_D = {"torch.nn.functional." + k[len('F.'):]: v for k, v in F_MAPPING.items()}
  313. F_MAPPING.update(NN_FUNCTIONAL_D)
  314. F_MAPPING.update(TORCH_NN_FUNCTIONAL_D)
  315. TORCH_DOT_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/torch_dot_mappings.json'))
  316. TORCH_DOT_MAPPING = get_mapping_from_file(TORCH_DOT_MAPPING_PATH)
  317. TENSOR_DOT_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/tensor_dot_mappings.json'))
  318. TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH)
  319. ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING}
  320. # ---------------------------- api list support or not support ----------------------------
  321. NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json'))
  322. NN_LIST = load_json_file(NN_LIST_PATH)
  323. NN_LIST += ["torch." + name for name in NN_LIST]
  324. NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING]
  325. NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING]
  326. F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json'))
  327. F_LIST = load_json_file(F_LIST_PATH)
  328. F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
  329. [name[len("torch."):] for name in F_LIST]
  330. F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING]
  331. F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING]
  332. TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'torch_dot_list.json'))
  333. TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH)
  334. TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING]
  335. TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING]
  336. TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'tensor_dot_list.json'))
  337. TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH)
  338. TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING]
  339. TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING]
  340. ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
  341. ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
  342. ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED
  343. ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED
  344. UNSUPPORTED_WARN_INFOS = {
  345. "nn.AdaptiveAvgPool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
  346. "nn.AvgPool1d": "Maybe could convert to mindspore.nn.AvgPool1d.",
  347. "nn.ConvTranspose2d": "Maybe could convert to mindspore.nn.Conv2dTranspose.",
  348. "nn.CrossEntropyLoss": "Maybe could convert to mindspore.nn.SoftmaxCrossEntropyWithLogits.",
  349. "nn.Embedding": "Maybe could convert to mindspore.nn.Embedding.",
  350. "nn.GroupNorm": "Maybe could convert to mindspore.nn.GroupNorm.",
  351. "nn.MSELoss": "Maybe could convert to mindspore.nn.MSELoss.",
  352. "nn.LSTM": "Maybe could convert to mindspore.nn.LSTM.",
  353. "nn.LSTMCell": "Maybe could convert to mindspore.nn.LSTMCell.",
  354. "nn.ModuleList": "Maybe could convert to mindspore.nn.CellList.",
  355. "nn.SmoothL1Loss": "Maybe could convert to mindspore.nn.SmoothL1Loss.",
  356. "nn.Tanh": "Maybe could convert to mindspore.nn.Tanh.",
  357. "nn.Upsample": "Maybe could convert to mindspore.ops.operations.ResizeBilinear.",
  358. "nn.L1Loss": "Maybe could convert to mindspore.nn.L1Loss.",
  359. "nn.Parameter": "Maybe could convert to mindspore.Parameter.",
  360. "nn.ParameterList": "Maybe could convert to mindspore.ParameterTuple.",
  361. "nn.Unfold": "Maybe could convert to mindspore.nn.Unfold.",
  362. "nn.PixelShuffle": "Maybe could convert to mindspore.ops.operations.DepthToSpace.",
  363. "F.adaptive_avg_pool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
  364. "F.conv2d": "Maybe could convert to mindspore.ops.operations.Conv2D.",
  365. "F.dropout": "please use mindspore.nn.Dropout in __init__().",
  366. "F.interpolate": "Maybe could convert to mindspore.ops.operations.ResizeBilinear.",
  367. "F.one_hot": "Maybe could convert to mindspore.ops.operations.OneHot.",
  368. "torch.bmm": "Maybe could convert to mindspore.ops.operations.BatchMatMul.",
  369. "torch.cumsum": "Maybe could convert to mindspore.ops.operations.CumSum.",
  370. "F.pad": "Maybe could convert to mindspore.ops.operations.Pad.",
  371. "F.softmax": "Maybe could convert to mindspore.ops.operations.Softmax.",
  372. "torch.clamp": "Maybe could convert to mindspore.ops.composite.clip_by_value.",
  373. "torch.eq": "Maybe could convert to mindspore.ops.operations.Equal.",
  374. "torch.load": "Maybe could convert to mindspore.train.serialization.load_checkpoint.",
  375. "torch.matmul": "Maybe could convert to mindspore.ops.operations.MatMul.",
  376. "torch.max": "try to use P.ArgMaxWithValue, notice that two values are returned by mindspore.ops.operations."
  377. "ArgMaxWithValue.",
  378. "torch.mean": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
  379. "torch.min": "try to use P.ArgMinWithValue, notice that two values are returned by mindspore.ops.operations."
  380. "ArgMinWithValue.",
  381. "torch.mm": "Maybe could convert to mindspore.ops.operations.MatMul.",
  382. "torch.mul": "Maybe could convert to mindspore.ops.operations.Mul.",
  383. "torch.norm": "Maybe could convert to mindspore.nn.Norm.",
  384. "torch.numel": "Maybe could convert to mindspore.ops.operations.Size.",
  385. "torch.ones_like": "Maybe could convert to mindspore.ops.operations.OnesLike.",
  386. "torch.randn": "Maybe could convert to mindspore.ops.operations.TruncatedNormal.",
  387. "torch.round": "Maybe could convert to mindspore.ops.operations.Round.",
  388. "torch.save": "Maybe could convert to mindspore.train.serialization.save_checkpoint.",
  389. "torch.sigmoid": "Maybe could convert to mindspore.ops.operations.Sigmoid.",
  390. "torch.split": "Maybe could convert to mindspore.ops.operations.Split.",
  391. "torch.squeeze": "Maybe could convert to mindspore.ops.operations.Squeeze.",
  392. "torch.stack": "Maybe could convert to mindspore.ops.operations.Pack.",
  393. "torch.sum": "Maybe could convert to mindspore.ops.operations.ReduceSum.",
  394. "torch.tanh": "Maybe could convert to mindspore.ops.operations.Tanh.",
  395. "torch.tensor": "Maybe could convert to mindspore.Tensor.",
  396. "torch.transpose": "Maybe could convert to mindspore.ops.operations.Transpose.",
  397. "torch.unsqueeze": "Maybe could convert to mindspore.ops.operations.ExpandDims.",
  398. "torch.zeros_like": "Maybe could convert to mindspore.ops.operations.ZerosLike.",
  399. ".chunk": "Maybe could convert to mindspore.ops.operations.Split.",
  400. ".fill_": "Maybe could convert to mindspore.ops.operations.Fill.",
  401. ".float": "Maybe could convert to mindspore.ops.operations.Cast.",
  402. ".mm": "Maybe could convert to mindspore.ops.operations.MatMul.",
  403. ".mul": "Maybe could convert to mindspore.ops.operations.Mul.",
  404. ".pow": "Maybe could convert to mindspore.ops.operations.Pow.",
  405. ".round": "Maybe could convert to mindspore.ops.operations.Round.",
  406. ".scatter": "Maybe could convert to mindspore.ops.operations.ScatterNd.",
  407. ".sigmoid": "Maybe could convert to mindspore.nn.Sigmoid.",
  408. ".sign": "Maybe could convert to mindspore.ops.operations.Sign.",
  409. ".sqrt": "Maybe could convert to mindspore.ops.operations.Sqrt.",
  410. ".sub": "Maybe could convert to mindspore.ops.operations.Sub.",
  411. ".transpose": "Maybe could convert to mindspore.ops.operations.Transpose.",
  412. ".unsqueeze": "Maybe could convert to mindspore.ops.operations.ExpandDims.",
  413. ".zero_": "Maybe could convert to mindspore.ops.operations.ZerosLike.",
  414. }
  415. NN_UNSUPPORTED_INFOS = {k: v for k, v in UNSUPPORTED_WARN_INFOS.items() if k.startswith('nn.')}
  416. TORCH_NN_UNSUPPORTED_INFOS = {('torch.' + k): v for k, v in NN_UNSUPPORTED_INFOS.items()}
  417. F_UNSUPPORTED_INFOS = {k: v for k, v in UNSUPPORTED_WARN_INFOS.items() if k.startswith('F.')}
  418. NN_FUNCTIONAL_UNSUPPORTED_INFOS = {'nn.functional.' + k[len('F.'):]: v for k, v in F_UNSUPPORTED_INFOS.items()}
  419. TORCH_NN_FUNCTIONAL_UNSUPPORTED_INFOS = {'torch.nn.functional.' + k[len('F.'):]: v for k, v in
  420. F_UNSUPPORTED_INFOS.items()}
  421. UNSUPPORTED_WARN_INFOS.update(TORCH_NN_UNSUPPORTED_INFOS)
  422. UNSUPPORTED_WARN_INFOS.update(NN_FUNCTIONAL_UNSUPPORTED_INFOS)
  423. UNSUPPORTED_WARN_INFOS.update(TORCH_NN_FUNCTIONAL_UNSUPPORTED_INFOS)
  424. SUPPORTED_WARN_INFOS = {
  425. "torch.eye": "Pay attention to use right mindspore data type.",
  426. "nn.Linear": "Pay attention to reshape the input to 2 dims if it is 3 dims before, because MindSpore.nn.Dense only "
  427. "support 2-dim input.",
  428. ".view": "Only float Tensor is supported in mindspore.ops.operations.Reshape.",
  429. ".reshape": "Only float Tensor is supported in mindspore.ops.operations.Reshape."
  430. }
  431. NN_SUPPORTED_INFOS = {k: v for k, v in SUPPORTED_WARN_INFOS.items() if k.startswith('nn.')}
  432. TORCH_NN_SUPPORTED_INFOS = {('torch.' + k): v for k, v in NN_SUPPORTED_INFOS.items()}
  433. SUPPORTED_WARN_INFOS.update(TORCH_NN_SUPPORTED_INFOS)