You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_nccl_bandwidth.py 1.6 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from hetu.communicator.mpi_nccl_comm import ncclDataType_t, ncclRedOp_t, mpi_communicator
  2. from hetu import ndarray
  3. import numpy as np
  4. import time
  5. def test_allreduce(comm=None):
  6. shape = (24, 24)
  7. size = 4
  8. for val in shape:
  9. size *= val
  10. input_arr = np.ones(shape)*comm.localRank.value
  11. input_arr = ndarray.array(input_arr, ctx=ndarray.gpu(comm.localRank.value))
  12. start = time.time()
  13. comm.dlarrayNcclAllReduce(input_arr, input_arr,
  14. ncclDataType_t.ncclFloat32, ncclRedOp_t.ncclSum)
  15. comm.stream.sync()
  16. end = time.time()
  17. secs = end - start
  18. return size, secs
  19. def test_p2p(comm=None, src=0, target=1):
  20. shape = (1000, 30, 224, 224)
  21. size = 4
  22. for val in shape:
  23. size *= val
  24. print("MyRank: ", comm.myRank.value)
  25. arr = np.ones(shape)*comm.localRank.value
  26. arr = ndarray.array(arr, ctx=ndarray.gpu(comm.localRank.value))
  27. start = time.time()
  28. if comm.myRank.value == 0:
  29. comm.dlarraySend(arr, ncclDataType_t.ncclFloat32, 1)
  30. else:
  31. comm.dlarrayRecv(arr, ncclDataType_t.ncclFloat32, 0)
  32. comm.stream.sync()
  33. end = time.time()
  34. secs = end - start
  35. # size: /Bytes
  36. # dur_time: /s
  37. return size, secs
  38. # mpirun --allow-run-as-root --tag-output -np 2 python test_nccl_bandwidth.py
  39. if __name__ == "__main__":
  40. comm = mpi_communicator()
  41. comm = comm.ncclInit()
  42. size, secs = test_p2p(comm)
  43. print("band width: %.2f MB/s" % (size/(2**20)/secs))
  44. size, secs = test_allreduce(comm)
  45. print("band width: %.2f MB/s" % (size/(2**20)/secs))

分布式深度学习系统