You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_nn_ops.py 24 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test nn ops """
  16. import numpy as np
  17. import mindspore
  18. import mindspore.context as context
  19. import mindspore.nn as nn
  20. from mindspore import Tensor, Parameter
  21. from mindspore.common.initializer import initializer
  22. from mindspore.ops import composite as C
  23. from mindspore.ops import operations as P
  24. from mindspore.ops import functional as F
  25. from mindspore.ops.operations import _grad_ops as G
  26. from mindspore.ops import prim_attr_register, PrimitiveWithInfer
  27. from ..ut_filter import non_graph_engine
  28. from ....mindspore_test_framework.mindspore_test import mindspore_test
  29. from ....mindspore_test_framework.pipeline.forward.compile_forward \
  30. import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
  31. from ....mindspore_test_framework.pipeline.forward.verify_exception \
  32. import pipeline_for_verify_exception_for_case_by_case_config
  33. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  34. def conv3x3(in_channels, out_channels, stride=1, padding=1):
  35. """3x3 convolution """
  36. return nn.Conv2d(in_channels, out_channels,
  37. kernel_size=3, stride=stride, padding=padding)
  38. def conv1x1(in_channels, out_channels, stride=1, padding=0):
  39. """1x1 convolution"""
  40. return nn.Conv2d(in_channels, out_channels,
  41. kernel_size=1, stride=stride, padding=padding)
  42. class ResidualBlock(nn.Cell):
  43. """
  44. residual Block
  45. """
  46. expansion = 4
  47. def __init__(self,
  48. in_channels,
  49. out_channels,
  50. stride=1,
  51. down_sample=False):
  52. super(ResidualBlock, self).__init__()
  53. out_chls = out_channels // self.expansion
  54. self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
  55. self.bn1 = nn.BatchNorm2d(out_chls)
  56. self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=0)
  57. self.bn2 = nn.BatchNorm2d(out_chls)
  58. self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
  59. self.bn3 = nn.BatchNorm2d(out_channels)
  60. self.relu = nn.ReLU()
  61. self.downsample = down_sample
  62. self.conv_down_sample = conv1x1(in_channels, out_channels,
  63. stride=stride, padding=0)
  64. self.bn_down_sample = nn.BatchNorm2d(out_channels)
  65. self.add = P.TensorAdd()
  66. def construct(self, x):
  67. """
  68. :param x:
  69. :return:
  70. """
  71. identity = x
  72. out = self.conv1(x)
  73. out = self.bn1(out)
  74. out = self.relu(out)
  75. out = self.conv2(out)
  76. out = self.bn2(out)
  77. out = self.relu(out)
  78. out = self.conv3(out)
  79. out = self.bn3(out)
  80. if self.downsample:
  81. identity = self.conv_down_sample(identity)
  82. identity = self.bn_down_sample(identity)
  83. out = self.add(out, identity)
  84. out = self.relu(out)
  85. return out
  86. class VirtualLossGrad(PrimitiveWithInfer):
  87. """ VirtualLossGrad definition """
  88. @prim_attr_register
  89. def __init__(self):
  90. """init VirtualLossGrad"""
  91. def __call__(self, x, out, dout):
  92. raise NotImplementedError
  93. def infer_shape(self, x_shape, out_shape, dout_shape):
  94. return x_shape
  95. def infer_dtype(self, x_dtype, out_dtype, dout_dtype):
  96. return x_dtype
  97. class VirtualLoss(PrimitiveWithInfer):
  98. """ VirtualLoss definition """
  99. @prim_attr_register
  100. def __init__(self):
  101. """init VirtualLoss"""
  102. def __call__(self, x):
  103. raise NotImplementedError
  104. def get_bprop(self):
  105. loss_grad = VirtualLossGrad()
  106. def bprop(x, out, dout):
  107. # pylint: disable=unused-argument
  108. dx = loss_grad(x, out, dout)
  109. return (dx,)
  110. return bprop
  111. def infer_shape(self, x_shape):
  112. return []
  113. def infer_dtype(self, x_dtype):
  114. return x_dtype
  115. class VirtualNetWithLoss(nn.Cell):
  116. """ VirtualNetWithLoss definition """
  117. def __init__(self, network):
  118. super(VirtualNetWithLoss, self).__init__()
  119. self.loss = VirtualLoss()
  120. self.network = network
  121. def construct(self, x):
  122. predict = self.network(x)
  123. return self.loss(predict)
  124. class SoftMaxGrad(nn.Cell):
  125. """ SoftMaxGrad definition """
  126. def __init__(self, network):
  127. super(SoftMaxGrad, self).__init__()
  128. self.network = network
  129. def construct(self, x):
  130. return C.grad(self.network)(x)
  131. class DropoutGrad(nn.Cell):
  132. """ DropoutGrad definition """
  133. def __init__(self, network):
  134. super(DropoutGrad, self).__init__()
  135. self.network = network
  136. def construct(self, x):
  137. return C.grad(self.network)(x)
  138. class ScalarSummaryNet(nn.Cell):
  139. """ ScalarSummaryNet definition """
  140. def __init__(self):
  141. super(ScalarSummaryNet, self).__init__()
  142. self.summary = P.ScalarSummary()
  143. def construct(self, scalar):
  144. string_in = "bias_value"
  145. out = self.summary(string_in, scalar)
  146. return out
  147. class L2NormalizeNet(nn.Cell):
  148. """ L2NormalizeNet definition """
  149. def __init__(self):
  150. super(L2NormalizeNet, self).__init__()
  151. self.l2_normalize = P.L2Normalize()
  152. def construct(self, x):
  153. out = self.l2_normalize(x)
  154. return out
  155. class HistogramSummaryNet(nn.Cell):
  156. """HistogramSummaryNet definition"""
  157. def __init__(self):
  158. super(HistogramSummaryNet, self).__init__()
  159. self.summary = P.HistogramSummary()
  160. def construct(self, tensor):
  161. string_in = "wight_value"
  162. out = self.summary(string_in, tensor)
  163. return out
  164. class FusedBatchNormGrad(nn.Cell):
  165. """ FusedBatchNormGrad definition """
  166. def __init__(self, network):
  167. super(FusedBatchNormGrad, self).__init__()
  168. self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True)
  169. self.network = network
  170. def construct(self, inp, output_grad):
  171. return self.grad(self.network)(inp, output_grad)
  172. class NetWithLoss(nn.Cell):
  173. """ NetWithLoss definition """
  174. def __init__(self, network):
  175. super(NetWithLoss, self).__init__()
  176. self.loss = P.SmoothL1Loss()
  177. self.network = network
  178. def construct(self, x, label):
  179. predict = self.network(x)
  180. return self.loss(predict, label)
  181. class Grad(nn.Cell):
  182. """ GradWrap definition """
  183. def __init__(self, network):
  184. super(Grad, self).__init__()
  185. self.network = network
  186. self.network.set_train()
  187. def construct(self, x, label):
  188. return C.grad(self.network)(x, label)
  189. class BatchnormNet(nn.Cell):
  190. """ BatchnormNet definition """
  191. def __init__(self):
  192. super(BatchnormNet, self).__init__()
  193. self.conv1 = nn.Conv2d(3, 4, kernel_size=8, stride=2, pad_mode="pad", padding=3)
  194. self.bn1 = nn.BatchNorm2d(4)
  195. self.flatten = P.Flatten()
  196. self.weight = Parameter(Tensor(np.ones([64, 10], np.float32)), name="weight")
  197. self.bias = Parameter(Tensor(np.ones([10], np.float32)), name="bias")
  198. self.fc = P.MatMul()
  199. self.biasAdd = P.BiasAdd()
  200. def construct(self, x):
  201. x = self.conv1(x)
  202. x = self.bn1(x)
  203. x = self.flatten(x)
  204. x = self.biasAdd(self.fc(x, self.weight), self.bias)
  205. return x
  206. class NetWithLossClass(nn.Cell):
  207. """ NetWithLossClass definition """
  208. def __init__(self, network):
  209. super(NetWithLossClass, self).__init__(auto_prefix=False)
  210. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  211. self.network = network
  212. def construct(self, x, label):
  213. predict = self.network(x)
  214. return self.loss(predict, label)
  215. class BlockNet(nn.Cell):
  216. """ BlockNet definition """
  217. def __init__(self):
  218. super(BlockNet, self).__init__()
  219. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, pad_mode="pad", padding=3)
  220. self.bn1 = nn.BatchNorm2d(64)
  221. self.relu = nn.ReLU()
  222. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
  223. self.block_down_sample = ResidualBlock(
  224. 64, 256, stride=1, down_sample=True
  225. )
  226. self.flatten = P.Flatten()
  227. self.weight = Parameter(Tensor(np.ones([1024, 10]).astype(np.float32)), name="weight")
  228. self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias")
  229. self.fc = P.MatMul()
  230. self.biasAdd = P.BiasAdd()
  231. def construct(self, x):
  232. x = self.conv1(x)
  233. return x
  234. class Conv2dWithBiasNet(nn.Cell):
  235. """ Conv2dWithBiasNet definition """
  236. def __init__(self):
  237. super(Conv2dWithBiasNet, self).__init__()
  238. self.conv = nn.Conv2d(3, 10, 1, bias_init='zeros')
  239. self.flatten = P.Flatten()
  240. def construct(self, input_x):
  241. return self.flatten(self.conv(input_x))
  242. class Conv2dNativeNet(nn.Cell):
  243. """ Conv2dNativeNet definition """
  244. def __init__(self):
  245. super(Conv2dNativeNet, self).__init__()
  246. self.conv = P.DepthwiseConv2dNative(channel_multiplier=3, kernel_size=(3, 3))
  247. self.flatten = P.Flatten()
  248. channel_multipliers = 1
  249. in_channels = 3
  250. kernel_size = (3, 3)
  251. self.weight = Parameter(initializer(
  252. Tensor(np.ones([channel_multipliers, in_channels, *kernel_size], dtype=np.float32)),
  253. [channel_multipliers, in_channels, *kernel_size]), name='weight')
  254. def construct(self, input_x):
  255. return self.flatten(self.conv(input_x, self.weight))
  256. class StateNet(nn.Cell):
  257. """ StateTestTensor definition """
  258. def __init__(self):
  259. super(StateNet, self).__init__()
  260. weight = Tensor(np.ones([2, 1, 2, 2], np.float32))
  261. self.s1 = Parameter(weight, name="s1")
  262. self.s2 = Parameter(weight, name="s2")
  263. self.sub = P.Sub()
  264. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  265. self.assign = P.Assign()
  266. def construct(self, x):
  267. x = F.depend(x, self.assign(self.s1, x + self.s1))
  268. self.s1 = self.sub(self.s1, x)
  269. self.s2 = self.sub(self.s2, x)
  270. return x
  271. def test_conv2d_same_primitive():
  272. class Conv2DSameNet(nn.Cell):
  273. def __init__(self):
  274. super(Conv2DSameNet, self).__init__()
  275. self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
  276. self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
  277. def construct(self, x, y):
  278. r1 = self.conv1(x)
  279. r2 = self.conv2(y)
  280. return (r1, r2)
  281. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  282. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  283. net = Conv2DSameNet()
  284. net(t1, t2)
  285. class ComparisonNet(nn.Cell):
  286. def __init__(self):
  287. """ ComparisonNet definition """
  288. super(ComparisonNet, self).__init__()
  289. def construct(self, x, y):
  290. ret = x <= y
  291. return ret
  292. def test_max_pool_with_arg_max():
  293. class NetMaxPoolWithArgMax(nn.Cell):
  294. def __init__(self):
  295. """ ComparisonNet definition """
  296. super(NetMaxPoolWithArgMax, self).__init__()
  297. self.max_pool_with_arg_max = P.MaxPoolWithArgmax(padding="valid", ksize=2, strides=1)
  298. def construct(self, x):
  299. ret = self.max_pool_with_arg_max(x)
  300. return ret
  301. x = Tensor(np.ones([1, 1, 3, 3], np.float32))
  302. net = NetMaxPoolWithArgMax()
  303. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  304. ret = net(x)
  305. print(ret)
  306. class GradWrapUnfold(nn.Cell):
  307. """ GradWrapUnfold definition """
  308. def __init__(self, network):
  309. super(GradWrapUnfold, self).__init__()
  310. self.network = network
  311. self.sens = Tensor(np.ones([1, 4, 2, 2], np.float32))
  312. def construct(self, x):
  313. return C.grad_all_with_sens(self.network)(x, self.sens)
  314. class UnfoldNetValid(nn.Cell):
  315. """ UnfoldNetValid definition """
  316. def __init__(self):
  317. super(UnfoldNetValid, self).__init__()
  318. self.unfold = nn.Unfold(ksizes=[1, 2, 2, 1],
  319. strides=[1, 1, 1, 1],
  320. rates=[1, 1, 1, 1],
  321. padding='VALID')
  322. def construct(self, x):
  323. return self.unfold(x)
  324. class UnfoldNetSame(nn.Cell):
  325. """ UnfoldNetSame definition """
  326. def __init__(self):
  327. super(UnfoldNetSame, self).__init__()
  328. self.unfold = nn.Unfold(ksizes=[1, 2, 2, 1],
  329. strides=[1, 1, 1, 1],
  330. rates=[1, 1, 1, 1],
  331. padding='SAME')
  332. def construct(self, x):
  333. return self.unfold(x)
  334. class FlattenNet(nn.Cell):
  335. """ FlattenNet definition """
  336. def __init__(self):
  337. super(FlattenNet, self).__init__()
  338. self.flatten = P.Flatten()
  339. def construct(self, x):
  340. return self.flatten(x)
  341. class PReLUNet(nn.Cell):
  342. """ PReLUNet definition """
  343. def __init__(self):
  344. super(PReLUNet, self).__init__()
  345. self.prelu = P.PReLU()
  346. self.w = Tensor(np.ones(3, np.float32))
  347. def construct(self, x):
  348. return self.prelu(x, self.w)
  349. class PReLUGradNet(nn.Cell):
  350. """ PReLUGradNet definition """
  351. def __init__(self):
  352. super(PReLUGradNet, self).__init__()
  353. self.prelu_grad = G.PReLUGrad()
  354. def construct(self, dout, x, w):
  355. return self.prelu_grad(dout, x, w)
  356. class LRNNet(nn.Cell):
  357. """ LRNNet definition """
  358. def __init__(self):
  359. super(LRNNet, self).__init__()
  360. self.lrn = P.LRN()
  361. def construct(self, x):
  362. return self.lrn(x)
  363. class LRNGradNet(nn.Cell):
  364. """ LRNGradNet definition """
  365. def __init__(self):
  366. super(LRNGradNet, self).__init__()
  367. self.lrn_grad = G.LRNGrad()
  368. def construct(self, dout, x, out):
  369. return self.lrn_grad(dout, x, out)
  370. test_cases = [
  371. ('SoftMaxGrad', {
  372. 'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())),
  373. 'desc_inputs': [[128, 32, 32, 64]],
  374. 'desc_bprop': [[128, 32, 32, 64]],
  375. }),
  376. ('DropoutGrad', {
  377. 'block': DropoutGrad(VirtualNetWithLoss(nn.Dropout())),
  378. 'desc_inputs': [[128, 32, 32, 64]],
  379. 'desc_bprop': [[128, 32, 32, 64]],
  380. }),
  381. ('ScalarSummary', {
  382. 'block': ScalarSummaryNet(),
  383. 'desc_inputs': [Tensor(2.2)],
  384. }),
  385. ('L2Normalize', {
  386. 'block': L2NormalizeNet(),
  387. 'desc_inputs': [Tensor(np.array([[1.0, 2, 3], [4.0, 5, 6], [7.0, 8, 9]]), mindspore.float32)],
  388. }),
  389. ('HistogramSummary', {
  390. 'block': HistogramSummaryNet(),
  391. 'desc_inputs': [[1, 2, 3]],
  392. }),
  393. ('FusedBatchNormGrad', {
  394. 'block': FusedBatchNormGrad(nn.BatchNorm2d(num_features=512, eps=1e-5, momentum=0.1)),
  395. 'desc_inputs': [[64, 512, 7, 7], [64, 512, 7, 7]],
  396. 'desc_bprop': [[64, 512, 7, 7]],
  397. }),
  398. ('BatchnormGrad', {
  399. 'block': Grad(NetWithLoss(BatchnormNet())),
  400. 'desc_inputs': [Tensor(np.ones([1, 3, 8, 8], np.float32)), Tensor(np.zeros([1, 10], np.float32))],
  401. }),
  402. ('BlockGrad', {
  403. 'block': Grad(NetWithLossClass(BlockNet())),
  404. 'desc_inputs': [Tensor(np.ones([1, 3, 8, 8], np.float32)), Tensor(np.zeros([1, 64, 4, 4], np.float32))],
  405. }),
  406. ('Conv2dWithBiasGrad', {
  407. 'block': Grad(NetWithLossClass(Conv2dWithBiasNet())),
  408. 'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 2560], np.float32))],
  409. }),
  410. ('Conv2dNativeGrad', {
  411. 'block': Grad(NetWithLossClass(Conv2dNativeNet())),
  412. 'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 1764], np.float32))],
  413. }),
  414. ('StateTest', {
  415. 'block': StateNet(),
  416. 'desc_inputs': [Tensor(np.ones([2, 1, 2, 2]).astype(np.float32))],
  417. }),
  418. ('StateGrad', {
  419. 'block': Grad(NetWithLossClass(StateNet())),
  420. 'desc_inputs': [Tensor(np.ones([2, 1, 2, 2], np.float32)), Tensor(np.ones([2, 1, 2, 2], np.float32))],
  421. }),
  422. ('ComparisonTest', {
  423. 'block': ComparisonNet(),
  424. 'desc_inputs': [Tensor(np.ones([6, 9, 10], np.int32)), Tensor(np.ones([6, 9, 10], np.int32))],
  425. }),
  426. ('UnfoldValid', {
  427. 'block': UnfoldNetValid(),
  428. 'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))],
  429. 'desc_bprop': [Tensor(np.ones([1, 4, 2, 2], np.float32))],
  430. 'skip': ['backward']}),
  431. ('UnfoldSame', {
  432. 'block': UnfoldNetSame(),
  433. 'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))],
  434. 'desc_bprop': [Tensor(np.ones([1, 4, 3, 3], np.float32))],
  435. 'skip': ['backward']}),
  436. ('UnfoldGrad', {
  437. 'block': GradWrapUnfold(UnfoldNetValid()),
  438. 'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))],
  439. 'desc_bprop': [Tensor(np.ones([1, 4, 2, 2], np.float32))],
  440. 'skip': ['backward']}),
  441. ('LogSigmoid', {
  442. 'block': nn.LogSigmoid(),
  443. 'desc_inputs': [Tensor(np.array([1, 2, 3, 4]).astype(np.float32))],
  444. 'desc_bprop': [Tensor(np.array([1, 2, 3, 4]).astype(np.float32))],
  445. 'skip': ['backward']}),
  446. ('ReduceLogSumExp', {
  447. 'block': nn.ReduceLogSumExp((0,), False),
  448. 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
  449. 'desc_bprop': [Tensor(np.array([1, 2, 3, 4]).astype(np.float32))],
  450. 'skip': ['backward']}),
  451. ('FlattenNet', {
  452. 'block': FlattenNet(),
  453. 'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))],
  454. }),
  455. ('PReLUNet', {
  456. 'block': PReLUNet(),
  457. 'desc_inputs': [Tensor(np.ones([1, 3, 4, 4], np.float32))],
  458. }),
  459. ('PReLUGradNet', {
  460. 'block': PReLUGradNet(),
  461. 'desc_inputs': [Tensor(np.ones([1, 3, 4, 4], np.float32)),
  462. Tensor(np.ones([1, 3, 4, 4], np.float32)),
  463. Tensor(np.ones(3, np.float32))],
  464. }),
  465. ('MatrixDiag', {
  466. 'block': nn.MatrixDiag(),
  467. 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))],
  468. 'skip': ['backward']
  469. }),
  470. ('MatrixDiagPart', {
  471. 'block': nn.MatrixDiagPart(),
  472. 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))],
  473. 'skip': ['backward']
  474. }),
  475. ('MatrixSetDiag', {
  476. 'block': nn.MatrixSetDiag(),
  477. 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)),
  478. Tensor(np.array([1, 2]).astype(np.float32))],
  479. 'skip': ['backward']
  480. }),
  481. ('LRNNet', {
  482. 'block': LRNNet(),
  483. 'desc_inputs': [Tensor(np.ones([1, 5, 4, 4], np.float32))],
  484. }),
  485. ('LRNGradNet', {
  486. 'block': LRNGradNet(),
  487. 'desc_inputs': [Tensor(np.ones([1, 5, 4, 4], np.float32)),
  488. Tensor(np.ones([1, 5, 4, 4], np.float32)),
  489. Tensor(np.ones([1, 5, 4, 4], np.float32))],
  490. }),
  491. ]
  492. test_cases_for_verify_exception = [
  493. ('ApplyMomentum_Error', {
  494. 'block': (P.ApplyMomentum(), {'exception': TypeError}),
  495. 'desc_inputs': [[2], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64]],
  496. 'desc_bprop': [[128, 32, 32, 64]],
  497. 'skip': ['backward']
  498. }),
  499. ('Conv2d_ValueError_1', {
  500. 'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': TypeError}),
  501. 'desc_inputs': [0],
  502. }),
  503. ('Conv2d_ValueError_2', {
  504. 'block': (lambda _: P.Conv2D(3, 4, mode=-2), {'exception': ValueError}),
  505. 'desc_inputs': [0],
  506. }),
  507. ('MaxPoolWithArgmax_ValueError_1', {
  508. 'block': (lambda _: P.MaxPoolWithArgmax(padding='sane'), {'exception': ValueError}),
  509. 'desc_inputs': [0],
  510. }),
  511. ('MaxPoolWithArgmax_ValueError_2', {
  512. 'block': (lambda _: P.MaxPoolWithArgmax(ksize='1'), {'exception': TypeError}),
  513. 'desc_inputs': [0],
  514. }),
  515. ('MaxPoolWithArgmax_ValueError_3', {
  516. 'block': (lambda _: P.MaxPoolWithArgmax(ksize=-2), {'exception': ValueError}),
  517. 'desc_inputs': [0],
  518. }),
  519. ('MaxPoolWithArgmax_ValueError_4', {
  520. 'block': (lambda _: P.MaxPoolWithArgmax(strides=-1), {'exception': ValueError}),
  521. 'desc_inputs': [0],
  522. }),
  523. ('FusedBatchNorm_ValueError_1', {
  524. 'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': TypeError}),
  525. 'desc_inputs': [0],
  526. }),
  527. ('FusedBatchNorm_ValueError_2', {
  528. 'block': (lambda _: P.FusedBatchNorm(mode=2, epsilon=1e-5, momentum=0.1), {'exception': ValueError}),
  529. 'desc_inputs': [0],
  530. }),
  531. ('FusedBatchNorm_ValueError_3', {
  532. 'block': (lambda _: P.FusedBatchNorm(mode=0, epsilon=-1e-5, momentum=0.1), {'exception': ValueError}),
  533. 'desc_inputs': [0],
  534. }),
  535. ('FusedBatchNorm_ValueError_4', {
  536. 'block': (lambda _: P.FusedBatchNorm(mode=0, epsilon=1e-5, momentum=-0.1), {'exception': ValueError}),
  537. 'desc_inputs': [0],
  538. }),
  539. ('FusedBatchNorm_ValueError_5', {
  540. 'block': (lambda _: P.FusedBatchNorm(mode=1, epsilon=-0.001, momentum=0.0), {'exception': ValueError}),
  541. 'desc_inputs': [0],
  542. }),
  543. ('Softmax_ValueError_1', {
  544. 'block': (lambda _: P.Softmax("1"), {'exception': TypeError}),
  545. 'desc_inputs': [0],
  546. }),
  547. ('Softmax_ValueError_2', {
  548. 'block': (lambda _: P.Softmax(1.1), {'exception': TypeError}),
  549. 'desc_inputs': [0],
  550. }),
  551. ('Softmax_ValueError_3', {
  552. 'block': (lambda _: P.Softmax(axis="1"), {'exception': TypeError}),
  553. 'desc_inputs': [0],
  554. }),
  555. ('DropoutGenMask_ValueError_1', {
  556. 'block': (lambda _: P.DropoutGenMask(Seed0="seed0"), {'exception': TypeError}),
  557. 'desc_inputs': [0],
  558. }),
  559. ('DropoutGenMask_ValueError_2', {
  560. 'block': (lambda _: P.DropoutGenMask(Seed0=1.0), {'exception': TypeError}),
  561. 'desc_inputs': [0],
  562. }),
  563. ('DropoutGenMask_ValueError_3', {
  564. 'block': (lambda _: P.DropoutGenMask(Seed1="seed1"), {'exception': TypeError}),
  565. 'desc_inputs': [0],
  566. }),
  567. ('DropoutGenMask_ValueError_4', {
  568. 'block': (lambda _: P.DropoutGenMask(Seed1=2.0), {'exception': TypeError}),
  569. 'desc_inputs': [0],
  570. }),
  571. ('MaxPool2d_ValueError_1', {
  572. 'block': (nn.MaxPool2d(kernel_size=120, stride=1, pad_mode="valid"), {'exception': ValueError}),
  573. 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
  574. }),
  575. ('MaxPool2d_ValueError_2', {
  576. 'block': (
  577. lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"),
  578. {'exception': TypeError},
  579. ),
  580. 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
  581. }),
  582. ('MaxPool2d_ValueError_3', {
  583. 'block': (
  584. lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"),
  585. {'exception': TypeError},
  586. ),
  587. 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
  588. }),
  589. ('ReduceLogsumexp_TypeError_1', {
  590. 'block': (
  591. lambda _: nn.ReduceLogSumExp(axis=(0,), keep_dims=2),
  592. {'exception': TypeError},
  593. ),
  594. 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
  595. }),
  596. ('ReduceLogsumexp_TypeError_2', {
  597. 'block': (
  598. lambda _: nn.ReduceLogSumExp(axis=1.2, keep_dims=True),
  599. {'exception': TypeError},
  600. ),
  601. 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))],
  602. }),
  603. ]
  604. @non_graph_engine
  605. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
  606. def test_compile():
  607. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  608. return test_cases
  609. @mindspore_test(pipeline_for_verify_exception_for_case_by_case_config)
  610. def test_check_exception():
  611. return test_cases_for_verify_exception