You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_indexed_slices.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_indexed_slices.py
  17. @Author:
  18. @Date : 2020-06-08
  19. @Desc : test mindspore indexed_slices's operation
  20. """
  21. import numpy as np
  22. import pytest
  23. import mindspore as ms
  24. import mindspore.nn as nn
  25. from mindspore.ops import composite as C
  26. from mindspore.ops import functional as F
  27. from mindspore.ops import operations as P
  28. from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
  29. from mindspore.ops.primitive import constexpr
  30. from mindspore.ops._grad.grad_base import bprop_getters
  31. from mindspore import Tensor, IndexedSlices, context
  32. from mindspore.common.parameter import Parameter, ParameterTuple
  33. from mindspore.common import dtype as mstype
  34. from mindspore._checkparam import Validator as validator
  35. from mindspore._checkparam import Rel
  36. from mindspore.nn import Optimizer
  37. from mindspore.nn import TrainOneStepCell, WithLossCell
  38. from mindspore.nn.optim import Momentum
  39. from mindspore.train import Model
  40. from ....dataset_mock import MindData
  41. context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
  42. reduce_sum = P.ReduceSum()
  43. unsorted_segment_sum = P.UnsortedSegmentSum()
  44. transpose = P.Transpose()
  45. shape_op = P.Shape()
  46. reshape = P.Reshape()
  47. size_op = P.Size()
  48. invert_permutation = P.InvertPermutation()
  49. logical_and = P.LogicalAnd()
  50. def get_axis(x):
  51. shape = shape_op(x)
  52. length = F.tuple_len(shape)
  53. perm = F.make_range(0, length)
  54. return perm
  55. class MSELoss(nn.Cell):
  56. def __init__(self):
  57. super(MSELoss, self).__init__()
  58. self.reduce_sum = P.ReduceSum()
  59. self.square = P.Square()
  60. self.reduce_mean = P.ReduceMean()
  61. def construct(self, data, label):
  62. diff = data - label
  63. return self.reduce_mean(self.square(diff), get_axis(diff))
  64. class MindDataSet(MindData):
  65. def __init__(self, dataset_types, dataset_shapes):
  66. super(MindDataSet, self).__init__(size=2, batch_size=32,
  67. np_types=dataset_types,
  68. output_shapes=dataset_shapes,
  69. input_indexs=(0, 1))
  70. def __next__(self):
  71. if self._size < self._iter_num:
  72. raise StopIteration
  73. self._iter_num += 1
  74. lst = []
  75. for shape_, type_ in zip(self._output_shapes, self._np_types):
  76. lst.append(Tensor(np.ones(shape_).astype(type_)))
  77. return tuple(lst)
  78. @constexpr
  79. def _generate_shape_index(out_shape, indices_shape, axis):
  80. out_rank = len(out_shape)
  81. ind_rank = len(indices_shape)
  82. if axis < 0:
  83. axis += out_rank - ind_rank + 1
  84. perm_part1 = tuple(range(axis, axis + ind_rank))
  85. index = tuple(range(out_rank))
  86. perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
  87. return perm
  88. @constexpr
  89. def _generate_inverse_index(x_shape, axis):
  90. x_rank = len(x_shape)
  91. index = tuple(range(x_rank))
  92. if axis < 0:
  93. axis += x_rank
  94. perm = index[1:1 + axis] + (0,) + index[1 + axis:]
  95. return perm
  96. class MySparseGatherV2(P.GatherV2):
  97. """
  98. For test
  99. """
  100. @bprop_getters.register(MySparseGatherV2)
  101. def get_bprop_sparse_gather_v2(self):
  102. """Generate bprop for MySparseGatherV2"""
  103. def bprop(x, indices, axis, out, dout):
  104. x_shp = shape_op(x)
  105. if axis == 0:
  106. indices_size = (size_op(indices),)
  107. x_tail_shp = x_shp[1:]
  108. values_shape = indices_size + x_tail_shp
  109. values = reshape(dout, values_shape)
  110. indices = reshape(indices, indices_size)
  111. return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
  112. if F.rank(dout) == 0:
  113. dout = P.ExpandDims()(dout, -1)
  114. if F.rank(indices) == 0:
  115. indices = P.ExpandDims()(indices, -1)
  116. out_shp = shape_op(dout)
  117. ind_shp = shape_op(indices)
  118. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  119. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  120. values_transpose = transpose(dout, perm_1)
  121. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  122. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  123. perm_2 = _generate_inverse_index(x_shp, axis)
  124. params_grad = transpose(params_grad, perm_2)
  125. return params_grad, zeros_like(indices), zeros_like(axis)
  126. return bprop
  127. adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
  128. @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
  129. "Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool")
  130. def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param,
  131. m, v, gradient, decay_flag):
  132. return gradient.values()
  133. @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
  134. "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
  135. def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param,
  136. m, v, gradient, decay_flag):
  137. op_mul = P.Mul()
  138. op_square = P.Square()
  139. op_sqrt = P.Sqrt()
  140. op_cast = P.Cast()
  141. op_reshape = P.Reshape()
  142. op_shape = P.Shape()
  143. param_fp32 = op_cast(param, mstype.float32)
  144. m_fp32 = op_cast(m, mstype.float32)
  145. v_fp32 = op_cast(v, mstype.float32)
  146. gradient_fp32 = op_cast(gradient, mstype.float32)
  147. next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
  148. next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
  149. - beta2, op_square(gradient_fp32))
  150. update = next_m / (op_sqrt(next_v) + eps)
  151. if decay_flag:
  152. update = update + op_mul(weight_decay_tensor, param_fp32)
  153. update_with_lr = op_mul(lr, update)
  154. next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
  155. next_v = F.depend(next_v, F.assign(param, next_param))
  156. next_v = F.depend(next_v, F.assign(m, next_m))
  157. next_v = F.depend(next_v, F.assign(v, next_v))
  158. return next_v
  159. def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
  160. """Check the type of inputs."""
  161. validator.check_value_type("beta1", beta1, [float], prim_name)
  162. validator.check_value_type("beta2", beta2, [float], prim_name)
  163. validator.check_value_type("eps", eps, [float], prim_name)
  164. validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
  165. validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
  166. validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
  167. validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
  168. validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
  169. class AdamWeightDecaySparse(Optimizer):
  170. def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
  171. decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
  172. super(AdamWeightDecaySparse, self).__init__(learning_rate, params)
  173. if self.is_group:
  174. raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
  175. _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
  176. self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
  177. self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
  178. self.eps = Tensor(np.array([eps]).astype(np.float32))
  179. self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
  180. self.params = self.parameters
  181. self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
  182. self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
  183. self.decay_flag = tuple(decay_filter(x) for x in self.params)
  184. self.map = C.Map()
  185. def construct(self, gradients):
  186. lr = self.get_lr()
  187. updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr,
  188. self.weight_decay_tensor),
  189. self.params, self.moments1, self.moments2, gradients, self.decay_flag)
  190. return updated_velocity
  191. def test_indexed_slices_make_indexed_slices():
  192. class MakeIndexedSlices(nn.Cell):
  193. def __init__(self):
  194. super(MakeIndexedSlices, self).__init__()
  195. self.dense_shape = (3, 2)
  196. def construct(self, indices, values):
  197. ret = (IndexedSlices(indices, values, self.dense_shape),)
  198. return ret[0]
  199. indices = Tensor([1, 2])
  200. values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
  201. MakeIndexedSlices()(indices, values)
  202. class IndexedSlicesGetAttr(nn.Cell):
  203. def __init__(self, dense_shape):
  204. super(IndexedSlicesGetAttr, self).__init__()
  205. self.dense_shape = dense_shape
  206. def construct(self, indices, values):
  207. x = IndexedSlices(indices, values, self.dense_shape)
  208. return x.values(), x.indices(), x.dense_shape()
  209. def test_indexed_slices_attr():
  210. indices = Tensor([0])
  211. values = Tensor([[1, 2]], dtype=ms.float32)
  212. IndexedSlicesGetAttr((3, 2))(indices, values)
  213. def test_indexed_slices_sparse_gatherv2_grad_all():
  214. grad_all = C.GradOperation('get_all', get_all=True)
  215. class GradWrap(nn.Cell):
  216. def __init__(self, network):
  217. super(GradWrap, self).__init__()
  218. self.network = network
  219. def construct(self, x, y):
  220. grad = grad_all(self.network)(x, y)
  221. return grad[0].indices(), grad[0].values(), grad[0].dense_shape()
  222. class SparseGatherV2(nn.Cell):
  223. def __init__(self):
  224. super(SparseGatherV2, self).__init__()
  225. self.sparse_gatherv2 = MySparseGatherV2()
  226. self.axis = 0
  227. def construct(self, params, indices):
  228. return self.sparse_gatherv2(params, indices, self.axis)
  229. params = Tensor(np.ones([3, 1, 2]).astype(np.int32))
  230. indices = Tensor(np.array([0, 1]).astype(np.int32))
  231. GradWrap(SparseGatherV2())(params, indices)
  232. def test_indexed_slices_sparse_gatherv2_grad_with_pram():
  233. grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
  234. class GradWrap(nn.Cell):
  235. def __init__(self, network):
  236. super(GradWrap, self).__init__()
  237. self.network = network
  238. self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
  239. def construct(self, x):
  240. weights = self.weights
  241. grad = grad_by_list(self.network, weights)(x)
  242. x = grad[0]
  243. return x.values(), x.indices(), x.dense_shape()
  244. class SparseGatherV2(nn.Cell):
  245. def __init__(self):
  246. super(SparseGatherV2, self).__init__()
  247. self.sparse_gatherv2 = MySparseGatherV2()
  248. self.axis = 0
  249. self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params")
  250. def construct(self, indices):
  251. return self.sparse_gatherv2(self.params, indices, self.axis)
  252. indices = Tensor(np.array([0, 1]).astype(np.int32))
  253. network = GradWrap(SparseGatherV2())
  254. network(indices)
  255. def test_indexed_slices_env_get():
  256. class Loss(nn.Cell):
  257. def __init__(self):
  258. super(Loss, self).__init__()
  259. def construct(self, base, target):
  260. return base
  261. class NetWithSparseGatherV2(nn.Cell):
  262. def __init__(self):
  263. super(NetWithSparseGatherV2, self).__init__()
  264. self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1")
  265. self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
  266. self.gatherv2 = MySparseGatherV2()
  267. self.axis = 0
  268. def construct(self, indices):
  269. return self.gatherv2(self.w1, indices, self.axis) * self.w2
  270. inputs = Tensor(np.array([0, 1]).astype(np.int32))
  271. label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
  272. net = NetWithSparseGatherV2()
  273. net.set_train()
  274. loss = Loss()
  275. optimizer = AdamWeightDecaySparse(net.trainable_params())
  276. net_with_loss = WithLossCell(net, loss)
  277. train_network = TrainOneStepCell(net_with_loss, optimizer)
  278. train_network(inputs, label)
  279. def test_indexed_slices_model_train():
  280. class Net(nn.Cell):
  281. def __init__(self, in_features, out_features):
  282. super(Net, self).__init__()
  283. self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
  284. self.add = P.TensorAdd()
  285. self.cast = P.Cast()
  286. self.flag = True
  287. def construct(self, inputs, label):
  288. x = self.add(inputs, self.weight)
  289. if self.flag:
  290. x = self.cast(x, mstype.float32)
  291. return x
  292. dataset_types = (np.float32, np.float32)
  293. dataset_shapes = ((16, 16), (16, 16))
  294. dataset = MindDataSet(dataset_types, dataset_shapes)
  295. net = Net(16, 16)
  296. net.set_train()
  297. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  298. model = Model(net, optimizer=optimizer)
  299. model.train(2, dataset, dataset_sink_mode=False)
  300. def test_indexed_slices_values_dim_greater_than_dense_shape_dim():
  301. indices = Tensor(np.array([0, 1], dtype=np.int32))
  302. values = Tensor(np.random.randn(2, 4, 5).astype(np.float32))
  303. dense_shape = (3, 4)
  304. with pytest.raises(TypeError):
  305. IndexedSlicesGetAttr(dense_shape)(indices, values)
  306. def test_indexed_slices_values_dim_less_than_dense_shape_dim():
  307. indices = Tensor(np.array([0, 1], dtype=np.int32))
  308. values = Tensor(np.random.randn(2, 4).astype(np.float32))
  309. dense_shape = (3, 4, 5)
  310. with pytest.raises(TypeError):
  311. IndexedSlicesGetAttr(dense_shape)(indices, values)
  312. def test_indexed_slices_value_and_dense_shape_illegal():
  313. indices = Tensor(np.array([0, 1], dtype=np.int32))
  314. values = Tensor(np.random.randn(2, 4).astype(np.float32))
  315. dense_shape = (3, 5)
  316. with pytest.raises(TypeError):
  317. IndexedSlicesGetAttr(dense_shape)(indices, values)
  318. class IndexedSlicesValuesDouble(nn.Cell):
  319. def __init__(self):
  320. super().__init__()
  321. def construct(self, x):
  322. indices = x.indices()
  323. values = x.values() * 2
  324. dense_shape = x.dense_shape()
  325. return IndexedSlices(indices, values, dense_shape)
  326. class IndexedSlicesValuesAdd2(nn.Cell):
  327. def __init__(self):
  328. super().__init__()
  329. def construct(self, x):
  330. indices = x.indices()
  331. values = x.values() + 2
  332. dense_shape = x.dense_shape()
  333. return IndexedSlices(indices, values, dense_shape)
  334. class IndexedSlicesWithControlIf(nn.Cell):
  335. def __init__(self, dense_shape):
  336. super().__init__()
  337. self.op1 = IndexedSlicesValuesDouble()
  338. self.op2 = IndexedSlicesValuesAdd2()
  339. self.dense_shape = dense_shape
  340. def construct(self, a, b, indices, values):
  341. x = IndexedSlices(indices, values, self.dense_shape)
  342. if a > b:
  343. x = self.op1(x)
  344. else:
  345. x = self.op2(x)
  346. return x.indices(), x.values()
  347. def test_indexed_slices_with_control_flow_if():
  348. a = Tensor(np.array(0).astype(np.int32))
  349. b = Tensor(np.array(2).astype(np.int32))
  350. indices = Tensor(np.array([0, 2]).astype(np.int32))
  351. values = Tensor(np.ones([2, 2]).astype(np.float32))
  352. dense_shape = (5, 2)
  353. net = IndexedSlicesWithControlIf(dense_shape)
  354. net(a, b, indices, values)
  355. class EmbeddingLookUpBnNet(nn.Cell):
  356. def __init__(self, param_np, target='CPU'):
  357. super().__init__()
  358. self.param = Parameter(Tensor(param_np), name="w1")
  359. self.embedding_lookup = nn.EmbeddingLookup(target=target)
  360. self.bn = nn.BatchNorm2d(num_features=3)
  361. self.mul = P.Mul()
  362. self.reshape = P.Reshape()
  363. self.relu = nn.PReLU()
  364. def construct(self, indices):
  365. x = self.embedding_lookup(self.param, indices)
  366. x = self.reshape(x, (2, 3, 2, 2))
  367. x = self.relu(x)
  368. x = self.bn(x)
  369. return x
  370. def test_embedding_lookup_with_mix_precision():
  371. param_np = np.ones([8, 8]).astype(np.float32)
  372. data = Tensor(np.array([0, 1, 2]).astype(np.int32))
  373. label = Tensor(np.random.randn(*(2, 3, 2, 2)).astype(np.float32))
  374. net = EmbeddingLookUpBnNet(param_np, target='CPU')
  375. criterion = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
  376. optimizer = nn.Adam(params=net.trainable_params(), learning_rate=0.1)
  377. optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
  378. train_network = ms.amp.build_train_network(net, optimizer, criterion, level="O2")
  379. train_network.set_train()
  380. for _ in range(2):
  381. train_network(data, label)