|
|
|
@@ -44,7 +44,9 @@ if __name__ == "__main__": |
|
|
|
device_id = 0 |
|
|
|
print_rank0("Training {} on HETU".format(args.model)) |
|
|
|
if args.comm_mode in ('AllReduce', 'Hybrid'): |
|
|
|
comm, device_id = ht.mpi_nccl_init() |
|
|
|
comm = ht.wrapped_mpi_nccl_init() |
|
|
|
device_id = comm.dev_id |
|
|
|
rank = comm.rank |
|
|
|
executor_ctx = ht.gpu(device_id % 8) if args.gpu >= 0 else ht.cpu(0) |
|
|
|
else: |
|
|
|
if args.gpu == -1: |
|
|
|
@@ -197,6 +199,4 @@ if __name__ == "__main__": |
|
|
|
print_rank0("Validation accuracy = %f" % accuracy) |
|
|
|
print_rank0("*"*50) |
|
|
|
print_rank0("Running time of total %d epoch = %fs" % |
|
|
|
(args.num_epochs, running_time)) |
|
|
|
if args.comm_mode in ('AllReduce', 'Hybrid'): |
|
|
|
ht.mpi_nccl_finish(comm) |
|
|
|
(args.num_epochs, running_time)) |