You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reshape.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from mindspore.train import Model, ParallelMode
  15. from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
  16. from mindspore.nn.optim.momentum import Momentum
  17. from mindspore import Tensor
  18. import mindspore as ms
  19. import numpy as np
  20. from mindspore.ops import operations as P
  21. import mindspore.nn as nn
  22. from mindspore.common.parameter import Parameter
  23. from tests.dataset_mock import MindData
  24. from mindspore import context
  25. from tests.ut.python.ops.test_math_ops import VirtualLoss
  26. from mindspore.common.api import _executor
  27. from mindspore.ops import composite as C
  28. from mindspore.ops.operations.comm_ops import _VirtualDataset
  29. from mindspore.ops import functional as F
  30. from mindspore.common.parameter import ParameterTuple
  31. from mindspore.common import dtype as mstype
  32. from mindspore.parallel import set_algo_parameters
  33. context.set_context(mode=context.GRAPH_MODE)
  34. context.reset_auto_parallel_context()
  35. class Dataset(MindData):
  36. def __init__(self, predict, label, length=3, input_num=2):
  37. super(Dataset, self).__init__(size=length)
  38. self.predict = predict
  39. self.label = label
  40. self.index = 0
  41. self.length = length
  42. self.input_num = input_num
  43. def __iter__(self):
  44. return self
  45. def __next__(self):
  46. if self.index >= self.length:
  47. raise StopIteration
  48. self.index += 1
  49. if self.input_num == 2:
  50. return self.predict, self.label
  51. else:
  52. return self.predict,
  53. def reset(self):
  54. self.index = 0
  55. class ReshapeNet(nn.Cell):
  56. def __init__(self, strategy0, strategy1, strategy2):
  57. super(ReshapeNet, self).__init__()
  58. self.relu = P.ReLU().set_strategy(strategy0)
  59. self.reshape = P.Reshape().set_strategy(strategy1)
  60. self.matmul = P.MatMul().set_strategy(strategy2)
  61. self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
  62. def construct(self, x):
  63. x = self.relu(x)
  64. x = self.reshape(x, (256, 25088))
  65. x = self.matmul(x, self.matmul_weight)
  66. return x
  67. def reshape_net(strategy0, strategy1, strategy2):
  68. return ReshapeNet(strategy0=strategy0, strategy1=strategy1, strategy2=strategy2)
  69. def reshape_common(parallel_mode, strategy0, strategy1, strategy2, strategy_loss):
  70. batch_size = 32
  71. learning_rate = 0.1
  72. momentum = 0.9
  73. epoch_size = 2
  74. context.reset_auto_parallel_context()
  75. context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
  76. predict = Tensor(np.ones([32, 512, 7, 7]), dtype=ms.float32)
  77. label = Tensor(np.ones([32]), dtype=ms.int32)
  78. dataset = Dataset(predict, label, 2)
  79. net = reshape_net(strategy0, strategy1, strategy2)
  80. loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  81. loss.softmax_cross_entropy.set_strategy(strategy_loss)
  82. loss.one_hot.set_strategy(((8,1), (), ()))
  83. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  84. model = Model(net, loss, opt)
  85. model.train(epoch_size, dataset, dataset_sink_mode=False)
  86. def test_reshape1():
  87. strategy0 = ((8, 1, 1, 1), )
  88. strategy1 = None
  89. strategy2 = ((8, 1), (1, 1))
  90. strategy_loss = ((8, 1), (8, 1))
  91. reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
  92. def test_reshape1_strategy_1():
  93. strategy0 = ((8, 1, 1, 1), )
  94. strategy1 = ((8, 1, 1, 1), )
  95. strategy2 = ((8, 1), (1, 1))
  96. strategy_loss = ((8, 1), (8, 1))
  97. try:
  98. reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
  99. except:
  100. pass
  101. def test_reshape1_strategy_2():
  102. strategy0 = ((8, 1, 1, 1), )
  103. strategy1 = ((8, 1, 1, 1), )
  104. strategy2 = ((8, 1), (1, 1))
  105. strategy_loss = ((8, 1), (8, 1))
  106. try:
  107. reshape_common(ParallelMode.AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
  108. except:
  109. pass
  110. def test_reshape2():
  111. strategy0 = ((8, 1, 1, 1), )
  112. strategy1 = None
  113. strategy2 = ((8, 1), (1, 1))
  114. strategy_loss = ((8, 1), (8, 1))
  115. reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
  116. def test_reshape3():
  117. strategy0 = ((2, 1, 1, 1), )
  118. strategy1 = None
  119. strategy2 = ((8, 1), (1, 1))
  120. strategy_loss = ((8, 1), (8, 1))
  121. reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
  122. def test_reshape4():
  123. strategy0 = ((1, 1, 1, 1), )
  124. strategy1 = None
  125. strategy2 = ((8, 1), (1, 1))
  126. strategy_loss = ((8, 1), (8, 1))
  127. reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
  128. def test_reshape5():
  129. strategy0 = ((2, 1, 1, 1), )
  130. strategy1 = None
  131. strategy2 = ((1, 8), (8, 1))
  132. strategy_loss = ((8, 1), (8, 1))
  133. reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
  134. def test_reshape_auto():
  135. strategy0 = None
  136. strategy1 = None
  137. strategy2 = None
  138. strategy_loss = None
  139. reshape_common(ParallelMode.AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
  140. class NetWithLoss(nn.Cell):
  141. def __init__(self, network):
  142. super(NetWithLoss, self).__init__()
  143. self.loss = VirtualLoss()
  144. self.network = network
  145. def construct(self, x):
  146. predict = self.network(x)
  147. return self.loss(predict)
  148. class GradWrap(nn.Cell):
  149. def __init__(self, network):
  150. super(GradWrap, self).__init__()
  151. self.network = network
  152. def construct(self, x):
  153. return C.grad_all(self.network)(x)
  154. class ReshapeNet1(nn.Cell):
  155. def __init__(self, strategy0):
  156. super(ReshapeNet1, self).__init__()
  157. self.virtual_dataset = _VirtualDataset()
  158. self.reshape = P.Reshape()
  159. self.matmul = P.MatMul().set_strategy(strategy0)
  160. self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
  161. self.reshape2 = P.Reshape()
  162. def construct(self, x):
  163. x = self.virtual_dataset(x)
  164. x = self.reshape(x, (256, 25088))
  165. x = self.matmul(x, self.matmul_weight)
  166. x = self.reshape2(x, (256 * 256,))
  167. return x
  168. class ReshapeNet2(nn.Cell):
  169. def __init__(self, strategy0):
  170. super(ReshapeNet2, self).__init__()
  171. self.virtual_dataset = _VirtualDataset()
  172. self.reshape = P.Reshape()
  173. self.matmul = P.MatMul().set_strategy(strategy0)
  174. self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
  175. self.reshape2 = P.Reshape()
  176. self.reduce_sum = P.ReduceSum(keep_dims=True)
  177. self.reshape3 = P.Reshape()
  178. def construct(self, x):
  179. x = self.virtual_dataset(x)
  180. x = self.reshape(x, (256, 25088))
  181. x = self.matmul(x, self.matmul_weight)
  182. x = self.reshape2(x, (256 * 256,))
  183. x = self.reduce_sum(x, -1)
  184. x = self.reshape3(x, ())
  185. return x
  186. class ReshapeNet3(nn.Cell):
  187. def __init__(self, strategy0):
  188. super(ReshapeNet3, self).__init__()
  189. self.virtual_dataset = _VirtualDataset()
  190. self.reshape = P.Reshape()
  191. self.matmul = P.MatMul().set_strategy(strategy0)
  192. self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
  193. self.reshape2 = P.Reshape()
  194. self.reduce_sum = P.ReduceSum(keep_dims=False)
  195. self.reshape3 = P.Reshape()
  196. def construct(self, x):
  197. x = self.virtual_dataset(x)
  198. x = self.reshape(x, (256, 25088))
  199. x = self.matmul(x, self.matmul_weight)
  200. x = self.reshape2(x, (256 * 256,))
  201. x = self.reduce_sum(x, -1)
  202. x = self.reshape3(x, (1, 1))
  203. return x
  204. class ReshapeNet4(nn.Cell):
  205. def __init__(self, strategy0):
  206. super(ReshapeNet4, self).__init__()
  207. self.virtual_dataset = _VirtualDataset()
  208. self.reshape = P.Reshape()
  209. self.reshape2 = P.Reshape()
  210. self.matmul = P.MatMul().set_strategy(strategy0)
  211. self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
  212. def construct(self, x):
  213. x = self.virtual_dataset(x)
  214. x = self.reshape(x, (256, 25088))
  215. w = self.reshape2(self.matmul_weight, (25088, 256))
  216. x = self.matmul(x, w)
  217. return x
  218. class ReshapeNet5(nn.Cell):
  219. def __init__(self, strategy0):
  220. super(ReshapeNet5, self).__init__()
  221. self.virtual_dataset = _VirtualDataset()
  222. self.reshape = P.Reshape()
  223. self.matmul1 = P.MatMul().set_strategy(strategy0)
  224. self.matmul1_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
  225. self.matmul2 = P.MatMul().set_strategy(strategy0)
  226. def construct(self, x):
  227. x = self.virtual_dataset(x)
  228. x = self.reshape(x, (256, 25088))
  229. matmul1_o = self.matmul1(x, self.matmul1_weight)
  230. matmul2_o = self.matmul2(matmul1_o, x)
  231. return matmul2_o
  232. class ReshapeNet6(nn.Cell):
  233. def __init__(self, strategy0):
  234. super(ReshapeNet6, self).__init__()
  235. self.virtual_dataset = _VirtualDataset()
  236. self.reshape = P.Reshape()
  237. self.matmul1_1 = P.MatMul().set_strategy(strategy0)
  238. self.matmul1_2 = P.MatMul().set_strategy(strategy0)
  239. self.matmul1_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
  240. self.matmul2 = P.MatMul().set_strategy(strategy0)
  241. self.add = P.TensorAdd()
  242. def construct(self, x):
  243. x = self.virtual_dataset(x)
  244. x = self.reshape(x, (256, 25088))
  245. matmul1_1_o = self.matmul1_1(x, self.matmul1_weight)
  246. matmul1_2_o = self.matmul1_2(x, self.matmul1_weight)
  247. matmul1_o = self.add(matmul1_1_o, matmul1_2_o)
  248. matmul2_o = self.matmul2(matmul1_o, x)
  249. return matmul2_o
  250. def reshape_net2(backbone):
  251. batch_size = 16
  252. device_num = 16
  253. context.set_auto_parallel_context(device_num=device_num, global_rank=0)
  254. input = Tensor(np.ones([batch_size * device_num, 512, 7, 7]).astype(np.float32) * 0.01)
  255. net = GradWrap(NetWithLoss(backbone))
  256. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  257. _executor.compile(net, input)
  258. def test_reshape_net1_1():
  259. reshape_net2(ReshapeNet1(((1, 8), (8, 1))))
  260. def test_reshape_net1_2():
  261. reshape_net2(ReshapeNet1(((1, 8), (8, 2))))
  262. def test_reshape_net2_1():
  263. reshape_net2(ReshapeNet2(((1, 8), (8, 1))))
  264. def test_reshape_net2_2():
  265. reshape_net2(ReshapeNet2(((1, 8), (8, 2))))
  266. def test_reshape_net3_1():
  267. reshape_net2(ReshapeNet3(((1, 8), (8, 1))))
  268. def test_reshape_net3_2():
  269. reshape_net2(ReshapeNet3(((1, 8), (8, 2))))
  270. def test_reshape_net4_1():
  271. try:
  272. reshape_net2(ReshapeNet4(((1, 8), (8, 1))))
  273. except:
  274. pass
  275. def test_reshape_net4_2():
  276. try:
  277. reshape_net2(ReshapeNet4(((1, 8), (8, 2))))
  278. except:
  279. pass
  280. def test_reshape_net5_1():
  281. reshape_net2(ReshapeNet5(((1, 8), (8, 1))))
  282. def test_reshape_net5_2():
  283. reshape_net2(ReshapeNet5(((1, 8), (8, 2))))
  284. def test_reshape_net6_1():
  285. reshape_net2(ReshapeNet6(((1, 8), (8, 1))))
  286. def test_reshape_net6_2():
  287. reshape_net2(ReshapeNet6(((1, 8), (8, 2))))
  288. class TrainOneStepCell(nn.Cell):
  289. """
  290. Network training package class.
  291. Append an optimizer to the training network after that the construct function
  292. can be called to create the backward graph.
  293. Args:
  294. network (Cell): The training network.
  295. optimizer (Cell): Optimizer for updating the weights.
  296. sens (Number): The adjust parameter. Default: 1.0.
  297. Examples:
  298. >>> net = Net()
  299. >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
  300. >>> optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  301. >>> loss_net = WithLossCell(net, loss_fn)
  302. >>> train_net = TrainOneStepCell(loss_net, optim)
  303. """
  304. def __init__(self, network, optimizer, sens=1.0):
  305. super(TrainOneStepCell, self).__init__(auto_prefix=False)
  306. self.network = network
  307. self.network.add_flags(defer_inline=True)
  308. self.weights = ParameterTuple(network.trainable_params())
  309. self.optimizer = optimizer
  310. self.grad = C.GradOperation('grad',
  311. get_by_list=True,
  312. sens_param=True)
  313. self.sens = sens
  314. def construct(self, data):
  315. weights = self.weights
  316. loss = self.network(data)
  317. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  318. grads = self.grad(self.network, weights)(data, sens)
  319. return F.depend(loss, self.optimizer(grads))
  320. def reshape_common2(parallel_mode, net):
  321. batch_size = 16
  322. learning_rate = 0.1
  323. momentum = 0.9
  324. epoch_size = 2
  325. predict = Tensor(np.ones([batch_size, 512, 7, 7]), dtype=ms.float32)
  326. label = Tensor(np.ones([batch_size]), dtype=ms.int32)
  327. dataset = Dataset(predict, label, 2, input_num=1)
  328. context.reset_auto_parallel_context()
  329. context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=16)
  330. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  331. train_net = TrainOneStepCell(net, opt).set_train()
  332. model = Model(train_net)
  333. model.train(epoch_size, dataset, dataset_sink_mode=False)
  334. def test_reshape_common2_0():
  335. reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet1(((1, 8), (8, 1))))
  336. def test_reshape_common2_1():
  337. reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet1(((1, 8), (8, 2))))
  338. def test_reshape_common2_2():
  339. reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet2(((1, 8), (8, 1))))
  340. def test_reshape_common2_3():
  341. reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet2(((1, 8), (8, 2))))
  342. def test_reshape_common2_4():
  343. reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet3(((1, 8), (8, 1))))
  344. def test_reshape_common2_5():
  345. reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet3(((1, 8), (8, 2))))
  346. class BatchNormReshapeNet(nn.Cell):
  347. def __init__(self):
  348. super(BatchNormReshapeNet, self).__init__()
  349. self.vd = P._VirtualDataset()
  350. self.batch_norm = nn.BatchNorm1d(512, affine=False)
  351. self.reshape = P.Reshape()
  352. self.prelu = nn.PReLU(channel=256)
  353. def construct(self, x):
  354. x = self.vd(x)
  355. x = self.batch_norm(x)
  356. x = self.reshape(x, (512, 256))
  357. x = self.prelu(x)
  358. return x
  359. def test_batchnorm_reshape_train():
  360. batch_size = 16
  361. device_num = 16
  362. context.set_auto_parallel_context(device_num=device_num, global_rank=0)
  363. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  364. input = Tensor(np.ones([batch_size * device_num, 512]).astype(np.float32) * 0.01)
  365. net = GradWrap(NetWithLoss(BatchNormReshapeNet()))
  366. _executor.compile(net, input)
  367. def bn_with_initialize(out_channels):
  368. bn = nn.BatchNorm2d(out_channels, momentum=0.3, eps=1e-5).add_flags_recursive(fp32=True)
  369. return bn
  370. def fc_with_initialize(input_channels, out_channels):
  371. return nn.Dense(input_channels, out_channels).add_flags_recursive(fp16=True)
  372. class BNReshapeDenseBNNet(nn.Cell):
  373. def __init__(self):
  374. super(BNReshapeDenseBNNet, self).__init__()
  375. self.batch_norm = bn_with_initialize(2)
  376. self.reshape = P.Reshape()
  377. self.cast = P.Cast()
  378. self.batch_norm2 = nn.BatchNorm1d(512, affine=False)
  379. self.fc = fc_with_initialize(2 * 32 * 32, 512)
  380. def construct(self, x):
  381. x = self.batch_norm(x)
  382. x = self.reshape(x, (16, 2*32*32))
  383. x = self.fc(x)
  384. x = self.batch_norm2(x)
  385. return x
  386. def test_bn_reshape_dense_bn_train():
  387. batch_size = 16
  388. device_num = 16
  389. context.set_auto_parallel_context(device_num=device_num, global_rank=0)
  390. input = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01)
  391. net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
  392. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  393. _executor.compile(net, input)
  394. class ParallelReduceMeanNet(nn.Cell):
  395. def __init__(self, conv_in_channel, conv_out_channel,
  396. reducemean_keep_dims=False, reducemean_axis=-1, strategy=None):
  397. super().__init__()
  398. self.conv = nn.Conv2d(in_channels=conv_in_channel, out_channels=conv_out_channel,
  399. kernel_size=1, stride=1, pad_mode='valid', has_bias=True,
  400. weight_init='ones', bias_init='ones')
  401. self.reduce_mean = P.ReduceMean(keep_dims=reducemean_keep_dims)
  402. self.flat = nn.Flatten()
  403. self.reducemean_axis = reducemean_axis
  404. if strategy is not None:
  405. self.reduce_mean.set_strategy(strategy)
  406. def construct(self, inputs):
  407. x = self.conv(inputs)
  408. x = self.reduce_mean(x, self.reducemean_axis)
  409. x = self.flat(x)
  410. return x
  411. class CrossEntropyLoss(nn.Cell):
  412. def __init__(self, reduction='mean'):
  413. super(CrossEntropyLoss, self).__init__()
  414. self.reduce_mean = P.ReduceMean()
  415. self.cross_entropy = SoftmaxCrossEntropyWithLogits()
  416. self.reduction = reduction
  417. def construct(self, logits, label):
  418. loss = self.cross_entropy(logits, label)
  419. if self.reduction == 'mean':
  420. loss = self.reduce_mean(loss, (-1,))
  421. return loss
  422. def test_flatten_reshape(parallel_mode="auto_parallel"):
  423. batch_size = 16
  424. learning_rate = 0.1
  425. momentum = 0.9
  426. epoch_size = 2
  427. context.reset_auto_parallel_context()
  428. context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
  429. net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_axis=(2, 3), strategy=((4, 2, 1, 1),))
  430. loss = CrossEntropyLoss()
  431. predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32)
  432. label = Tensor(np.ones([batch_size, 64]), dtype=ms.float32)
  433. dataset = Dataset(predict, label, 2, input_num=2)
  434. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  435. model = Model(net, loss_fn = loss, optimizer=opt)
  436. model.train(epoch_size, dataset, dataset_sink_mode=False)
  437. def test_flatten_reshape2(parallel_mode="auto_parallel"):
  438. batch_size = 16
  439. learning_rate = 0.1
  440. momentum = 0.9
  441. epoch_size = 2
  442. context.reset_auto_parallel_context()
  443. context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
  444. set_algo_parameters(not_fully_use_devices=True)
  445. net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_axis=(2, 3), strategy=((4, 1, 1, 1),))
  446. loss = CrossEntropyLoss()
  447. predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32)
  448. label = Tensor(np.ones([batch_size, 64]), dtype=ms.float32)
  449. dataset = Dataset(predict, label, 2, input_num=2)
  450. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  451. model = Model(net, loss_fn = loss, optimizer=opt)
  452. model.train(epoch_size, dataset, dataset_sink_mode=False)
  453. class ParallelReshapeNet(nn.Cell):
  454. def __init__(self, dense_in_channel, dense_out_channel, shape, strategy=None):
  455. super().__init__()
  456. self.flat = nn.Flatten()
  457. self.dense = nn.Dense(in_channels=dense_in_channel,
  458. out_channels=dense_out_channel,
  459. weight_init='ones',
  460. bias_init='ones',
  461. has_bias=True)
  462. self.reshape = P.Reshape()
  463. self.shape = shape
  464. self.reshape.set_strategy(strategy)
  465. def construct(self, inputs):
  466. x = self.flat(inputs)
  467. x = self.dense(x)
  468. x = self.reshape(x, self.shape)
  469. return x
  470. # the shape of input and output of reshape is the same
  471. # reshape is optimized before step_parallel
  472. def test_flatten_reshape3(parallel_mode="auto_parallel"):
  473. batch_size = 16
  474. learning_rate = 0.1
  475. momentum = 0.9
  476. epoch_size = 2
  477. context.reset_auto_parallel_context()
  478. context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
  479. set_algo_parameters(not_fully_use_devices=True)
  480. net = ParallelReshapeNet(dense_in_channel=2048, dense_out_channel=1000, shape=(128, 1000), strategy=((16, 1),))
  481. loss = CrossEntropyLoss()
  482. predict = Tensor(np.ones([batch_size, 1, 2, 1024]), dtype=ms.float32)
  483. label = Tensor(np.ones([batch_size, 1000]), dtype=ms.float32)
  484. dataset = Dataset(predict, label, 2, input_num=2)
  485. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  486. model = Model(net, loss_fn = loss, optimizer=opt)
  487. model.train(epoch_size, dataset, dataset_sink_mode=False)
  488. class CrossEntropyLoss2(nn.Cell):
  489. def __init__(self, reduction='mean'):
  490. super(CrossEntropyLoss2, self).__init__()
  491. self.cross_entropy = SoftmaxCrossEntropyWithLogits(reduction=reduction)
  492. def construct(self, logits, label):
  493. loss = self.cross_entropy(logits, label)
  494. return loss
  495. def test_flatten_reshape4(parallel_mode="semi_auto_parallel"):
  496. batch_size = 16
  497. learning_rate = 0.1
  498. momentum = 0.9
  499. epoch_size = 2
  500. context.reset_auto_parallel_context()
  501. context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
  502. set_algo_parameters(not_fully_use_devices=True)
  503. net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_keep_dims=True, strategy=((4, 1, 1, 1),))
  504. loss = CrossEntropyLoss2()
  505. predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32)
  506. label = Tensor(np.ones([batch_size, 2048]), dtype=ms.float32)
  507. dataset = Dataset(predict, label, 2, input_num=2)
  508. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  509. model = Model(net, loss_fn=loss, optimizer=opt)
  510. model.train(epoch_size, dataset, dataset_sink_mode=False)