You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_strategy_checkpoint.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from mindspore import context
  16. from mindspore.context import set_auto_parallel_context
  17. import mindspore.nn as nn
  18. from mindspore.ops import operations as P
  19. from mindspore import Tensor
  20. from tests.ut.python.ops.test_math_ops import VirtualLoss
  21. import mindspore as ms
  22. from mindspore.common.api import _executor
  23. from mindspore.ops import composite as C
  24. # model_parallel test
  25. # export PARALLEL_CHECKPOINT_ON=on
  26. # export PARALLEL_TRAIN_TIMES=4
  27. def test_six_matmul():
  28. class NetWithLoss(nn.Cell):
  29. def __init__(self, network):
  30. super(NetWithLoss, self).__init__()
  31. self.loss = VirtualLoss()
  32. self.network = network
  33. def construct(self, x1, x2, x3, x4, x5, x6, x7):
  34. predict = self.network(x1, x2, x3, x4, x5, x6, x7)
  35. return self.loss(predict)
  36. class GradWrap(nn.Cell):
  37. def __init__(self, network):
  38. super(GradWrap, self).__init__()
  39. self.network = network
  40. def construct(self, x1, x2, x3, x4, x5, x6, x7):
  41. return C.grad_all(self.network)(x1, x2, x3, x4, x5, x6, x7)
  42. class Net(nn.Cell):
  43. def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6):
  44. super().__init__()
  45. self.matmul1 = P.MatMul().set_strategy(strategy1)
  46. self.matmul2 = P.MatMul().set_strategy(strategy2)
  47. self.matmul3 = P.MatMul().set_strategy(strategy3)
  48. self.matmul4 = P.MatMul().set_strategy(strategy4)
  49. self.matmul5 = P.MatMul().set_strategy(strategy5)
  50. self.matmul6 = P.MatMul().set_strategy(strategy6)
  51. def construct(self, x1, x2, x3, x4, x5, x6, x7):
  52. out = self.matmul1(x1, x2)
  53. out = self.matmul2(out, x3)
  54. out = self.matmul3(out, x4)
  55. out = self.matmul4(out, x5)
  56. out = self.matmul5(out, x6)
  57. out = self.matmul6(out, x7)
  58. return out
  59. set_auto_parallel_context(device_num=512, global_rank=0)
  60. strategy1 = ((8, 1), (1, 1))
  61. strategy2 = ((1, 8), (8, 1))
  62. strategy3 = ((2, 2), (2, 2))
  63. strategy4 = ((4, 2), (2, 4))
  64. strategy5 = ((2, 4), (4, 2))
  65. strategy6 = ((4, 4), (4, 4))
  66. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6)))
  67. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  68. x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  69. x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
  70. x3 = Tensor(np.ones([64, 64]), dtype=ms.float32)
  71. x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
  72. x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
  73. x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
  74. x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  75. _executor.compile(net, x1, x2, x3, x4, x5, x6, x7)
  76. # remove matmul2
  77. def test_six_matmul_repeated1():
  78. class NetWithLoss(nn.Cell):
  79. def __init__(self, network):
  80. super(NetWithLoss, self).__init__()
  81. self.loss = VirtualLoss()
  82. self.network = network
  83. def construct(self, x1, x2, x4, x5, x6, x7):
  84. predict = self.network(x1, x2, x4, x5, x6, x7)
  85. return self.loss(predict)
  86. class GradWrap(nn.Cell):
  87. def __init__(self, network):
  88. super(GradWrap, self).__init__()
  89. self.network = network
  90. def construct(self, x1, x2, x4, x5, x6, x7):
  91. return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7)
  92. class Net(nn.Cell):
  93. def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6):
  94. super().__init__()
  95. self.matmul1 = P.MatMul().set_strategy(strategy1)
  96. self.matmul3 = P.MatMul().set_strategy(strategy3)
  97. self.matmul4 = P.MatMul().set_strategy(strategy4)
  98. self.matmul5 = P.MatMul().set_strategy(strategy5)
  99. self.matmul6 = P.MatMul().set_strategy(strategy6)
  100. def construct(self, x1, x2, x4, x5, x6, x7):
  101. out = self.matmul1(x1, x2)
  102. out = self.matmul3(out, x4)
  103. out = self.matmul4(out, x5)
  104. out = self.matmul5(out, x6)
  105. out = self.matmul6(out, x7)
  106. return out
  107. set_auto_parallel_context(device_num=512, global_rank=0)
  108. strategy1 = ((8, 1), (1, 1))
  109. strategy3 = ((8, 1), (1, 1))
  110. strategy4 = ((8, 1), (1, 1))
  111. strategy5 = ((8, 1), (1, 1))
  112. strategy6 = ((8, 1), (1, 1))
  113. net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6)))
  114. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  115. x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  116. x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
  117. x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
  118. x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
  119. x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
  120. x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  121. _executor.compile(net, x1, x2, x4, x5, x6, x7)
  122. # add matmul7
  123. def test_six_matmul_repeated2():
  124. class NetWithLoss(nn.Cell):
  125. def __init__(self, network):
  126. super(NetWithLoss, self).__init__()
  127. self.loss = VirtualLoss()
  128. self.network = network
  129. def construct(self, x1, x2, x4, x5, x6, x7, x8):
  130. predict = self.network(x1, x2, x4, x5, x6, x7, x8)
  131. return self.loss(predict)
  132. class GradWrap(nn.Cell):
  133. def __init__(self, network):
  134. super(GradWrap, self).__init__()
  135. self.network = network
  136. def construct(self, x1, x2, x4, x5, x6, x7, x8):
  137. return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7, x8)
  138. class Net(nn.Cell):
  139. def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7):
  140. super().__init__()
  141. self.matmul1 = P.MatMul().set_strategy(strategy1)
  142. self.matmul3 = P.MatMul().set_strategy(strategy3)
  143. self.matmul4 = P.MatMul().set_strategy(strategy4)
  144. self.matmul5 = P.MatMul().set_strategy(strategy5)
  145. self.matmul6 = P.MatMul().set_strategy(strategy6)
  146. self.matmul7 = P.MatMul().set_strategy(strategy7)
  147. def construct(self, x1, x2, x4, x5, x6, x7, x8):
  148. out = self.matmul1(x1, x2)
  149. out = self.matmul3(out, x4)
  150. out = self.matmul4(out, x5)
  151. out = self.matmul5(out, x6)
  152. out = self.matmul6(out, x7)
  153. out = self.matmul7(out, x8)
  154. return out
  155. set_auto_parallel_context(device_num=512, global_rank=0)
  156. strategy1 = ((8, 1), (1, 1))
  157. strategy3 = ((8, 1), (1, 1))
  158. strategy4 = ((8, 1), (1, 1))
  159. strategy5 = ((8, 1), (1, 1))
  160. strategy6 = ((8, 1), (1, 1))
  161. strategy7 = ((8, 1), (1, 1))
  162. net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)))
  163. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  164. x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  165. x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
  166. x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
  167. x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
  168. x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
  169. x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  170. x8 = Tensor(np.ones([32, 128]), dtype=ms.float32)
  171. _executor.compile(net, x1, x2, x4, x5, x6, x7, x8)
  172. # add scope2
  173. def test_six_matmul_repeated3():
  174. class NetWithLoss(nn.Cell):
  175. def __init__(self, network1, network2):
  176. super(NetWithLoss, self).__init__()
  177. self.loss = VirtualLoss()
  178. self.network = network1
  179. self.network2 = network2
  180. def construct(self, x1, x2, x4, x5, x6, x7, x8, x9, x10):
  181. predict = self.network(x1, x2, x4, x5, x6, x7, x8)
  182. predict = self.network2(predict, x9, x10)
  183. return self.loss(predict)
  184. class GradWrap(nn.Cell):
  185. def __init__(self, network):
  186. super(GradWrap, self).__init__()
  187. self.network = network
  188. def construct(self, x1, x2, x4, x5, x6, x7, x8, x9, x10):
  189. return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7, x8, x9, x10)
  190. class Net(nn.Cell):
  191. def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7):
  192. super().__init__()
  193. self.matmul1 = P.MatMul().set_strategy(strategy1)
  194. self.matmul3 = P.MatMul().set_strategy(strategy3)
  195. self.matmul4 = P.MatMul().set_strategy(strategy4)
  196. self.matmul5 = P.MatMul().set_strategy(strategy5)
  197. self.matmul6 = P.MatMul().set_strategy(strategy6)
  198. self.matmul7 = P.MatMul().set_strategy(strategy7)
  199. def construct(self, x1, x2, x4, x5, x6, x7, x8):
  200. out = self.matmul1(x1, x2)
  201. out = self.matmul3(out, x4)
  202. out = self.matmul4(out, x5)
  203. out = self.matmul5(out, x6)
  204. out = self.matmul6(out, x7)
  205. out = self.matmul7(out, x8)
  206. return out
  207. class Net1(nn.Cell):
  208. def __init__(self, strategy1, strategy2):
  209. super().__init__()
  210. self.matmul1 = P.MatMul().set_strategy(strategy1)
  211. self.matmul2 = P.MatMul().set_strategy(strategy2)
  212. def construct(self, x1, x2, x3):
  213. out = self.matmul1(x1, x2)
  214. out = self.matmul2(out, x3)
  215. return out
  216. set_auto_parallel_context(device_num=512, global_rank=0)
  217. strategy1 = ((8, 1), (1, 1))
  218. strategy3 = ((8, 1), (1, 1))
  219. strategy4 = ((8, 1), (1, 1))
  220. strategy5 = ((8, 1), (1, 1))
  221. strategy6 = ((8, 1), (1, 1))
  222. strategy7 = ((8, 1), (1, 1))
  223. strategy8 = ((8, 1), (1, 1))
  224. strategy9 = ((8, 1), (1, 1))
  225. net1 = Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)
  226. net2 = Net1(strategy8, strategy9)
  227. net = GradWrap(NetWithLoss(net1, net2))
  228. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  229. x1 = Tensor(np.ones([128, 32]), dtype=ms.float32)
  230. x2 = Tensor(np.ones([32, 64]), dtype=ms.float32)
  231. x4 = Tensor(np.ones([64, 128]), dtype=ms.float32)
  232. x5 = Tensor(np.ones([128, 64]), dtype=ms.float32)
  233. x6 = Tensor(np.ones([64, 32]), dtype=ms.float32)
  234. x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  235. x8 = Tensor(np.ones([32, 128]), dtype=ms.float32)
  236. x9 = Tensor(np.ones([128, 64]), dtype=ms.float32)
  237. x10 = Tensor(np.ones([64, 64]), dtype=ms.float32)
  238. _executor.compile(net, x1, x2, x4, x5, x6, x7, x8, x9, x10)