You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

funcs.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless REQUIRED by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """funcs for gen_explicit_map"""
  16. from functools import partial
  17. def gen_explicit_map_f_max_pool2d(params_pt, args_pt):
  18. """
  19. Generate explicit_map for F.MaxPool2d.
  20. Args:
  21. params_pt (dict): Params for APIPt.
  22. args_pt (dict): Args for APIPt.
  23. Returns:
  24. dict, map between frames.
  25. """
  26. if 'padding' in args_pt:
  27. padding = args_pt['padding']
  28. else:
  29. padding = params_pt['padding']
  30. if padding.strip() in ("0", "(0,0)", "(0, 0)"):
  31. padding = "'valid'"
  32. else:
  33. padding = "'same'"
  34. if 'stride' in args_pt:
  35. strides = args_pt['stride']
  36. else:
  37. strides = args_pt['kernel_size']
  38. return {"padding": padding,
  39. "strides": strides}
  40. def gen_explicit_map_nn_sequential(_, args_pt):
  41. """
  42. Generate explicit_map for nn.Sequential.
  43. Args:
  44. args_pt (dict): Args for APIPt.
  45. Returns:
  46. dict, map between frames.
  47. """
  48. args = args_pt['*args']
  49. return {"*args": "[{}]".format(args)}
  50. def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt):
  51. """
  52. Generate explicit_map for which include mapping relationship is `1 - k_ms = k_pt`.
  53. Args:
  54. params_pt (dict): Params for APIPt.
  55. args_pt (dict): Args for APIPt.
  56. Returns:
  57. dict, map between frames.
  58. """
  59. value = args_pt[k_pt] if k_pt in args_pt else params_pt[k_pt]
  60. value = value.strip()
  61. def is_number(string):
  62. try:
  63. float(string)
  64. return True
  65. except ValueError:
  66. return False
  67. if is_number(value):
  68. return {k_ms: str(1 - float(value))}
  69. return {k_ms: "1.0 - " + value}
  70. def gen_explicit_map_nn_maxpool2d(params_pt, args_pt):
  71. """
  72. Generate explicit_map for nn.MaxPool2d.
  73. Args:
  74. params_pt (dict): Params for APIPt.
  75. args_pt (dict): Args for APIPt.
  76. Returns:
  77. dict, map between frames.
  78. """
  79. if 'padding' in args_pt:
  80. padding = args_pt['padding']
  81. else:
  82. padding = params_pt['padding']
  83. if padding.strip() in ("0", "(0,0)", "(0, 0)"):
  84. pad_mode = "'valid'"
  85. else:
  86. pad_mode = "'same'"
  87. if 'stride' in args_pt:
  88. stride = args_pt['stride']
  89. else:
  90. stride = args_pt['kernel_size']
  91. return {"pad_mode": pad_mode,
  92. "stride": stride}
  93. def torch_dot_eye_gen_explicit_map(_, args_pt):
  94. """
  95. Generate explicit_map for torch.eye.
  96. Args:
  97. args_pt (dict): Args for APIPt.
  98. Returns:
  99. dict, map between frames.
  100. """
  101. explicit_map = {'t': 'mindspore.int32'}
  102. if args_pt.get('m'):
  103. explicit_map.update({'m': args_pt.get('m')})
  104. else:
  105. explicit_map.update({'m': args_pt.get('n')})
  106. return explicit_map
  107. tensor_dot_permute_gen_explicit_map = lambda params_pt, args_pt: {"input_perm": "(" + args_pt["*dIms"] + ",)"}
  108. tensor_dot_repeat_gen_explicit_map = lambda params_pt, args_pt: {"multiples": "(" + args_pt["*sizes"] + ",)"}
  109. tensor_dot_reshape_gen_explicit_map = lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"}
  110. tensor_dot_view_gen_explicit_map = lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"}
  111. nn_conv2d_gen_explicit_map = lambda params_pt, args_pt: {"pad_mode": "'pad'"}
  112. nn_batchnorm2d_gen_explicit_map = partial(gen_explicit_map_one_delta, k_ms="momentum", k_pt="momentum")
  113. nn_batchnorm1d_gen_explicit_map = nn_batchnorm2d_gen_explicit_map
  114. nn_dropout_gen_explicit_map = partial(gen_explicit_map_one_delta, k_ms="keep_prob", k_pt="p")
  115. torch_dot_add_gen_explicit_map = lambda params_pt, args_pt:\
  116. {"input_y": (args_pt['value'] + '*' + args_pt["alpha"]) if args_pt.get("alpha") else args_pt['value']}