You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parse.py 7.4 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_parse.py
  17. @Author:
  18. @Date : 2019-01-23 17:13
  19. @Desc :
  20. """
  21. import logging
  22. import pytest
  23. import numpy as np
  24. import mindspore as ms
  25. import mindspore.nn as nn
  26. from mindspore import Tensor
  27. from mindspore import context
  28. from mindspore.ops import composite as C
  29. from mindspore.common.api import ms_function, _executor
  30. from mindspore.ops._grad.grad_base import bprop_getters
  31. from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
  32. from mindspore.ops.functional import tensor_add
  33. from ...ut_filter import non_graph_engine
  34. # pylint: disable=W0613,W0612
  35. # W0613: unused-argument
  36. log = logging.getLogger("test")
  37. log.setLevel(level=logging.ERROR)
  38. context.set_context(mode=context.GRAPH_MODE)
  39. # Test case: use the parse obj interface use default parameter
  40. class Net(nn.Cell):
  41. """ Net definition """
  42. def __init__(self, dim):
  43. super(Net, self).__init__()
  44. self.softmax1 = nn.Softmax(dim)
  45. self.softmax2 = nn.Softmax(dim + 1)
  46. def construct(self, input_data, input1=ms.Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))):
  47. return self.softmax1(input_data)
  48. @non_graph_engine
  49. def test_parse_defalut_parameter_case2():
  50. """ test_parse_defalut_parameter_case2 """
  51. log.debug("begin test_parse_defalut_parameter_case2")
  52. net = Net(0)
  53. npd = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  54. log.debug("input value is: %r", npd)
  55. input_data = ms.Tensor(npd)
  56. input_data.set_dtype(ms.float32)
  57. log.debug("start run")
  58. output = net(input_data)
  59. value = output.asnumpy()
  60. log.debug("output value = %r", value)
  61. # Test case: use the variable parameter for parse object
  62. class Net1(nn.Cell):
  63. """ Net1 definition """
  64. def __init__(self):
  65. super(Net1, self).__init__()
  66. def construct(self, *args):
  67. x = args[0]
  68. return x
  69. def test_var_parameter_case2():
  70. """ test_var_parameter_case2 """
  71. log.debug("begin test_var_parameter_case2")
  72. net = Net1()
  73. npd = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  74. log.debug("input value is: %r", npd)
  75. input_data = ms.Tensor(npd)
  76. input_data.set_dtype(ms.float32)
  77. np1 = np.random.randn(2, 3, 4, 5).astype(np.float32)
  78. input1 = ms.Tensor(np1)
  79. np2 = np.random.randn(2, 3, 4, 5).astype(np.float32)
  80. input2 = ms.Tensor(np2)
  81. _executor.compile(net, input_data, input1, input2)
  82. # Test case: test the global flag
  83. g_x = Tensor(np.ones([3, 3]).astype(np.float32))
  84. @ms_function
  85. def tensor_add_global(x):
  86. """ tensor_add_global """
  87. global g_x
  88. res = tensor_add(x, g_x)
  89. return res
  90. @non_graph_engine
  91. def test_global_flag():
  92. """ test_global_flag """
  93. log.debug("begin test_global_flag")
  94. x = Tensor(np.ones([3, 3]).astype(np.float32))
  95. res = tensor_add_global(x)
  96. log.debug("finished test_global_flag, ret = %r", res)
  97. class NetWithNDarray(nn.Cell):
  98. """ NetWithNDarray definition """
  99. def __init__(self, dim):
  100. super(NetWithNDarray, self).__init__()
  101. self.softmax = nn.Softmax(dim)
  102. self.x = ms.Tensor(np.ones(shape=(1)).astype(np.float32))
  103. def construct(self, input_data):
  104. return self.softmax(input_data) * self.x
  105. @non_graph_engine
  106. def test_net_with_ndarray():
  107. """ test_net_with_ndarray """
  108. net = NetWithNDarray(0)
  109. input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  110. net(ms.Tensor(input_data))
  111. def test_bprop_with_wrong_output_num():
  112. context.set_context(check_bprop=True)
  113. class BpropWithWrongOutputNum(PrimitiveWithInfer):
  114. @prim_attr_register
  115. def __init__(self):
  116. super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
  117. def __call__(self, x, y):
  118. return x
  119. def infer_shape(self, x_shape, yshape):
  120. return x_shape
  121. def infer_dtype(self, x_type, y_type):
  122. return x_type
  123. @bprop_getters.register(BpropWithWrongOutputNum)
  124. def get_bprop_with_wrong_output_num(self):
  125. """Generate bprop for BpropWithWrongOutputNum"""
  126. def bprop(x, y, out, dout):
  127. return (dout,)
  128. return bprop
  129. class BpropWithWrongOutputNumCell(nn.Cell):
  130. def __init__(self):
  131. super(BpropWithWrongOutputNumCell, self).__init__()
  132. def construct(self, x, y):
  133. return BpropWithWrongOutputNum()(x, y)
  134. with pytest.raises(TypeError):
  135. C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
  136. def test_bprop_with_wrong_output_type():
  137. context.set_context(check_bprop=True)
  138. class BpropWithWrongOutputType(PrimitiveWithInfer):
  139. @prim_attr_register
  140. def __init__(self):
  141. super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
  142. def __call__(self, x):
  143. return x
  144. def infer_shape(self, x_shape):
  145. return x_shape
  146. def infer_dtype(self, x_type):
  147. return x_type
  148. @bprop_getters.register(BpropWithWrongOutputType)
  149. def get_bprop_with_wrong_output_type(self):
  150. """Generate bprop for BpropWithWrongOutputType"""
  151. def bprop(x, out, dout):
  152. return (1,)
  153. return bprop
  154. class BpropWithWrongOutputTypeCell(nn.Cell):
  155. def __init__(self):
  156. super(BpropWithWrongOutputTypeCell, self).__init__()
  157. def construct(self, x):
  158. return BpropWithWrongOutputType()(x)
  159. with pytest.raises(TypeError):
  160. C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
  161. def test_bprop_with_wrong_output_shape():
  162. context.set_context(check_bprop=True)
  163. class BpropWithWrongOutputShape(PrimitiveWithInfer):
  164. @prim_attr_register
  165. def __init__(self):
  166. super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
  167. def __call__(self, x):
  168. return x
  169. def infer_shape(self, x_shape):
  170. return x_shape
  171. def infer_dtype(self, x_type):
  172. return x_type
  173. @bprop_getters.register(BpropWithWrongOutputShape)
  174. def get_bprop_with_wrong_output_shape(self):
  175. """Generate bprop for BpropWithWrongOutputShape"""
  176. ones = Tensor(np.ones([2,]).astype(np.int32))
  177. def bprop(x, out, dout):
  178. return (ones,)
  179. return bprop
  180. class BpropWithWrongOutputShapeCell(nn.Cell):
  181. def __init__(self):
  182. super(BpropWithWrongOutputShapeCell, self).__init__()
  183. def construct(self, x):
  184. return BpropWithWrongOutputShape()(x)
  185. with pytest.raises(TypeError):
  186. net = BpropWithWrongOutputShapeCell()
  187. net.set_grad()
  188. C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))