You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset_interface.py 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from mindspore.train import Model, ParallelMode
  15. from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
  16. from mindspore.nn.optim.momentum import Momentum
  17. from mindspore import Tensor
  18. import mindspore as ms
  19. import numpy as np
  20. import mindspore.nn as nn
  21. from tests.dataset_mock import MindData
  22. from mindspore import context
  23. from mindspore.train.loss_scale_manager import DynamicLossScaleManager
  24. from mindspore.ops import composite as C, functional as F, operations as P
  25. from mindspore.common.parameter import Parameter, ParameterTuple
  26. context.set_context(mode=context.GRAPH_MODE)
  27. class Dataset(MindData):
  28. def __init__(self, predict, label, length=3):
  29. super(Dataset, self).__init__(size=length)
  30. self.predict = predict
  31. self.label = label
  32. self.index = 0
  33. self.length = length
  34. def __iter__(self):
  35. return self
  36. def __next__(self):
  37. if self.index >= self.length:
  38. raise StopIteration
  39. self.index += 1
  40. return self.predict, self.label
  41. def reset(self):
  42. self.index = 0
  43. class AllToAllNet(nn.Cell):
  44. def __init__(self, strategy1):
  45. super(AllToAllNet, self).__init__()
  46. self.matmul = P.MatMul().set_strategy(((1, 1), (1, 8)))
  47. self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight")
  48. self.transpose1 = P.Transpose().set_strategy(strategy1)
  49. def construct(self, x):
  50. x = self.matmul(x, self.matmul_weight)
  51. x = self.transpose1(x, (1, 0))
  52. return x
  53. def all_to_all_net(strategy1):
  54. return AllToAllNet(strategy1=strategy1)
  55. def loss_scale_manager_common(strategy1):
  56. learning_rate = 0.1
  57. momentum = 0.9
  58. epoch_size = 2
  59. context.reset_auto_parallel_context()
  60. context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=8)
  61. predict = Tensor(np.ones([32, 128]), dtype=ms.float32)
  62. label = Tensor(np.ones([32]), dtype=ms.int32)
  63. dataset = Dataset(predict, label, 2)
  64. net = all_to_all_net(strategy1)
  65. loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  66. loss.softmax_cross_entropy.set_strategy(((8, 1), (8, 1)))
  67. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  68. scale_manager = DynamicLossScaleManager(32, 2, 2000)
  69. model = Model(net, loss, opt, loss_scale_manager=scale_manager)
  70. # if no GE exists, outputs = self._train_network(*next_element) outputs is None, TypeError is caught.
  71. try:
  72. model.train(epoch_size, dataset, dataset_sink_mode=False)
  73. except TypeError:
  74. pass
  75. else:
  76. assert False
  77. def test_dataset_interface_sens_scalar():
  78. strategy1 = ((8, 1), )
  79. loss_scale_manager_common(strategy1)
  80. class TrainOneStepCell(nn.Cell):
  81. def __init__(self, network, optimizer, sens=1.0):
  82. super(TrainOneStepCell, self).__init__(auto_prefix=False)
  83. self.network = network
  84. self.network.add_flags(defer_inline=True)
  85. self.weights = ParameterTuple(network.trainable_params())
  86. self.optimizer = optimizer
  87. self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
  88. def construct(self, data, sens):
  89. weights = self.weights
  90. loss = self.network(data)
  91. grads = self.grad(self.network, weights)(data, sens)
  92. return F.depend(loss, self.optimizer(grads))
  93. def loss_scale_manager_sens(strategy1, sens):
  94. learning_rate = 0.1
  95. momentum = 0.9
  96. device_num = 8
  97. context.reset_auto_parallel_context()
  98. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num)
  99. predict = Tensor(np.ones([32 * device_num, 128]), dtype=ms.float32)
  100. net = all_to_all_net(strategy1)
  101. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  102. train_net = TrainOneStepCell(net, opt)
  103. train_net.set_train()
  104. train_net(predict, sens)
  105. def test_dataset_interface_sens_shape_not_equal_loss():
  106. strategy1 = ((8, 1), )
  107. sens = Tensor(np.ones([256, 1024]), dtype=ms.float32)
  108. try:
  109. loss_scale_manager_sens(strategy1, sens)
  110. except:
  111. pass
  112. def test_dataset_interface_sens_shape_equal_loss():
  113. strategy1 = ((4, 2), )
  114. sens = Tensor(np.ones([256, 256]), dtype=ms.float32)
  115. loss_scale_manager_sens(strategy1, sens)
  116. def test_input_not_in_parameter_layotu_dict():
  117. class Net(nn.Cell):
  118. def __init__(self, strategy1):
  119. super(Net, self).__init__()
  120. self.matmul = P.MatMul().set_strategy(((1, 1), (1, 8)))
  121. self.matmul_weight = Parameter(Tensor(np.ones([128, 256]), dtype=ms.float32), name="weight")
  122. self.transpose1 = P.Transpose().set_strategy(strategy1)
  123. def construct(self, x, b):
  124. x = self.matmul(x, self.matmul_weight)
  125. x = self.transpose1(x, (1, 0))
  126. return x
  127. strategy1 = ((8, 1), )
  128. device_num = 8
  129. context.reset_auto_parallel_context()
  130. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num)
  131. predict = Tensor(np.ones([32 * device_num, 128]), dtype=ms.float32)
  132. b = Tensor(np.ones([32 * device_num, 128]), dtype=ms.float32)
  133. net = Net(strategy1)
  134. net.set_train()
  135. net(predict, b)