You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset_util.py 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from mindspore.train._utils import _to_full_shapes, _to_full_tensor
  16. from mindspore import Tensor
  17. import mindspore as ms
  18. def test_to_full_shapes():
  19. device_num = 16
  20. shapes = [[32, 128], [12], [24, 1, 12]]
  21. full_shapes = _to_full_shapes(shapes, device_num)
  22. assert full_shapes == [(512, 128), (192,), (384, 1, 12)]
  23. def test_to_full_tensor_1():
  24. elem = Tensor([[1,2,3], [4,5,6]], dtype=ms.float32)
  25. device_num = 4
  26. global_rank = 2
  27. full_tensor = _to_full_tensor(elem, device_num, global_rank, scaling_sens=None)
  28. expect = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1,2,3], [4,5,6], [0, 0, 0], [0, 0, 0]])
  29. expect_tensor = Tensor(expect, dtype=ms.float32)
  30. assert (full_tensor[0] == expect_tensor)
  31. def test_to_full_tensor_2():
  32. elem0 = Tensor([[1,2,3], [4,5,6]], dtype=ms.float32)
  33. elem1 = Tensor([[1], [4]], dtype=ms.int32)
  34. elem = (elem0, elem1,)
  35. device_num = 4
  36. global_rank = 2
  37. full_tensor = _to_full_tensor(elem, device_num, global_rank, scaling_sens=None)
  38. expect0 = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1,2,3], [4,5,6], [0, 0, 0], [0, 0, 0]])
  39. expect_tensor0 = Tensor(expect0, dtype=ms.float32)
  40. expect1 = ([[0], [0], [0], [0], [1], [4], [0], [0]])
  41. expect_tensor1 = Tensor(expect1, dtype=ms.int32)
  42. expect_tensors = (expect_tensor0, expect_tensor1)
  43. assert (full_tensor == expect_tensors)
  44. def test_to_full_tensor_sens_2():
  45. elem0 = Tensor([[1,2,3], [4,5,6]], dtype=ms.float32)
  46. elem1 = Tensor([[1], [4]], dtype=ms.int32)
  47. elem = (elem0, elem1,)
  48. device_num = 4
  49. global_rank = 2
  50. full_tensor = _to_full_tensor(elem, device_num, global_rank, scaling_sens=0.1)
  51. expect0 = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1,2,3], [4,5,6], [0, 0, 0], [0, 0, 0]])
  52. expect_tensor0 = Tensor(expect0, dtype=ms.float32)
  53. expect1 = ([[0], [0], [0], [0], [1], [4], [0], [0]])
  54. expect_tensor1 = Tensor(expect1, dtype=ms.int32)
  55. expect_tensor_sens = Tensor(0.1, dtype=ms.float32)
  56. expect_tensors = (expect_tensor0, expect_tensor1, expect_tensor_sens)
  57. assert (full_tensor == expect_tensors)