You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_converter.py 12 kB


  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Test Converter"""
  16. from mindinsight.mindconverter.converter import Converter
  17. from mindinsight.mindconverter.config import NN_MAPPING
  18. class TestConverter:
  19. """Test Converter"""
  20. converter_ins = Converter()
  21. # test convert_api with nn ops
  22. def test_convert_api_nn_layernorm(self):
  23. """Test convert_api function work ok when convert api nn.LayerNorm"""
  24. code = "nn.LayerNorm((5, 10, 10), elementwise_affine=False)"
  25. api_name = 'nn.LayerNorm'
  26. layer_norm_info = NN_MAPPING.get(api_name)
  27. expected_ms_api_name = 'nn.LayerNorm'
  28. epsilon = layer_norm_info.pt_api.params.get('eps')
  29. replaced_code = self.converter_ins.convert_api(code)
  30. assert replaced_code == code.replace('nn.LayerNorm((5, 10, 10), elementwise_affine=False)',
  31. '{}(normalized_shape=(5, 10, 10), epsilon={})'.format(
  32. expected_ms_api_name, epsilon))
  33. def test_convert_api_nn_leaky_relu(self):
  34. """Test convert_api function work ok when convert api nn.LeakyReLU"""
  35. code = "nn.LeakyReLU(0.3)"
  36. expected_ms_api_name = 'nn.LeakyReLU'
  37. replaced_code = self.converter_ins.convert_api(code)
  38. assert replaced_code == code.replace('nn.LeakyReLU(0.3)',
  39. '{}(alpha=0.3)'.format(expected_ms_api_name))
  40. def test_convert_api_nn_prelu(self):
  41. """Test convert_api function work ok when convert api nn.PReLU"""
  42. code = "nn.PReLU()(input)"
  43. expected_ms_api_name = 'nn.PReLU'
  44. replaced_code = self.converter_ins.convert_api(code)
  45. assert replaced_code == code.replace('nn.PReLU()(input)',
  46. '{}()(input)'.format(expected_ms_api_name))
  47. def test_convert_api_nn_softmax(self):
  48. """Test convert_api function work ok when convert api nn.Softmax"""
  49. code = "nn.Softmax(dim=1)"
  50. expected_ms_api_name = 'nn.Softmax'
  51. replaced_code = self.converter_ins.convert_api(code)
  52. assert replaced_code == code.replace('nn.Softmax(dim=1)',
  53. '{}(axis=1)'.format(expected_ms_api_name))
  54. def test_convert_api_nn_dropout(self):
  55. """Test convert_api function work ok when convert api nn.Dropout"""
  56. code = """nn.Dropout(0.3)"""
  57. expected_ms_api_name = 'nn.Dropout'
  58. replaced_code = self.converter_ins.convert_api(code)
  59. assert replaced_code == code.replace('nn.Dropout(0.3)',
  60. "{}(keep_prob=0.7)".format(expected_ms_api_name))
  61. # test convert_api with torch dot ops
  62. def test_convert_api_torch_dot_abs(self):
  63. """Test convert_api function work ok when convert api torch.abs"""
  64. code = "torch.abs(input)"
  65. expected_ms_api_name = 'P.Abs'
  66. replaced_code = self.converter_ins.convert_api(code)
  67. assert replaced_code == code.replace('torch.abs(input)',
  68. '{}()(input)'.format(expected_ms_api_name))
  69. def test_convert_api_torch_dot_acos(self):
  70. """Test convert_api function work ok when convert api torch.acos"""
  71. code = "torch.acos(input)"
  72. expected_ms_api_name = 'P.ACos'
  73. replaced_code = self.converter_ins.convert_api(code)
  74. assert replaced_code == code.replace('torch.acos(input)',
  75. '{}()(input)'.format(expected_ms_api_name))
  76. def test_convert_api_torch_dot_cos(self):
  77. """Test convert_api function work ok when convert api torch.cos"""
  78. code = "torch.cos(input)"
  79. expected_ms_api_name = 'P.Cos'
  80. replaced_code = self.converter_ins.convert_api(code)
  81. assert replaced_code == code.replace('torch.cos(input)',
  82. '{}()(input)'.format(expected_ms_api_name))
  83. def test_convert_api_torch_dot_exp(self):
  84. """Test convert_api function work ok when convert api torch.exp"""
  85. code = "torch.exp(input)"
  86. expected_ms_api_name = 'P.Exp'
  87. replaced_code = self.converter_ins.convert_api(code)
  88. assert replaced_code == code.replace('torch.exp(input)',
  89. '{}()(input)'.format(expected_ms_api_name))
  90. def test_convert_api_torch_dot_log(self):
  91. """Test convert_api function work ok when convert api torch.log"""
  92. code = "torch.log(input)"
  93. expected_ms_api_name = 'P.Log'
  94. replaced_code = self.converter_ins.convert_api(code)
  95. assert replaced_code == code.replace('torch.log(input)',
  96. '{}()(input)'.format(expected_ms_api_name))
  97. def test_convert_api_torch_dot_pow(self):
  98. """Test convert_api function work ok when convert api torch.pow"""
  99. code = "torch.pow(a, exp)"
  100. expected_ms_api_name = 'P.Pow'
  101. replaced_code = self.converter_ins.convert_api(code)
  102. assert replaced_code == code.replace('torch.pow(a, exp)',
  103. '{}()(a, exp)'.format(expected_ms_api_name))
  104. def test_convert_api_torch_dot_div(self):
  105. """Test convert_api function work ok when convert api torch.div"""
  106. code = "torch.div(input, other)"
  107. expected_ms_api_name = 'P.Div'
  108. replaced_code = self.converter_ins.convert_api(code)
  109. assert replaced_code == code.replace('torch.div(input, other)',
  110. '{}()(input, other)'.format(expected_ms_api_name))
  111. def test_convert_api_torch_dot_sin(self):
  112. """Test convert_api function work ok when convert api torch.sin"""
  113. code = "torch.sin(input)"
  114. expected_ms_api_name = 'P.Sin'
  115. replaced_code = self.converter_ins.convert_api(code)
  116. assert replaced_code == code.replace('torch.sin(input)',
  117. '{}()(input)'.format(expected_ms_api_name))
  118. def test_convert_api_torch_dot_sqrt(self):
  119. """Test convert_api function work ok when convert api torch.sqrt"""
  120. code = "torch.sqrt(input)"
  121. expected_ms_api_name = 'P.Sqrt'
  122. replaced_code = self.converter_ins.convert_api(code)
  123. assert replaced_code == code.replace('torch.sqrt(input)',
  124. '{}()(input)'.format(expected_ms_api_name))
  125. def test_convert_api_torch_dot_eye_with_n(self):
  126. """Test convert_api function work ok when convert api torch.eye"""
  127. code = "torch.eye(3)"
  128. expected_ms_api_name = 'P.Eye'
  129. replaced_code = self.converter_ins.convert_api(code)
  130. assert replaced_code == code.replace('torch.eye(3)',
  131. '{}()(3, 3, mindspore.int32)'.format(expected_ms_api_name))
  132. def test_convert_api_torch_dot_eye_with_m(self):
  133. """Test convert_api function work ok when convert api torch.eye"""
  134. code = "torch.eye(3, 4)"
  135. expected_ms_api_name = 'P.Eye'
  136. replaced_code = self.converter_ins.convert_api(code)
  137. assert replaced_code == code.replace('torch.eye(3, 4)',
  138. '{}()(3, 4, mindspore.int32)'.format(expected_ms_api_name))
  139. def test_convert_api_torch_dot_add_with_alpha_default(self):
  140. """Test convert_api function work ok when convert api torch.add"""
  141. code = "torch.add(input, value)"
  142. expected_ms_api_name = 'P.TensorAdd'
  143. replaced_code = self.converter_ins.convert_api(code)
  144. assert replaced_code == code.replace('torch.add(input, value)',
  145. '{}()(input, value)'.format(expected_ms_api_name))
  146. def test_convert_api_torch_dot_add_with_alpha_not_default(self):
  147. """Test convert_api function work ok when convert api torch.add"""
  148. code = "torch.add(input, value, 3)"
  149. expected_ms_api_name = 'P.TensorAdd'
  150. replaced_code = self.converter_ins.convert_api(code)
  151. assert replaced_code == code.replace('torch.add(input, value, 3)',
  152. '{}()(input, value*3)'.format(expected_ms_api_name))
  153. # test convert_api with F ops
  154. def test_convert_api_f_normalize(self):
  155. """Test convert_api function work ok when convert api F.normalize"""
  156. code = "F.normalize(input)"
  157. expected_ms_api_name = 'P.L2Normalize'
  158. replaced_code = self.converter_ins.convert_api(code)
  159. assert replaced_code == code.replace('F.normalize(input)',
  160. '{}(1, 1e-12)(input)'.format(expected_ms_api_name))
  161. def test_convert_api_f_sigmoid(self):
  162. """Test convert_api function work ok when convert api F.sigmoid"""
  163. code = "F.sigmoid(input)"
  164. expected_ms_api_name = 'P.Sigmoid'
  165. replaced_code = self.converter_ins.convert_api(code)
  166. assert replaced_code == code.replace('F.sigmoid(input)',
  167. '{}()(input)'.format(expected_ms_api_name))
  168. def test_convert_api_f_max_pool2d(self):
  169. """Test convert_api function work ok when convert api F.max_pool2d"""
  170. code = """F.max_pool2d(out, 2)"""
  171. expected_ms_api_name = 'P.MaxPool'
  172. replaced_code = self.converter_ins.convert_api(code)
  173. assert replaced_code == code.replace('F.max_pool2d(out, 2)',
  174. "{}(2, 2, 'valid')(out)".format(expected_ms_api_name))
  175. def test_convert_api_f_avg_pool2d_without_strides(self):
  176. """Test convert_api function work ok when convert api F.avg_pool2d"""
  177. code = """F.avg_pool2d(out, 2)"""
  178. expected_ms_api_name = 'P.AvgPool'
  179. replaced_code = self.converter_ins.convert_api(code)
  180. assert replaced_code == code.replace('F.avg_pool2d(out, 2)',
  181. "{}(2, 2, 'valid')(out)".format(expected_ms_api_name))
  182. def test_convert_api_f_avg_pool2d_with_strides(self):
  183. """Test convert_api function work ok when convert api F.avg_pool2d"""
  184. code = """F.avg_pool2d(out, 2, 3)"""
  185. expected_ms_api_name = 'P.AvgPool'
  186. replaced_code = self.converter_ins.convert_api(code)
  187. assert replaced_code == code.replace('F.avg_pool2d(out, 2, 3)',
  188. "{}(2, 3, 'valid')(out)".format(expected_ms_api_name))
  189. # test convert_api with tensor dot ops
  190. def test_convert_api_tensor_dot_repeat(self):
  191. """Test convert_api function work ok when convert api .repeat"""
  192. code = "x.repeat(4, 2)"
  193. expected_ms_api_name = 'P.Tile'
  194. replaced_code = self.converter_ins.convert_api(code)
  195. assert replaced_code == code.replace('x.repeat(4, 2)',
  196. '{}()(x, {})'.format(expected_ms_api_name, '(4, 2,)'))
  197. def test_convert_api_tensor_dot_permute(self):
  198. """Test convert_api function work ok when convert api .permute"""
  199. code = "x.permute(2, 0, 1)"
  200. expected_ms_api_name = 'P.Transpose'
  201. replaced_code = self.converter_ins.convert_api(code)
  202. assert replaced_code == code.replace('x.permute(2, 0, 1)',
  203. '{}()(x, (2, 0, 1,))'.format(expected_ms_api_name))