You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset_util.py 2.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import mindspore as ms
  15. from mindspore import Tensor
  16. from mindspore.train._utils import _to_full_shapes, _to_full_tensor
  17. def test_to_full_shapes():
  18. device_num = 16
  19. shapes = [[32, 128], [12], [24, 1, 12]]
  20. full_shapes = _to_full_shapes(shapes, device_num)
  21. assert full_shapes == [(512, 128), (192,), (384, 1, 12)]
  22. def test_to_full_tensor_1():
  23. elem = Tensor([[1, 2, 3], [4, 5, 6]], dtype=ms.float32)
  24. device_num = 4
  25. global_rank = 2
  26. full_tensor = _to_full_tensor(elem, device_num, global_rank, scaling_sens=None)
  27. expect = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0]])
  28. expect_tensor = Tensor(expect, dtype=ms.float32)
  29. assert full_tensor[0] == expect_tensor
  30. def test_to_full_tensor_2():
  31. elem0 = Tensor([[1, 2, 3], [4, 5, 6]], dtype=ms.float32)
  32. elem1 = Tensor([[1], [4]], dtype=ms.int32)
  33. elem = (elem0, elem1,)
  34. device_num = 4
  35. global_rank = 2
  36. full_tensor = _to_full_tensor(elem, device_num, global_rank, scaling_sens=None)
  37. expect0 = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0]])
  38. expect_tensor0 = Tensor(expect0, dtype=ms.float32)
  39. expect1 = ([[0], [0], [0], [0], [1], [4], [0], [0]])
  40. expect_tensor1 = Tensor(expect1, dtype=ms.int32)
  41. expect_tensors = (expect_tensor0, expect_tensor1)
  42. assert full_tensor == expect_tensors
  43. def test_to_full_tensor_sens_2():
  44. elem0 = Tensor([[1, 2, 3], [4, 5, 6]], dtype=ms.float32)
  45. elem1 = Tensor([[1], [4]], dtype=ms.int32)
  46. elem = (elem0, elem1,)
  47. device_num = 4
  48. global_rank = 2
  49. full_tensor = _to_full_tensor(elem, device_num, global_rank, scaling_sens=0.1)
  50. expect0 = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0]])
  51. expect_tensor0 = Tensor(expect0, dtype=ms.float32)
  52. expect1 = ([[0], [0], [0], [0], [1], [4], [0], [0]])
  53. expect_tensor1 = Tensor(expect1, dtype=ms.int32)
  54. expect_tensor_sens = Tensor(0.1, dtype=ms.float32)
  55. expect_tensors = (expect_tensor0, expect_tensor1, expect_tensor_sens)
  56. assert full_tensor == expect_tensors