|
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torch.nn.functional as F
- import torch.backends.cudnn as cudnn
- from pytorch_models import *
- import hetu as ht
- import numpy as np
- import argparse
- from time import time
- import os
- import logging
- logging.basicConfig(level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- logger = logging.getLogger(__name__)
-
-
- def print_rank0(msg):
- if local_rank % 8 == 0:
- logger.info(msg)
-
-
- def train(epoch=-1, net=None, data=None, label=None, batch_size=-1, criterion=None, optimizer=None):
- print_rank0('Epoch: %d' % epoch)
- n_train_batches = data.shape[0] // batch_size
-
- net.train()
-
- train_loss = 0
- correct = 0
- total = 0
-
- for minibatch_index in range(n_train_batches):
- minibatch_start = minibatch_index * args.batch_size
- minibatch_end = (minibatch_index + 1) * args.batch_size
- inputs = torch.Tensor(data[minibatch_start:minibatch_end])
- targets = torch.Tensor(label[minibatch_start:minibatch_end]).long()
-
- inputs, targets = inputs.to(device), targets.to(device)
- optimizer.zero_grad()
- outputs = net(inputs)
- loss = criterion(outputs, targets)
- loss.backward()
- optimizer.step()
-
- train_loss += loss.item()
- _, predicted = outputs.max(1)
- total += targets.size(0)
- correct += predicted.eq(targets).sum().item()
-
- print_rank0("Train loss = %f" % (train_loss/(minibatch_index+1)))
- print_rank0("Train accuracy = %f" % (100.*correct/total))
-
-
- def test(epoch=-1, net=None, data=None, label=None, batch_size=-1, criterion=None):
- net.eval()
- n_test_batches = data.shape[0] // batch_size
- test_loss = 0
- correct = 0
- total = 0
-
- with torch.no_grad():
- for minibatch_index in range(n_test_batches):
- minibatch_start = minibatch_index * args.batch_size
- minibatch_end = (minibatch_index + 1) * args.batch_size
- inputs = torch.Tensor(data[minibatch_start:minibatch_end])
- targets = torch.Tensor(label[minibatch_start:minibatch_end]).long()
-
- inputs, targets = inputs.to(device), targets.to(device)
- outputs = net(inputs)
- loss = criterion(outputs, targets)
- test_loss += loss.item()
- _, predicted = outputs.max(1)
- total += targets.size(0)
- correct += predicted.eq(targets).sum().item()
-
- print_rank0("Validation loss = %f" % (test_loss/(minibatch_index+1)))
- print_rank0("Validation accuracy = %f" % (100.*correct/total))
-
-
- if __name__ == "__main__":
- # argument parser
- global local_rank
- local_rank = 0
- parser = argparse.ArgumentParser()
- parser.add_argument('--model', type=str, required=True,
- help='model to be tested')
- parser.add_argument('--dataset', type=str, required=True,
- help='dataset to be trained on')
- parser.add_argument('--batch-size', type=int,
- default=128, help='batch size')
- parser.add_argument('--learning-rate', type=float,
- default=0.1, help='learning rate')
- parser.add_argument('--opt', type=str, default='sgd',
- help='optimizer to be used, default sgd; sgd / momentum / adagrad / adam')
- parser.add_argument('--num-epochs', type=int,
- default=20, help='epoch number')
- parser.add_argument('--gpu', type=int, default=0,
- help='gpu to be used, -1 means cpu')
- parser.add_argument('--validate', action='store_true',
- help='whether to use validation')
- parser.add_argument('--timing', action='store_true',
- help='whether to time the training phase')
- parser.add_argument('--distributed', action='store_true',
- help='whether to distributed training')
- parser.add_argument('--local_rank', type=int, default=-1)
- args = parser.parse_args()
-
- if args.distributed == True:
- init_method = 'tcp://'
- master_ip = os.getenv('MASTER_ADDR', 'localhost')
- master_port = os.getenv('MASTER_PORT', '6000')
- init_method += master_ip + ':' + master_port
- rank = int(os.getenv('RANK', '0'))
- world_size = int(os.getenv("WORLD_SIZE", '1'))
- print("***"*50)
- print(init_method)
- torch.distributed.init_process_group(backend="nccl",
- world_size=world_size,
- rank=rank,
- init_method=init_method)
-
- if args.gpu == -1:
- device = 'cpu'
- else:
- if args.distributed == True:
- local_rank = rank % torch.cuda.device_count()
- torch.cuda.set_device(local_rank)
- device = torch.device('cuda:%d' % local_rank)
- logger.info('Use GPU %d.' % local_rank)
- else:
- device = torch.device('cuda:%d' % args.gpu)
- torch.cuda.set_device(args.gpu)
- print_rank0('Use GPU %d.' % args.gpu)
-
- assert args.model in ['mlp', 'resnet18', 'resnet34',
- 'vgg16', 'vgg19', 'rnn'], 'Model not supported now.'
-
- assert args.dataset in ['MNIST', 'CIFAR10', 'CIFAR100', 'ImageNet']
- dataset = args.dataset
-
- if args.model in ['resnet18', 'resnet34', 'vgg16', 'vgg19'] and args.dataset == 'CIFAR100':
- net = eval(args.model)(100)
- elif args.model == 'rnn':
- net = eval(args.model)(28, 10, 128, 28)
- else:
- net = eval(args.model)()
-
- assert args.dataset in ['MNIST', 'CIFAR10', 'CIFAR100', 'ImageNet']
- dataset = args.dataset
-
- net.to(device)
- if args.distributed:
- net = torch.nn.parallel.DistributedDataParallel(
- net, device_ids=[local_rank])
-
- assert args.opt in ['sgd', 'momentum', 'nesterov',
- 'adagrad', 'adam'], 'Optimizer not supported!'
- if args.opt == 'sgd':
- print_rank0('Use SGD Optimizer.')
- opt = optim.SGD(net.parameters(), lr=args.learning_rate)
- elif args.opt == 'momentum':
- print_rank0('Use Momentum Optimizer.')
- opt = optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9)
- elif args.opt == 'nesterov':
- print_rank0('Use Nesterov Momentum Optimizer.')
- opt = optim.SGD(net.parameters(), lr=args.learning_rate,
- momentum=0.9, nesterov=True)
- elif args.opt == 'adagrad':
- print_rank0('Use AdaGrad Optimizer.')
- opt = optim.Adagrad(net.parameters(), lr=args.learning_rate)
- else:
- print_rank0('Use Adam Optimizer.')
- opt = optim.Adam(lr=args.learning_rate)
-
- criterion = nn.CrossEntropyLoss()
-
- # data loading
- print_rank0('Loading %s data...' % dataset)
- if dataset == 'MNIST':
- datasets = ht.data.mnist(onehot=False)
- train_set_x, train_set_y = datasets[0]
- valid_set_x, valid_set_y = datasets[1]
- test_set_x, test_set_y = datasets[2]
- elif dataset == 'CIFAR10':
- train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.normalize_cifar(
- num_class=10, onehot=False)
- if args.model == "mlp":
- train_set_x = train_set_x.reshape(train_set_x.shape[0], -1)
- valid_set_x = valid_set_x.reshape(valid_set_x.shape[0], -1)
- elif dataset == 'CIFAR100':
- train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.normalize_cifar(
- num_class=100, onehot=False)
-
- running_time = 0
- # training
- print_rank0("Start training loop...")
- for i in range(args.num_epochs + 1):
- if args.timing:
- start = time()
- train(epoch=i, net=net, data=train_set_x, label=train_set_y,
- batch_size=args.batch_size, criterion=criterion, optimizer=opt)
- if args.timing:
- end = time()
- print_rank0("Running time of current epoch = %fs" % (end - start))
- if i != 0:
- running_time += (end - start)
- test(epoch=i, net=net, data=valid_set_x, label=valid_set_y,
- batch_size=args.batch_size, criterion=criterion)
-
- print_rank0("*"*50)
- print_rank0("Running time of total %d epoch = %fs" %
- (args.num_epochs, running_time))
|