You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

funcs.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless REQUIRED by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """funcs for gen_explicit_map"""
  16. from functools import partial
  17. def gen_explicit_map_f_max_pool2d(params_pt, args_pt):
  18. """
  19. Generate explicit_map for F.MaxPool2d.
  20. Args:
  21. params_pt (dict): Params for APIPt.
  22. args_pt (dict): Args for APIPt.
  23. Returns:
  24. dict, map between frames.
  25. """
  26. if 'padding' in args_pt:
  27. padding = args_pt['padding']
  28. else:
  29. padding = params_pt['padding']
  30. if padding.strip() in ("0", "(0,0)", "(0, 0)"):
  31. padding = "'valid'"
  32. else:
  33. padding = "'same'"
  34. return {"padding": padding}
  35. def gen_explicit_map_nn_sequential(_, args_pt):
  36. """
  37. Generate explicit_map for nn.Sequential.
  38. Args:
  39. args_pt (dict): Args for APIPt.
  40. Returns:
  41. dict, map between frames.
  42. """
  43. args = args_pt['*args']
  44. return {"*args": "[{}]".format(args)}
  45. def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt):
  46. """
  47. Generate explicit_map for which include mapping relationship is `1 - k_ms = k_pt`.
  48. Args:
  49. params_pt (dict): Params for APIPt.
  50. args_pt (dict): Args for APIPt.
  51. Returns:
  52. dict, map between frames.
  53. """
  54. value = args_pt[k_pt] if k_pt in args_pt else params_pt[k_pt]
  55. value = value.strip()
  56. def is_number(string):
  57. try:
  58. float(string)
  59. return True
  60. except ValueError:
  61. return False
  62. if is_number(value):
  63. return {k_ms: str(1 - float(value))}
  64. return {k_ms: "1.0 - " + value}
  65. def gen_explicit_map_nn_maxpool2d(params_pt, args_pt):
  66. """
  67. Generate explicit_map for nn.MaxPool2d.
  68. Args:
  69. params_pt (dict): Params for APIPt.
  70. args_pt (dict): Args for APIPt.
  71. Returns:
  72. dict, map between frames.
  73. """
  74. if 'padding' in args_pt:
  75. padding = args_pt['padding']
  76. else:
  77. padding = params_pt['padding']
  78. if padding.strip() in ("0", "(0,0)", "(0, 0)"):
  79. pad_mode = "'valid'"
  80. else:
  81. pad_mode = "'same'"
  82. return {"pad_mode": pad_mode}
  83. def torch_dot_eye_gen_explicit_map(_, args_pt):
  84. """
  85. Generate explicit_map for torch.eye.
  86. Args:
  87. args_pt (dict): Args for APIPt.
  88. Returns:
  89. dict, map between frames.
  90. """
  91. explicit_map = {'t': 'mindspore.int32'}
  92. if args_pt.get('m'):
  93. explicit_map.update({'m': args_pt.get('m')})
  94. else:
  95. explicit_map.update({'m': args_pt.get('n')})
  96. return explicit_map
  97. tensor_dot_permute_gen_explicit_map = lambda params_pt, args_pt: {"input_perm": "(" + args_pt["*dIms"] + ",)"}
  98. tensor_dot_repeat_gen_explicit_map = lambda params_pt, args_pt: {"multiples": "(" + args_pt["*sizes"] + ",)"}
  99. tensor_dot_reshape_gen_explicit_map = lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"}
  100. tensor_dot_view_gen_explicit_map = lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"}
  101. nn_conv2d_gen_explicit_map = lambda params_pt, args_pt: {"pad_mode": "'pad'"}
  102. nn_batchnorm2d_gen_explicit_map = partial(gen_explicit_map_one_delta, k_ms="momentum", k_pt="momentum")
  103. nn_batchnorm1d_gen_explicit_map = nn_batchnorm2d_gen_explicit_map
  104. nn_dropout_gen_explicit_map = partial(gen_explicit_map_one_delta, k_ms="keep_prob", k_pt="p")
  105. torch_dot_add_gen_explicit_map = lambda params_pt, args_pt:\
  106. {"input_y": (args_pt['value'] + '*' + args_pt["alpha"]) if args_pt.get("alpha") else args_pt['value']}