You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

torch_main.py 8.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torch.nn.functional as F
  5. import torch.backends.cudnn as cudnn
  6. from pytorch_models import *
  7. import hetu as ht
  8. import numpy as np
  9. import argparse
  10. from time import time
  11. import os
  12. import logging
  13. logging.basicConfig(level=logging.INFO,
  14. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  15. logger = logging.getLogger(__name__)
  16. def print_rank0(msg):
  17. if local_rank % 8 == 0:
  18. logger.info(msg)
  19. def train(epoch=-1, net=None, data=None, label=None, batch_size=-1, criterion=None, optimizer=None):
  20. print_rank0('Epoch: %d' % epoch)
  21. n_train_batches = data.shape[0] // batch_size
  22. net.train()
  23. train_loss = 0
  24. correct = 0
  25. total = 0
  26. for minibatch_index in range(n_train_batches):
  27. minibatch_start = minibatch_index * args.batch_size
  28. minibatch_end = (minibatch_index + 1) * args.batch_size
  29. inputs = torch.Tensor(data[minibatch_start:minibatch_end])
  30. targets = torch.Tensor(label[minibatch_start:minibatch_end]).long()
  31. inputs, targets = inputs.to(device), targets.to(device)
  32. optimizer.zero_grad()
  33. outputs = net(inputs)
  34. loss = criterion(outputs, targets)
  35. loss.backward()
  36. optimizer.step()
  37. train_loss += loss.item()
  38. _, predicted = outputs.max(1)
  39. total += targets.size(0)
  40. correct += predicted.eq(targets).sum().item()
  41. print_rank0("Train loss = %f" % (train_loss/(minibatch_index+1)))
  42. print_rank0("Train accuracy = %f" % (100.*correct/total))
  43. def test(epoch=-1, net=None, data=None, label=None, batch_size=-1, criterion=None):
  44. net.eval()
  45. n_test_batches = data.shape[0] // batch_size
  46. test_loss = 0
  47. correct = 0
  48. total = 0
  49. with torch.no_grad():
  50. for minibatch_index in range(n_test_batches):
  51. minibatch_start = minibatch_index * args.batch_size
  52. minibatch_end = (minibatch_index + 1) * args.batch_size
  53. inputs = torch.Tensor(data[minibatch_start:minibatch_end])
  54. targets = torch.Tensor(label[minibatch_start:minibatch_end]).long()
  55. inputs, targets = inputs.to(device), targets.to(device)
  56. outputs = net(inputs)
  57. loss = criterion(outputs, targets)
  58. test_loss += loss.item()
  59. _, predicted = outputs.max(1)
  60. total += targets.size(0)
  61. correct += predicted.eq(targets).sum().item()
  62. print_rank0("Validation loss = %f" % (test_loss/(minibatch_index+1)))
  63. print_rank0("Validation accuracy = %f" % (100.*correct/total))
  64. if __name__ == "__main__":
  65. # argument parser
  66. global local_rank
  67. local_rank = 0
  68. parser = argparse.ArgumentParser()
  69. parser.add_argument('--model', type=str, required=True,
  70. help='model to be tested')
  71. parser.add_argument('--dataset', type=str, required=True,
  72. help='dataset to be trained on')
  73. parser.add_argument('--batch-size', type=int,
  74. default=128, help='batch size')
  75. parser.add_argument('--learning-rate', type=float,
  76. default=0.1, help='learning rate')
  77. parser.add_argument('--opt', type=str, default='sgd',
  78. help='optimizer to be used, default sgd; sgd / momentum / adagrad / adam')
  79. parser.add_argument('--num-epochs', type=int,
  80. default=20, help='epoch number')
  81. parser.add_argument('--gpu', type=int, default=0,
  82. help='gpu to be used, -1 means cpu')
  83. parser.add_argument('--validate', action='store_true',
  84. help='whether to use validation')
  85. parser.add_argument('--timing', action='store_true',
  86. help='whether to time the training phase')
  87. parser.add_argument('--distributed', action='store_true',
  88. help='whether to distributed training')
  89. parser.add_argument('--local_rank', type=int, default=-1)
  90. args = parser.parse_args()
  91. if args.distributed == True:
  92. init_method = 'tcp://'
  93. master_ip = os.getenv('MASTER_ADDR', 'localhost')
  94. master_port = os.getenv('MASTER_PORT', '6000')
  95. init_method += master_ip + ':' + master_port
  96. rank = int(os.getenv('RANK', '0'))
  97. world_size = int(os.getenv("WORLD_SIZE", '1'))
  98. print("***"*50)
  99. print(init_method)
  100. torch.distributed.init_process_group(backend="nccl",
  101. world_size=world_size,
  102. rank=rank,
  103. init_method=init_method)
  104. if args.gpu == -1:
  105. device = 'cpu'
  106. else:
  107. if args.distributed == True:
  108. local_rank = rank % torch.cuda.device_count()
  109. torch.cuda.set_device(local_rank)
  110. device = torch.device('cuda:%d' % local_rank)
  111. logger.info('Use GPU %d.' % local_rank)
  112. else:
  113. device = torch.device('cuda:%d' % args.gpu)
  114. torch.cuda.set_device(args.gpu)
  115. print_rank0('Use GPU %d.' % args.gpu)
  116. assert args.model in ['mlp', 'resnet18', 'resnet34',
  117. 'vgg16', 'vgg19', 'rnn'], 'Model not supported now.'
  118. assert args.dataset in ['MNIST', 'CIFAR10', 'CIFAR100', 'ImageNet']
  119. dataset = args.dataset
  120. if args.model in ['resnet18', 'resnet34', 'vgg16', 'vgg19'] and args.dataset == 'CIFAR100':
  121. net = eval(args.model)(100)
  122. elif args.model == 'rnn':
  123. net = eval(args.model)(28, 10, 128, 28)
  124. else:
  125. net = eval(args.model)()
  126. assert args.dataset in ['MNIST', 'CIFAR10', 'CIFAR100', 'ImageNet']
  127. dataset = args.dataset
  128. net.to(device)
  129. if args.distributed:
  130. net = torch.nn.parallel.DistributedDataParallel(
  131. net, device_ids=[local_rank])
  132. assert args.opt in ['sgd', 'momentum', 'nesterov',
  133. 'adagrad', 'adam'], 'Optimizer not supported!'
  134. if args.opt == 'sgd':
  135. print_rank0('Use SGD Optimizer.')
  136. opt = optim.SGD(net.parameters(), lr=args.learning_rate)
  137. elif args.opt == 'momentum':
  138. print_rank0('Use Momentum Optimizer.')
  139. opt = optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9)
  140. elif args.opt == 'nesterov':
  141. print_rank0('Use Nesterov Momentum Optimizer.')
  142. opt = optim.SGD(net.parameters(), lr=args.learning_rate,
  143. momentum=0.9, nesterov=True)
  144. elif args.opt == 'adagrad':
  145. print_rank0('Use AdaGrad Optimizer.')
  146. opt = optim.Adagrad(net.parameters(), lr=args.learning_rate)
  147. else:
  148. print_rank0('Use Adam Optimizer.')
  149. opt = optim.Adam(lr=args.learning_rate)
  150. criterion = nn.CrossEntropyLoss()
  151. # data loading
  152. print_rank0('Loading %s data...' % dataset)
  153. if dataset == 'MNIST':
  154. datasets = ht.data.mnist(onehot=False)
  155. train_set_x, train_set_y = datasets[0]
  156. valid_set_x, valid_set_y = datasets[1]
  157. test_set_x, test_set_y = datasets[2]
  158. elif dataset == 'CIFAR10':
  159. train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.normalize_cifar(
  160. num_class=10, onehot=False)
  161. if args.model == "mlp":
  162. train_set_x = train_set_x.reshape(train_set_x.shape[0], -1)
  163. valid_set_x = valid_set_x.reshape(valid_set_x.shape[0], -1)
  164. elif dataset == 'CIFAR100':
  165. train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.normalize_cifar(
  166. num_class=100, onehot=False)
  167. running_time = 0
  168. # training
  169. print_rank0("Start training loop...")
  170. for i in range(args.num_epochs + 1):
  171. if args.timing:
  172. start = time()
  173. train(epoch=i, net=net, data=train_set_x, label=train_set_y,
  174. batch_size=args.batch_size, criterion=criterion, optimizer=opt)
  175. if args.timing:
  176. end = time()
  177. print_rank0("Running time of current epoch = %fs" % (end - start))
  178. if i != 0:
  179. running_time += (end - start)
  180. test(epoch=i, net=net, data=valid_set_x, label=valid_set_y,
  181. batch_size=args.batch_size, criterion=criterion)
  182. print_rank0("*"*50)
  183. print_rank0("Running time of total %d epoch = %fs" %
  184. (args.num_epochs, running_time))