You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_nn_ops.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test nn ops """
  16. import functools
  17. import numpy as np
  18. import mindspore
  19. import mindspore.nn as nn
  20. import mindspore.context as context
  21. from mindspore import Tensor, Parameter
  22. from mindspore.common.initializer import initializer
  23. from mindspore.ops import Primitive
  24. from mindspore.ops import composite as C
  25. from mindspore.ops import operations as P
  26. from mindspore.ops import prim_attr_register, PrimitiveWithInfer
  27. from ..ut_filter import non_graph_engine
  28. from ....mindspore_test_framework.mindspore_test import mindspore_test
  29. from ....mindspore_test_framework.pipeline.forward.compile_forward \
  30. import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
  31. from ....mindspore_test_framework.pipeline.forward.verify_exception \
  32. import pipeline_for_verify_exception_for_case_by_case_config
  33. def conv3x3(in_channels, out_channels, stride=1, padding=1):
  34. """3x3 convolution """
  35. return nn.Conv2d(in_channels, out_channels,
  36. kernel_size=3, stride=stride, padding=padding)
  37. def conv1x1(in_channels, out_channels, stride=1, padding=0):
  38. """1x1 convolution"""
  39. return nn.Conv2d(in_channels, out_channels,
  40. kernel_size=1, stride=stride, padding=padding)
  41. class ResidualBlock(nn.Cell):
  42. """
  43. residual Block
  44. """
  45. expansion = 4
  46. def __init__(self,
  47. in_channels,
  48. out_channels,
  49. stride=1,
  50. down_sample=False):
  51. super(ResidualBlock, self).__init__()
  52. out_chls = out_channels // self.expansion
  53. self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
  54. self.bn1 = nn.BatchNorm2d(out_chls)
  55. self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=0)
  56. self.bn2 = nn.BatchNorm2d(out_chls)
  57. self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
  58. self.bn3 = nn.BatchNorm2d(out_channels)
  59. self.relu = nn.ReLU()
  60. self.downsample = down_sample
  61. self.conv_down_sample = conv1x1(in_channels, out_channels,
  62. stride=stride, padding=0)
  63. self.bn_down_sample = nn.BatchNorm2d(out_channels)
  64. self.add = P.TensorAdd()
  65. def construct(self, x):
  66. """
  67. :param x:
  68. :return:
  69. """
  70. identity = x
  71. out = self.conv1(x)
  72. out = self.bn1(out)
  73. out = self.relu(out)
  74. out = self.conv2(out)
  75. out = self.bn2(out)
  76. out = self.relu(out)
  77. out = self.conv3(out)
  78. out = self.bn3(out)
  79. if self.downsample:
  80. identity = self.conv_down_sample(identity)
  81. identity = self.bn_down_sample(identity)
  82. out = self.add(out, identity)
  83. out = self.relu(out)
  84. return out
  85. class VirtualLossGrad(PrimitiveWithInfer):
  86. """ VirtualLossGrad definition """
  87. @prim_attr_register
  88. def __init__(self):
  89. """init VirtualLossGrad"""
  90. def __call__(self, x, out, dout):
  91. raise NotImplementedError
  92. def infer_shape(self, x_shape, out_shape, dout_shape):
  93. return x_shape
  94. def infer_dtype(self, x_dtype, out_dtype, dout_dtype):
  95. return x_dtype
  96. class VirtualLoss(PrimitiveWithInfer):
  97. """ VirtualLoss definition """
  98. @prim_attr_register
  99. def __init__(self):
  100. """init VirtualLoss"""
  101. def __call__(self, x):
  102. raise NotImplementedError
  103. def get_bprop(self):
  104. loss_grad = VirtualLossGrad()
  105. def bprop(x, out, dout):
  106. # pylint: disable=unused-argument
  107. dx = loss_grad(x, out, dout)
  108. return (dx,)
  109. return bprop
  110. def infer_shape(self, x_shape):
  111. return []
  112. def infer_dtype(self, x_dtype):
  113. return x_dtype
  114. class VirtualNetWithLoss(nn.Cell):
  115. """ VirtualNetWithLoss definition """
  116. def __init__(self, network):
  117. super(VirtualNetWithLoss, self).__init__()
  118. self.loss = VirtualLoss()
  119. self.network = network
  120. def construct(self, x):
  121. predict = self.network(x)
  122. return self.loss(predict)
  123. class SoftMaxGrad(nn.Cell):
  124. """ SoftMaxGrad definition """
  125. def __init__(self, network):
  126. super(SoftMaxGrad, self).__init__()
  127. self.network = network
  128. def construct(self, x):
  129. return C.grad(self.network)(x)
  130. class DropoutGrad(nn.Cell):
  131. """ DropoutGrad definition """
  132. def __init__(self, network):
  133. super(DropoutGrad, self).__init__()
  134. self.network = network
  135. def construct(self, x):
  136. return C.grad(self.network)(x)
  137. class ScalarSummaryNet(nn.Cell):
  138. """ ScalarSummaryNet definition """
  139. def __init__(self):
  140. super(ScalarSummaryNet, self).__init__()
  141. self.summary = P.ScalarSummary()
  142. def construct(self, scalar):
  143. string_in = "bias_value"
  144. out = self.summary(string_in, scalar)
  145. return out
  146. class HistogramSummaryNet(nn.Cell):
  147. """HistogramSummaryNet definition"""
  148. def __init__(self):
  149. super(HistogramSummaryNet, self).__init__()
  150. self.summary = P.HistogramSummary()
  151. def construct(self, tensor):
  152. string_in = "wight_value"
  153. out = self.summary(string_in, tensor)
  154. return out
  155. class FusedBatchNormGrad(nn.Cell):
  156. """ FusedBatchNormGrad definition """
  157. def __init__(self, network):
  158. super(FusedBatchNormGrad, self).__init__()
  159. self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True)
  160. self.network = network
  161. def construct(self, inp, output_grad):
  162. return self.grad(self.network)(inp, output_grad)
  163. class NetWithLoss(nn.Cell):
  164. """ NetWithLoss definition """
  165. def __init__(self, network):
  166. super(NetWithLoss, self).__init__()
  167. self.loss = P.SmoothL1Loss()
  168. self.network = network
  169. def construct(self, x, label):
  170. predict = self.network(x)
  171. return self.loss(predict, label)
  172. class Grad(nn.Cell):
  173. """ GradWrap definition """
  174. def __init__(self, network):
  175. super(Grad, self).__init__()
  176. self.network = network
  177. self.network.set_train()
  178. def construct(self, x, label):
  179. return C.grad(self.network)(x, label)
  180. class BatchnormNet(nn.Cell):
  181. """ BatchnormNet definition """
  182. def __init__(self):
  183. super(BatchnormNet, self).__init__()
  184. self.conv1 = nn.Conv2d(3, 4, kernel_size=8, stride=2, pad_mode="pad", padding=3)
  185. self.bn1 = nn.BatchNorm2d(4)
  186. self.flatten = P.Flatten()
  187. self.weight = Parameter(Tensor(np.ones([64, 10], np.float32)), name="weight")
  188. self.bias = Parameter(Tensor(np.ones([10], np.float32)), name="bias")
  189. self.fc = P.MatMul()
  190. self.biasAdd = P.BiasAdd()
  191. def construct(self, x):
  192. x = self.conv1(x)
  193. x = self.bn1(x)
  194. x = self.flatten(x)
  195. x = self.biasAdd(self.fc(x, self.weight), self.bias)
  196. return x
  197. class NetWithLossClass(nn.Cell):
  198. """ NetWithLossClass definition """
  199. def __init__(self, network):
  200. super(NetWithLossClass, self).__init__(auto_prefix=False)
  201. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  202. self.network = network
  203. def construct(self, x, label):
  204. predict = self.network(x)
  205. return self.loss(predict, label)
  206. class BlockNet(nn.Cell):
  207. """ BlockNet definition """
  208. def __init__(self):
  209. super(BlockNet, self).__init__()
  210. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, pad_mode="pad", padding=3)
  211. self.bn1 = nn.BatchNorm2d(64)
  212. self.relu = nn.ReLU()
  213. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
  214. self.block_down_sample = ResidualBlock(
  215. 64, 256, stride=1, down_sample=True
  216. )
  217. self.flatten = P.Flatten()
  218. self.weight = Parameter(Tensor(np.ones([1024, 10]).astype(np.float32)), name="weight")
  219. self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias")
  220. self.fc = P.MatMul()
  221. self.biasAdd = P.BiasAdd()
  222. def construct(self, x):
  223. x = self.conv1(x)
  224. return x
  225. class Conv2dWithBiasNet(nn.Cell):
  226. """ Conv2dWithBiasNet definition """
  227. def __init__(self):
  228. super(Conv2dWithBiasNet, self).__init__()
  229. self.conv = nn.Conv2d(3, 10, 1, bias_init='zeros')
  230. self.flatten = P.Flatten()
  231. def construct(self, input_x):
  232. return self.flatten(self.conv(input_x))
  233. class Conv2dNativeNet(nn.Cell):
  234. """ Conv2dNativeNet definition """
  235. def __init__(self):
  236. super(Conv2dNativeNet, self).__init__()
  237. self.conv = P.DepthwiseConv2dNative(channel_multiplier=3, kernel_size=(3, 3))
  238. self.flatten = P.Flatten()
  239. channel_multipliers = 1
  240. in_channels = 3
  241. kernel_size = (3, 3)
  242. self.weight = Parameter(initializer(
  243. Tensor(np.ones([channel_multipliers, in_channels, *kernel_size], dtype=np.float32)),
  244. [channel_multipliers, in_channels, *kernel_size]), name='weight')
  245. def construct(self, input_x):
  246. return self.flatten(self.conv(input_x, self.weight))
  247. class MakeRefKeyNet(nn.Cell):
  248. """ MakeRefKeyNet definition """
  249. def __init__(self):
  250. super(MakeRefKeyNet, self).__init__()
  251. self.y = Parameter(Tensor([1.0], mindspore.float32), name="y")
  252. def construct(self, x):
  253. key = P.MakeRefKey("y")()
  254. P.Assign()(key, x)
  255. return x
  256. class StateNet(nn.Cell):
  257. """ StateTestTensor definition """
  258. def __init__(self):
  259. super(StateNet, self).__init__()
  260. weight = Tensor(np.ones([2, 1, 2, 2], np.float32))
  261. self.s1 = Parameter(weight, name="s1")
  262. self.s2 = Parameter(weight, name="s2")
  263. self.sub = P.Sub()
  264. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  265. self.assign = P.Assign()
  266. def construct(self, x):
  267. x = Primitive('depend')(x, self.assign(self.s1, x + self.s1))
  268. self.s1 = self.sub(self.s1, x)
  269. self.s2 = self.sub(self.s2, x)
  270. return x
  271. class ComparisonNet(nn.Cell):
  272. def __init__(self):
  273. """ ComparisonNet definition """
  274. super(ComparisonNet, self).__init__()
  275. def construct(self, x, y):
  276. ret = x <= y
  277. return ret
  278. def test_max_pool_with_arg_max():
  279. class NetMaxPoolWithArgMax(nn.Cell):
  280. def __init__(self):
  281. """ ComparisonNet definition """
  282. super(NetMaxPoolWithArgMax, self).__init__()
  283. self.max_pool_with_arg_max = P.MaxPoolWithArgmax(padding="valid", ksize=2, strides=1)
  284. def construct(self, x):
  285. ret = self.max_pool_with_arg_max(x)
  286. return ret
  287. x = Tensor(np.ones([1, 1, 3, 3], np.float32))
  288. net = NetMaxPoolWithArgMax()
  289. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  290. ret = net(x)
  291. print(ret)
  292. class GradWrapUnfold(nn.Cell):
  293. """ GradWrapUnfold definition """
  294. def __init__(self, network):
  295. super(GradWrapUnfold, self).__init__()
  296. self.network = network
  297. self.sens = Tensor(np.ones([1, 4, 2, 2], np.float32))
  298. def construct(self, x):
  299. return C.grad_all_with_sens(self.network)(x, self.sens)
  300. class UnfoldNetValid(nn.Cell):
  301. """ UnfoldNetValid definition """
  302. def __init__(self):
  303. super(UnfoldNetValid, self).__init__()
  304. self.unfold = nn.Unfold(ksizes=[1, 2, 2, 1],
  305. strides=[1, 1, 1, 1],
  306. rates=[1, 1, 1, 1],
  307. padding='VALID')
  308. def construct(self, x):
  309. return self.unfold(x)
  310. class UnfoldNetSame(nn.Cell):
  311. """ UnfoldNetSame definition """
  312. def __init__(self):
  313. super(UnfoldNetSame, self).__init__()
  314. self.unfold = nn.Unfold(ksizes=[1, 2, 2, 1],
  315. strides=[1, 1, 1, 1],
  316. rates=[1, 1, 1, 1],
  317. padding='SAME')
  318. def construct(self, x):
  319. return self.unfold(x)
  320. test_cases = [
  321. ('SoftMaxGrad', {
  322. 'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())),
  323. 'desc_inputs': [[128, 32, 32, 64]],
  324. 'desc_bprop': [[128, 32, 32, 64]],
  325. }),
  326. ('DropoutGrad', {
  327. 'block': DropoutGrad(VirtualNetWithLoss(nn.Dropout())),
  328. 'desc_inputs': [[128, 32, 32, 64]],
  329. 'desc_bprop': [[128, 32, 32, 64]],
  330. }),
  331. ('ScalarSummary', {
  332. 'block': ScalarSummaryNet(),
  333. 'desc_inputs': [2.2],
  334. }),
  335. ('HistogramSummary', {
  336. 'block': HistogramSummaryNet(),
  337. 'desc_inputs': [[1,2,3]],
  338. }),
  339. ('FusedBatchNormGrad', {
  340. 'block': FusedBatchNormGrad(nn.BatchNorm2d(num_features=512, eps=1e-5, momentum=0.1)),
  341. 'desc_inputs': [[64, 512, 7, 7], [64, 512, 7, 7]],
  342. 'desc_bprop': [[64, 512, 7, 7]],
  343. }),
  344. ('BatchnormGrad', {
  345. 'block': Grad(NetWithLoss(BatchnormNet())),
  346. 'desc_inputs': [Tensor(np.ones([1, 3, 8, 8], np.float32)), Tensor(np.zeros([1, 10], np.float32))],
  347. }),
  348. ('BlockGrad', {
  349. 'block': Grad(NetWithLossClass(BlockNet())),
  350. 'desc_inputs': [Tensor(np.ones([1, 3, 8, 8], np.float32)), Tensor(np.zeros([1, 64, 4, 4], np.float32))],
  351. }),
  352. ('Conv2dWithBiasGrad', {
  353. 'block': Grad(NetWithLossClass(Conv2dWithBiasNet())),
  354. 'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 2560], np.float32))],
  355. }),
  356. ('Conv2dNativeGrad', {
  357. 'block': Grad(NetWithLossClass(Conv2dNativeNet())),
  358. 'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 1764], np.float32))],
  359. }),
  360. ('MakeRefKey', {
  361. 'block': MakeRefKeyNet(),
  362. 'desc_inputs': [Tensor([2.0], mindspore.float32)],
  363. }),
  364. ('StateTest', {
  365. 'block': StateNet(),
  366. 'desc_inputs': [Tensor(np.ones([2, 1, 2, 2]).astype(np.float32))],
  367. }),
  368. ('StateGrad', {
  369. 'block': Grad(NetWithLossClass(StateNet())),
  370. 'desc_inputs': [Tensor(np.ones([2, 1, 2, 2], np.float32)), Tensor(np.ones([2, 1, 2, 2], np.float32))],
  371. }),
  372. ('ComparisonTest', {
  373. 'block': ComparisonNet(),
  374. 'desc_inputs': [Tensor(np.ones([6, 9, 10], np.int32)), Tensor(np.ones([6, 9, 10], np.int32))],
  375. }),
  376. ('UnfoldValid', {
  377. 'block': UnfoldNetValid(),
  378. 'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))],
  379. 'desc_bprop': [Tensor(np.ones([1, 4, 2, 2], np.float32))],
  380. 'skip': ['backward']}),
  381. ('UnfoldSame', {
  382. 'block': UnfoldNetSame(),
  383. 'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))],
  384. 'desc_bprop': [Tensor(np.ones([1, 4, 3, 3], np.float32))],
  385. 'skip': ['backward']}),
  386. ('UnfoldGrad', {
  387. 'block': GradWrapUnfold(UnfoldNetValid()),
  388. 'desc_inputs': [Tensor(np.ones([1, 1, 3, 3], np.float32))],
  389. 'desc_bprop': [Tensor(np.ones([1, 4, 2, 2], np.float32))],
  390. 'skip': ['backward']}),
  391. ]
  392. test_cases_for_verify_exception = [
  393. ('ApplyMomentum_Error', {
  394. 'block': (P.ApplyMomentum(), {'exception': TypeError}),
  395. 'desc_inputs': [[2], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64]],
  396. 'desc_bprop': [[128, 32, 32, 64]],
  397. 'skip': ['backward']
  398. }),
  399. ('Conv2d_ValueError_1', {
  400. 'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': TypeError}),
  401. 'desc_inputs': [0],
  402. }),
  403. ('Conv2d_ValueError_2', {
  404. 'block': (lambda _: P.Conv2D(3, 4, mode=-2), {'exception': ValueError}),
  405. 'desc_inputs': [0],
  406. }),
  407. ('MaxPoolWithArgmax_ValueError_1', {
  408. 'block': (lambda _: P.MaxPoolWithArgmax(padding='sane'), {'exception': ValueError}),
  409. 'desc_inputs': [0],
  410. }),
  411. ('MaxPoolWithArgmax_ValueError_2', {
  412. 'block': (lambda _: P.MaxPoolWithArgmax(ksize='1'), {'exception': TypeError}),
  413. 'desc_inputs': [0],
  414. }),
  415. ('MaxPoolWithArgmax_ValueError_3', {
  416. 'block': (lambda _: P.MaxPoolWithArgmax(ksize=-2), {'exception': ValueError}),
  417. 'desc_inputs': [0],
  418. }),
  419. ('MaxPoolWithArgmax_ValueError_4', {
  420. 'block': (lambda _: P.MaxPoolWithArgmax(strides=-1), {'exception': ValueError}),
  421. 'desc_inputs': [0],
  422. }),
  423. ('FusedBatchNorm_ValueError_1', {
  424. 'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': TypeError}),
  425. 'desc_inputs': [0],
  426. }),
  427. ('FusedBatchNorm_ValueError_2', {
  428. 'block': (lambda _: P.FusedBatchNorm(mode=2, epsilon=1e-5, momentum=0.1), {'exception': ValueError}),
  429. 'desc_inputs': [0],
  430. }),
  431. ('FusedBatchNorm_ValueError_3', {
  432. 'block': (lambda _: P.FusedBatchNorm(mode=0, epsilon=-1e-5, momentum=0.1), {'exception': ValueError}),
  433. 'desc_inputs': [0],
  434. }),
  435. ('FusedBatchNorm_ValueError_4', {
  436. 'block': (lambda _: P.FusedBatchNorm(mode=0, epsilon=1e-5, momentum=-0.1), {'exception': ValueError}),
  437. 'desc_inputs': [0],
  438. }),
  439. ('FusedBatchNorm_ValueError_5', {
  440. 'block': (lambda _: P.FusedBatchNorm(mode=1, epsilon=-0.001, momentum=0.0), {'exception': ValueError}),
  441. 'desc_inputs': [0],
  442. }),
  443. ('Softmax_ValueError_1', {
  444. 'block': (lambda _: P.Softmax("1"), {'exception': TypeError}),
  445. 'desc_inputs': [0],
  446. }),
  447. ('Softmax_ValueError_2', {
  448. 'block': (lambda _: P.Softmax(1.1), {'exception': TypeError}),
  449. 'desc_inputs': [0],
  450. }),
  451. ('Softmax_ValueError_3', {
  452. 'block': (lambda _: P.Softmax(axis="1"), {'exception': TypeError}),
  453. 'desc_inputs': [0],
  454. }),
  455. ('DropoutGenMask_ValueError_1', {
  456. 'block': (lambda _: P.DropoutGenMask(Seed0="seed0"), {'exception': TypeError}),
  457. 'desc_inputs': [0],
  458. }),
  459. ('DropoutGenMask_ValueError_2', {
  460. 'block': (lambda _: P.DropoutGenMask(Seed0=1.0), {'exception': TypeError}),
  461. 'desc_inputs': [0],
  462. }),
  463. ('DropoutGenMask_ValueError_3', {
  464. 'block': (lambda _: P.DropoutGenMask(Seed1="seed1"), {'exception': TypeError}),
  465. 'desc_inputs': [0],
  466. }),
  467. ('DropoutGenMask_ValueError_4', {
  468. 'block': (lambda _: P.DropoutGenMask(Seed1=2.0), {'exception': TypeError}),
  469. 'desc_inputs': [0],
  470. }),
  471. ('MaxPool2d_ValueError_1', {
  472. 'block': (nn.MaxPool2d(kernel_size=120, stride=1, pad_mode="valid"), {'exception': ValueError}),
  473. 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
  474. }),
  475. ('MaxPool2d_ValueError_2', {
  476. 'block': (
  477. lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"),
  478. {'exception': TypeError},
  479. ),
  480. 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
  481. }),
  482. ('MaxPool2d_ValueError_3', {
  483. 'block': (
  484. lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"),
  485. {'exception': TypeError},
  486. ),
  487. 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
  488. }),
  489. ]
  490. @non_graph_engine
  491. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
  492. def test_compile():
  493. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  494. return test_cases
  495. @mindspore_test(pipeline_for_verify_exception_for_case_by_case_config)
  496. def test_check_exception():
  497. return test_cases_for_verify_exception