You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless REQUIRED by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """API config"""
  16. import ast
  17. from collections import OrderedDict
  18. from importlib import import_module
  19. import json
  20. import os
  21. import pasta
  22. from mindinsight.mindconverter.common.log import logger
  23. REQUIRED = 'REQUIRED'
  24. UNREQUIRED = 'UNREQUIRED'
  25. FUNC_MODULE = 'mindinsight.mindconverter.funcs'
  26. class APIPt:
  27. """Base API for args parse, and API for one frame."""
  28. def __init__(self, name: str, params: OrderedDict):
  29. self.name = name
  30. self.params = OrderedDict()
  31. for k, value in params.items():
  32. self.params[k] = self.to_str(value)
  33. @staticmethod
  34. def to_str(value):
  35. """
  36. Trans value to str.
  37. Args:
  38. value (Union[str,Number,int]): Each value for params of OrderedDict.
  39. Returns:
  40. str, str type of value.
  41. """
  42. if value is REQUIRED:
  43. return value
  44. if isinstance(value, str):
  45. return "'{}'".format(value)
  46. return str(value)
  47. def parse_args(self, call_name: str, args_str: str):
  48. """
  49. Parse call_name and args_str.
  50. Args:
  51. call_name (str): str of the call function, etc.
  52. args_str (str): str of args for function, which starts with '(' and end with ')'.
  53. Returns:
  54. OrderedDict, all args parsed.
  55. Raises:
  56. ValueError: If can not use ast to parse or the required parse node not type of ast.Call,
  57. or the given args_str not valid.
  58. """
  59. # expr is REQUIRED to meet (**) format
  60. if not (len(args_str) >= 2 and args_str[0] == "(" and args_str.strip()[-1] == ")"):
  61. raise ValueError('"{}" is think as args string, it should start with "(" and end with ")" without '
  62. 'considering spaces'.format(args_str))
  63. try:
  64. ast_node = ast.parse("whatever_call_name" + args_str)
  65. call_node = ast_node.body[0].value
  66. if not isinstance(call_node, ast.Call):
  67. raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str))
  68. except:
  69. raise ValueError("can't parse code:\n{}".format(args_str))
  70. # regard all actual parameter as one parameter
  71. if len(self.params) == 1:
  72. k = list(self.params.keys())[0]
  73. if k.startswith('*'):
  74. value = args_str[1:-1]
  75. return OrderedDict([(k, value), ("call_name", call_name)])
  76. args = OrderedDict()
  77. # param which name not assigned
  78. param_iter = iter(self.params.keys())
  79. if len(call_node.args) > len(self.params):
  80. raise ValueError('Parse args of torch in {}, but there is problems with params'.format(call_name))
  81. for arg in call_node.args:
  82. if isinstance(arg, ast.Starred):
  83. logger.debug("Find *%s", arg.value.id)
  84. args['*'] = arg.value.id
  85. else:
  86. # remove \n
  87. args[next(param_iter)] = pasta.dump(arg).strip()
  88. # params which name is assigned
  89. for keyword in call_node.keywords:
  90. if keyword.arg is None:
  91. logger.info("Find **%s", keyword.value.id)
  92. args['**'] = keyword.value.id
  93. else:
  94. # remove \n
  95. args[keyword.arg] = pasta.dump(keyword.value).strip()
  96. args["call_name"] = call_name
  97. return args
  98. class APIMs(APIPt):
  99. """API for MindSpore"""
  100. def __init__(self, name: str, params: OrderedDict, p_attrs=None):
  101. self.is_primitive = name.startswith('P.')
  102. if self.is_primitive:
  103. self.p_attrs = p_attrs if p_attrs else set()
  104. super(APIMs, self).__init__(name, params)
  105. def create_args(self, params_pt, args_pt, ms2pt_map, explicit_map):
  106. """
  107. Create args for MindSpore according to other frame op info.
  108. Args:
  109. params_pt (OrderedDict): Params used for initialize function of APIPt.
  110. args_pt (OrderedDict): Args parsed from APIPt.
  111. ms2pt_map (dict): Dict of params mapping relation for ops between frames.
  112. explicit_map(func): Function to generate mapping relation for ops between frames.
  113. Returns:
  114. OrderedDict, args for MindSpore.
  115. """
  116. args = OrderedDict()
  117. # traverse MindSpore's params
  118. for k in self.params.keys():
  119. # has relevant param? yes
  120. if k in ms2pt_map:
  121. if ms2pt_map[k] in args_pt:
  122. # user assigned value
  123. args[k] = args_pt[ms2pt_map[k]]
  124. elif self.params[k] != params_pt[ms2pt_map[k]]:
  125. # user didn't assigned value, but initial value different between 2 frames
  126. args[k] = params_pt[ms2pt_map[k]]
  127. # has relevant param? no
  128. else:
  129. # params forced to display
  130. if k in explicit_map:
  131. args[k] = explicit_map[k]
  132. elif self.params[k] is REQUIRED:
  133. args[k] = "<REQUIRED>"
  134. # find * or ** in frame actual parameters
  135. for star in ('*', '**'):
  136. if star in args_pt:
  137. args[star] = args_pt[star]
  138. return args
  139. class MappingHelper:
  140. """Mapping from one frame to another frame"""
  141. def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs):
  142. ms2pt_mapping = kwargs.get('ms2pt_mapping')
  143. gen_explicit_map = kwargs.get('gen_explicit_map')
  144. export_key = kwargs.get('export_key')
  145. if ms2pt_mapping is None:
  146. ms2pt_mapping = {}
  147. if gen_explicit_map is None:
  148. gen_explicit_map = lambda params_pt, args_pt: {}
  149. self.ms_api = ms_api
  150. self.pt_api = pt_api
  151. self.ms2pt_mapping = ms2pt_mapping
  152. self.gen_explicit_map = gen_explicit_map
  153. if export_key is not None:
  154. self.export_key = export_key
  155. else:
  156. self.export_key = not ms_api.is_primitive
  157. def gen_args_expr(self, args):
  158. """
  159. Generate str assignment statement from given dict.
  160. Args:
  161. args (OrderedDict): Key, value pairs for assignment source.
  162. Returns:
  163. str, generated str.
  164. """
  165. expr = ''
  166. for key, value in args.items():
  167. if expr:
  168. expr += ', '
  169. sym = '' if key in ('*', '**') else '='
  170. if self.export_key:
  171. expr += key + sym
  172. expr += value
  173. return expr
  174. def gen_args_expr_for_p(self, args, p_attrs):
  175. """
  176. Generate str assignment statement from given dict for primitive and not primitive.
  177. Args:
  178. args (OrderedDict): Key, value pairs for assignment source.
  179. p_attrs (set): Exclusive params for operator.
  180. Returns:
  181. tuple, generated str for primitive, generated str for not primitive.
  182. """
  183. args_attrs = OrderedDict([(k, v) for k, v in args.items() if k in p_attrs])
  184. args_ios = OrderedDict([(k, v) for k, v in args.items() if k not in p_attrs])
  185. return self.gen_args_expr(args_attrs), self.gen_args_expr(args_ios)
  186. def convert(self, call_name_pt: str, args_str_pt: str):
  187. """
  188. Convert code sentence to MindSpore code sentence.
  189. Args:
  190. call_name_pt (str): str of the call function, etc.
  191. args_str_pt (str): str of args for function, which starts with '(' and end with ')'.
  192. Returns:
  193. str, converted code sentence for MindSpore.
  194. """
  195. # all value for args_pt is str
  196. args_pt = self.pt_api.parse_args(call_name_pt, args_str_pt)
  197. # all value for args_ms is str
  198. explicit_map = self.gen_explicit_map(self.pt_api.params, args_pt)
  199. args_ms = self.ms_api.create_args(self.pt_api.params, args_pt, self.ms2pt_mapping, explicit_map)
  200. if self.ms_api.is_primitive:
  201. if self.pt_api.name == '.size' and 'idx' in args_pt:
  202. args_expr = self.gen_args_expr(args_ms)
  203. expr_ms = "%s()(%s)[%s]" % (self.ms_api.name, args_expr, args_pt['idx'])
  204. else:
  205. expr_attrs, expr_ios = self.gen_args_expr_for_p(args_ms, self.ms_api.p_attrs)
  206. expr_ms = "%s(%s)(%s)" % (self.ms_api.name, expr_attrs, expr_ios)
  207. else:
  208. ms_expr = self.gen_args_expr(args_ms)
  209. expr_ms = "%s(%s)" % (self.ms_api.name, ms_expr)
  210. return expr_ms
  211. def get_ms_api(ms_api_info):
  212. """
  213. Get APIMs instance from ms_api_info.
  214. Args:
  215. ms_api_info (list): info for create an APIMs instance, the first value in list is name for APIMs, the second(if
  216. provided) is params for APIMs, the third(if provided) is p_attrs for APIMs.
  217. Returns:
  218. APIMs, instance of APIMs parsed from given info.
  219. """
  220. ms_name = ms_api_info[0]
  221. ms_params = ms_api_info[1] if len(ms_api_info) >= 2 else None
  222. ms_p_attrs = set(ms_api_info[2]) if len(ms_api_info) >= 3 else None
  223. ms_api = APIMs(name=ms_name, params=ms_params, p_attrs=ms_p_attrs)
  224. return ms_api
  225. def get_pt_api(pt_api_info):
  226. """
  227. Get APIPt instance from pt_api_info.
  228. Args:
  229. pt_api_info (list): info for create an APIMs instance, the first value in list is name for APIPt, the second(if
  230. provided) is params for APIPt.
  231. Returns:
  232. APIMs, instance of APIMs parsed from given info.
  233. """
  234. pt_name = pt_api_info[0]
  235. pt_params = pt_api_info[1] if len(pt_api_info) >= 2 else None
  236. pt_api = APIPt(name=pt_name, params=pt_params)
  237. return pt_api
  238. def get_mapping_from_file(path):
  239. """
  240. Parse mapping info from given file.
  241. Args:
  242. path (str): The file path.
  243. Returns:
  244. dict, key is op name, value is a relevant instance of MappingHelper.
  245. """
  246. mapping_info_d = load_json_file(path)
  247. parse_mapping_dict = {}
  248. for key, value in mapping_info_d.items():
  249. ms_api_info = value.pop('ms_api')
  250. ms_api = get_ms_api(ms_api_info)
  251. pt_api_info = value.pop('pt_api')
  252. pt_api = get_pt_api(pt_api_info)
  253. gen_explicit_map = value.get('gen_explicit_map')
  254. if gen_explicit_map:
  255. module_name = import_module(FUNC_MODULE)
  256. value['gen_explicit_map'] = getattr(module_name, gen_explicit_map)
  257. parse_mapping_dict.update({key: MappingHelper(**dict(ms_api=ms_api, pt_api=pt_api), **value)})
  258. return parse_mapping_dict
  259. def load_json_file(file_path):
  260. """
  261. Load data from given json file path.
  262. Args:
  263. file_path (str): The file to load json data from.
  264. Returns:
  265. list(str), the list data stored in file_path.
  266. """
  267. with open(file_path, 'r', encoding='utf-8') as file:
  268. info = json.loads(file.read())
  269. return info
  270. # ---------------------------- mappings ----------------------------
  271. NN_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/nn_mappings.json'))
  272. NN_MAPPING = get_mapping_from_file(NN_MAPPING_PATH)
  273. # update to add key with full api_name, which starts with 'torch.nn.'
  274. NN_MAPPING.update({"torch." + k: v for k, v in NN_MAPPING.items()})
  275. F_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/f_mappings.json'))
  276. F_MAPPING = get_mapping_from_file(F_MAPPING_PATH)
  277. # update to add key starts with 'nn.functional.'
  278. NN_FUNCTIONAL_D = {"nn.functional." + k[len('F.'):]: v for k, v in F_MAPPING.items()}
  279. # update to add key starts with 'torch.nn.functional.'
  280. TORCH_NN_FUNCTIONAL_D = {"torch.nn.functional." + k[len('F.'):]: v for k, v in F_MAPPING.items()}
  281. F_MAPPING.update(NN_FUNCTIONAL_D)
  282. F_MAPPING.update(TORCH_NN_FUNCTIONAL_D)
  283. TORCH_DOT_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/torch_dot_mappings.json'))
  284. TORCH_DOT_MAPPING = get_mapping_from_file(TORCH_DOT_MAPPING_PATH)
  285. TENSOR_DOT_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/tensor_dot_mappings.json'))
  286. TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH)
  287. ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING}
  288. # ---------------------------- api list support or not support ----------------------------
  289. NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json'))
  290. NN_LIST = load_json_file(NN_LIST_PATH)
  291. NN_LIST += ["torch." + name for name in NN_LIST]
  292. NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING]
  293. NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING]
  294. F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json'))
  295. F_LIST = load_json_file(F_LIST_PATH)
  296. F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
  297. [name[len("torch."):] for name in F_LIST]
  298. F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING]
  299. F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING]
  300. TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'torch_dot_list.json'))
  301. TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH)
  302. TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING]
  303. TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING]
  304. TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'tensor_dot_list.json'))
  305. TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH)
  306. TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING]
  307. TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING]
  308. ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
  309. ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
  310. ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED
  311. ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED
  312. UNSUPPORTED_WARN_INFOS = {
  313. "nn.AdaptiveAvgPool2d": "maybe could convert to P.ReduceMean",
  314. "nn.AvgPool1d": "maybe could convert to nn.AvgPool1d",
  315. "nn.ConvTranspose2d": "maybe could convert to nn.Conv2dTranspose",
  316. "nn.CrossEntropyLoss": "maybe could convert to nn.SoftmaxCrossEntropyWithLogits",
  317. "nn.Embedding": "maybe could convert to nn.Embedding",
  318. "nn.GroupNorm": "maybe could convert to nn.GroupNorm",
  319. "nn.MSELoss": "maybe could convert to nn.MSELoss",
  320. "nn.LSTM": "maybe could convert to nn.LSTM",
  321. "nn.LSTMCell": "maybe could convert to nn.LSTMCell",
  322. "nn.ModuleList": "maybe could convert to nn.CellList",
  323. "nn.SmoothL1Loss": "maybe could convert to nn.SmoothL1Loss",
  324. "nn.Tanh": "maybe could convert to nn.Tanh",
  325. "nn.Upsample": "maybe could convert to P.ResizeBilinear",
  326. "nn.L1Loss": "maybe could convert to nn.L1Loss",
  327. "nn.Parameter": "maybe could convert to mindspore.Parameter",
  328. "nn.ParameterList": "maybe could convert to mindspore.ParameterTuple",
  329. "nn.Unfold": "maybe could convert to nn.Unfold",
  330. "nn.PixelShuffle": "maybe could convert to P.DepthToSpace",
  331. "F.adaptive_avg_pool2d": "maybe could convert to P.ReduceMean",
  332. "F.conv2d": "maybe could convert to mindspore.ops.operations.Conv2D",
  333. "F.dropout": "please use nn.Dropout in __init__()",
  334. "F.interpolate": "maybe could convert to P.ResizeBilinear",
  335. "torch.bmm": "maybe could convert to P.BatchMatMul",
  336. "torch.cumsum": "maybe could convert to P.CumSum",
  337. "F.relu": "maybe could convert to P.ReLU",
  338. "F.pad": "maybe could convert to P.Pad",
  339. "F.softmax": "maybe could convert to P.Softmax",
  340. "torch.clamp": "maybe could convert to mindspore.ops.composite.clip_by_value",
  341. "torch.eq": "maybe could convert to P.Equal",
  342. "torch.load": "maybe could convert to mindspore.train.serialization.load_checkpoint",
  343. "torch.matmul": "maybe could convert to P.MatMul",
  344. "torch.max": "try to use P.ArgMaxWithValue, notice that two values are returned by P.ArgMaxWithValue",
  345. "torch.mean": "maybe could convert to P.ReduceMean",
  346. "torch.min": "try to use P.ArgMinWithValue, notice that two values are returned by P.ArgMinWithValue",
  347. "torch.mm": "maybe could convert to P.MatMul",
  348. "torch.mul": "maybe could convert to P.Mul",
  349. "torch.norm": "maybe could convert to nn.Norm",
  350. "torch.numel": "maybe could convert to P.Size",
  351. "F.one_hot": "maybe could convert to P.OneHot",
  352. "torch.ones_like": "maybe could convert to P.OnesLike",
  353. "torch.randn": "maybe could convert to P.TruncatedNormal",
  354. "torch.round": "maybe could convert to P.Round",
  355. "torch.save": "maybe could convert to mindspore.train.serialization.save_checkpoint",
  356. "torch.sigmoid": "maybe could convert to P.Sigmoid",
  357. "torch.split": "maybe could convert to P.Split",
  358. "torch.squeeze": "maybe could convert to P.Squeeze",
  359. "torch.stack": "maybe could convert to P.Pack",
  360. "torch.sum": "maybe could convert to mindspore.ops.operations.ReduceSum",
  361. "torch.tanh": "maybe could convert to mindspore.ops.operations.Tanh",
  362. "torch.tensor": "maybe could convert to mindspore.Tensor",
  363. "torch.transpose": "maybe could convert to P.Transpose",
  364. "torch.unsqueeze": "maybe could convert to P.ExpandDims",
  365. "torch.zeros_like": "maybe could convert to P.ZerosLike",
  366. ".chunk": "maybe could convert to P.Split",
  367. ".fill_": "maybe could convert to P.Fill",
  368. ".float": "maybe could convert to P.Cast",
  369. ".mm": "maybe could convert to P.MatMul",
  370. "mul": "maybe could convert to P.Mul",
  371. ".pow": "maybe could convert to P.Pow",
  372. ".round": "maybe could convert to P.Round",
  373. ".scatter": "maybe could convert to P.ScatterNd",
  374. "sigmoid": "maybe could convert to nn.Sigmoid",
  375. ".sign": "maybe could convert to P.Sign",
  376. ".sqrt": "maybe could convert to P.Sqrt",
  377. ".sub": "maybe could convert to P.Sub",
  378. ".transpose": "maybe could convert to P.Transpose",
  379. ".unsqueeze": "maybe could convert to P.ExpandDims",
  380. ".zero_": "maybe could convert to P.ZerosLike",
  381. }