You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ops_reid.py 6.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test Activations """
  16. import functools
  17. import numpy as np
  18. import mindspore.nn as nn
  19. from mindspore.ops import operations as P
  20. from ....mindspore_test_framework.mindspore_test import mindspore_test
  21. from ....mindspore_test_framework.pipeline.forward.compile_forward \
  22. import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
  23. from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
  24. import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
  25. from ....ops_common import convert
  26. class SeqConvBnRelu(nn.Cell):
  27. """ SeqConvBnRelu definition """
  28. def __init__(self, in_ch, out_ch):
  29. super(SeqConvBnRelu, self).__init__()
  30. self.conv = nn.Conv2d(in_ch, out_ch, 3)
  31. self.bn = nn.BatchNorm2d(out_ch)
  32. self.relu = P.ReLU()
  33. def construct(self, input_x):
  34. return self.relu(self.bn(self.conv(input_x)))
  35. test_case_reid_ops = [
  36. ('ReduceMax', {
  37. 'block': P.ReduceMax(keep_dims=False),
  38. 'desc_const': [(1,)],
  39. 'desc_inputs': [convert([32, 32], np.float16)],
  40. 'desc_bprop': [convert([32], np.float16)],
  41. 'skip': []}),
  42. ('ReduceMin', {
  43. 'block': P.ReduceMin(),
  44. 'desc_const': [(1,)],
  45. 'desc_inputs': [[32, 32]],
  46. 'desc_bprop': [[32]],
  47. 'skip': []}),
  48. ('ReduceMean', {
  49. 'block': P.ReduceMean(keep_dims=True),
  50. 'desc_const': [(1, 2)],
  51. 'desc_inputs': [[32, 4, 4]],
  52. 'desc_bprop': [[32, 1, 1]]}),
  53. ('Log', {
  54. 'block': P.Log(),
  55. 'desc_inputs': [[4, 128, 1024]],
  56. 'desc_bprop': [[4, 128, 1024]],
  57. 'skip': ['backward']}), # check backward error
  58. ('Reciprocal', {
  59. 'block': P.Reciprocal(),
  60. 'desc_inputs': [[4, 128, 1024]],
  61. 'desc_bprop': [[4, 128, 1024]],
  62. 'skip': ['backward']}),
  63. ('FloorDiv', {
  64. 'block': P.FloorDiv(),
  65. 'desc_inputs': [[4, 128, 1024], [4, 128, 1024]],
  66. 'desc_bprop': [[4, 128, 1024]]}),
  67. ('Sigmoid', {
  68. 'block': P.Sigmoid(),
  69. 'desc_inputs': [[4, 128, 1024]],
  70. 'desc_bprop': [[4, 128, 1024]]}),
  71. ('Softmax', {
  72. 'block': P.Softmax(),
  73. 'desc_inputs': [[1, 16]],
  74. 'desc_bprop': [[1, 16]],
  75. 'skip': ['backward']}), # check backward error
  76. ('Softmax', {
  77. 'block': P.Softmax(axis=(0, 1)),
  78. 'desc_inputs': [[1, 16]],
  79. 'desc_bprop': [[1, 16]],
  80. 'skip': ['backward']}),
  81. ('L2Normalize', {
  82. 'block': P.L2Normalize(),
  83. 'desc_inputs': [[4, 128, 1024]],
  84. 'desc_bprop': [[4, 128, 1024]]}),
  85. ('ReLU', {
  86. 'block': P.ReLU(),
  87. 'desc_inputs': [[64, 64, 112, 112]],
  88. 'desc_bprop': [[64, 64, 112, 112]]}),
  89. ('SeqConvBnRelu', {
  90. 'block': SeqConvBnRelu(3, 64),
  91. 'desc_inputs': [[64, 3, 112, 112]],
  92. 'desc_bprop': [[64, 64, 112, 112]]}),
  93. ('PReluCell', {
  94. 'block': nn.PReLU(1, [np.float32(0.25)]),
  95. 'desc_inputs': [[128, 64, 112, 112]],
  96. 'desc_bprop': [[128, 64, 112, 112]]}),
  97. ('PRelu', {
  98. 'block': P.PReLU(),
  99. 'desc_inputs': [[128, 64, 112, 112], [64,]],
  100. 'desc_bprop': [[128, 64, 112, 112]]}),
  101. ('Cos', {
  102. 'block': P.Cos(),
  103. 'desc_inputs': [[8, 16]],
  104. 'desc_bprop': [[8, 16]]}),
  105. ('ACos', {
  106. 'block': P.ACos(),
  107. 'desc_inputs': [[8, 16]],
  108. 'desc_bprop': [[8, 16]]}),
  109. ('Exp', {
  110. 'block': P.Exp(),
  111. 'desc_inputs': [[256, 8]],
  112. 'desc_bprop': [[256, 8]]}),
  113. ('Pow', {
  114. 'block': P.Pow(), # 输入有标量插件产生了段错误。
  115. 'desc_const': [2.0],
  116. 'desc_inputs': [[1, 512]],
  117. 'desc_bprop': [[1, 512]]}),
  118. ('LogicalNot', {
  119. 'block': P.LogicalNot(),
  120. 'desc_inputs': [convert([256], np.bool_)],
  121. 'desc_bprop': [[256]]}), # 自定义算子 input bool没转换,gongchen提单。
  122. ('Equal', {
  123. 'block': P.Equal(),
  124. 'desc_inputs': [convert([256], np.float16), convert([256], np.float16)],
  125. 'desc_bprop': [[256]]}),
  126. ('Greater', {
  127. 'block': P.Greater(),
  128. 'desc_inputs': [convert([256], np.float16), convert([256], np.float16)],
  129. 'desc_bprop': [[256]]}),
  130. ('Dropout', {
  131. 'block': nn.Dropout(),
  132. 'desc_inputs': [[1, 512, 7, 7]],
  133. 'desc_bprop': [[1, 512, 7, 7]]}), # 输入有标量插件产生了段错误。
  134. ('MatMul', {
  135. 'block': P.MatMul(),
  136. 'desc_inputs': [[64, 512], [512, 64]], # fp16不行。很有问题。
  137. 'desc_bprop': [[64, 64]]}),
  138. ('Maximum', {
  139. 'block': P.Maximum(),
  140. 'desc_inputs': [[64, 1], [64, 1]],
  141. 'desc_bprop': [[64, 1]]}),
  142. ]
  143. test_case_lists = [test_case_reid_ops]
  144. test_case = functools.reduce(lambda x, y: x + y, test_case_lists)
  145. # use -k to select certain testcast
  146. # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
  147. test_exec_case = filter(lambda x: 'skip' not in x[1] or
  148. 'exec' not in x[1]['skip'], test_case)
  149. test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or
  150. 'backward' not in x[1]['skip'] and 'backward_exec'
  151. not in x[1]['skip'], test_case)
  152. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
  153. def test_exec():
  154. return test_exec_case
  155. @mindspore_test(pipeline_for_compile_grad_ge_graph_for_case_by_case_config)
  156. def test_backward_exec():
  157. return test_backward_exec_case