|
- import hetu as ht
- import models
- import os
- import numpy as np
- import argparse
- import json
- import logging
- from time import time
- logging.basicConfig(level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- logger = logging.getLogger(__name__)
-
-
- def print_rank0(msg):
- if device_id == 0:
- logger.info(msg)
-
-
- if __name__ == "__main__":
- # argument parser
- parser = argparse.ArgumentParser()
- parser.add_argument('--model', type=str, required=True,
- help='model to be tested')
- parser.add_argument('--dataset', type=str, required=True,
- help='dataset to be trained on')
- parser.add_argument('--batch-size', type=int,
- default=128, help='batch size')
- parser.add_argument('--learning-rate', type=float,
- default=0.1, help='learning rate')
- parser.add_argument('--opt', type=str, default='sgd',
- help='optimizer to be used, default sgd; sgd / momentum / adagrad / adam')
- parser.add_argument('--num-epochs', type=int,
- default=10, help='epoch number')
- parser.add_argument('--gpu', type=int, default=0,
- help='gpu to be used, -1 means cpu')
- parser.add_argument('--validate', action='store_true',
- help='whether to use validation')
- parser.add_argument('--timing', action='store_true',
- help='whether to time the training phase')
- parser.add_argument('--comm-mode', default=None, help='communication mode')
- args = parser.parse_args()
-
- global device_id
- device_id = 0
- print_rank0("Training {} on HETU".format(args.model))
- if args.comm_mode in ('AllReduce', 'Hybrid'):
- comm = ht.wrapped_mpi_nccl_init()
- device_id = comm.dev_id
- rank = comm.rank
- executor_ctx = ht.gpu(device_id % 8) if args.gpu >= 0 else ht.cpu(0)
- else:
- if args.gpu == -1:
- executor_ctx = ht.cpu(0)
- print_rank0('Use CPU.')
- else:
- executor_ctx = ht.gpu(args.gpu)
- print_rank0('Use GPU %d.' % args.gpu)
- if args.comm_mode in ('PS', 'Hybrid'):
- settings_file = open(os.path.join(os.path.abspath(
- os.path.dirname(__file__)), 'worker_conf%d.json' % args.gpu))
- settings = json.load(settings_file)
- for key in settings:
- if type(settings[key]) == str:
- os.environ[key] = settings[key]
- else:
- os.environ[key] = str(settings[key]) # type is str
-
- assert args.model in ['alexnet', 'cnn_3_layers', 'lenet', 'logreg', 'lstm', 'mlp', 'resnet18', 'resnet34', 'rnn', 'vgg16', 'vgg19'], \
- 'Model not supported!'
- model = eval('models.' + args.model)
-
- assert args.dataset in ['MNIST', 'CIFAR10', 'CIFAR100', 'ImageNet']
- dataset = args.dataset
- assert args.opt in ['sgd', 'momentum', 'nesterov',
- 'adagrad', 'adam'], 'Optimizer not supported!'
-
- if args.opt == 'sgd':
- print_rank0('Use SGD Optimizer.')
- opt = ht.optim.SGDOptimizer(learning_rate=args.learning_rate)
- elif args.opt == 'momentum':
- print_rank0('Use Momentum Optimizer.')
- opt = ht.optim.MomentumOptimizer(learning_rate=args.learning_rate)
- elif args.opt == 'nesterov':
- print_rank0('Use Nesterov Momentum Optimizer.')
- opt = ht.optim.MomentumOptimizer(
- learning_rate=args.learning_rate, nesterov=True)
- elif args.opt == 'adagrad':
- print_rank0('Use AdaGrad Optimizer.')
- opt = ht.optim.AdaGradOptimizer(
- learning_rate=args.learning_rate, initial_accumulator_value=0.1)
- else:
- print_rank0('Use Adam Optimizer.')
- opt = ht.optim.AdamOptimizer(learning_rate=args.learning_rate)
-
- # data loading
- print_rank0('Loading %s data...' % dataset)
- if dataset == 'MNIST':
- datasets = ht.data.mnist()
- train_set_x, train_set_y = datasets[0]
- valid_set_x, valid_set_y = datasets[1]
- test_set_x, test_set_y = datasets[2]
- # train_set_x: (50000, 784), train_set_y: (50000, 10)
- # valid_set_x: (10000, 784), valid_set_y: (10000, 10)
- # x_shape = (args.batch_size, 784)
- # y_shape = (args.batch_size, 10)
- elif dataset == 'CIFAR10':
- train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.normalize_cifar(
- num_class=10)
- if args.model == "mlp":
- train_set_x = train_set_x.reshape(train_set_x.shape[0], -1)
- valid_set_x = valid_set_x.reshape(valid_set_x.shape[0], -1)
- # train_set_x: (50000, 3, 32, 32), train_set_y: (50000, 10)
- # valid_set_x: (10000, 3, 32, 32), valid_set_y: (10000, 10)
- # x_shape = (args.batch_size, 3, 32, 32)
- # y_shape = (args.batch_size, 10)
- elif dataset == 'CIFAR100':
- train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.normalize_cifar(
- num_class=100)
- # train_set_x: (50000, 3, 32, 32), train_set_y: (50000, 100)
- # valid_set_x: (10000, 3, 32, 32), valid_set_y: (10000, 100)
- else:
- raise NotImplementedError
-
- # model definition
- print_rank0('Building model {}'.format(args.model))
- x = ht.dataloader_op([
- ht.Dataloader(train_set_x, args.batch_size, 'train'),
- ht.Dataloader(valid_set_x, args.batch_size, 'validate'),
- ])
- y_ = ht.dataloader_op([
- ht.Dataloader(train_set_y, args.batch_size, 'train'),
- ht.Dataloader(valid_set_y, args.batch_size, 'validate'),
- ])
- if args.model in ['resnet18', 'resnet34', 'vgg16', 'vgg19'] and args.dataset == 'CIFAR100':
- loss, y = model(x, y_, 100)
- else:
- loss, y = model(x, y_)
-
- train_op = opt.minimize(loss)
-
- eval_nodes = {'train': [loss, y, y_, train_op], 'validate': [loss, y, y_]}
- executor = ht.Executor(eval_nodes, ctx=executor_ctx,
- comm_mode=args.comm_mode)
- n_train_batches = executor.get_batch_num('train')
- n_valid_batches = executor.get_batch_num('validate')
-
- # training
- print_rank0("Start training loop...")
- running_time = 0
- for i in range(args.num_epochs + 1):
- print_rank0("Epoch %d" % i)
- loss_all = 0
- batch_num = 0
- if args.timing:
- start = time()
- correct_predictions = []
- for minibatch_index in range(n_train_batches):
- loss_val, predict_y, y_val, _ = executor.run(
- 'train', eval_node_list=[loss, y, y_, train_op])
- # Loss for this minibatch
- predict_y = predict_y.asnumpy()
- y_val = y_val.asnumpy()
- loss_all += loss_val.asnumpy()
- batch_num += 1
- # Predict accuracy for this minibatch
- correct_prediction = np.equal(
- np.argmax(y_val, 1),
- np.argmax(predict_y, 1)).astype(np.float32)
- correct_predictions.extend(correct_prediction)
-
- loss_all /= batch_num
- accuracy = np.mean(correct_predictions)
- print_rank0("Train loss = %f" % loss_all)
- print_rank0("Train accuracy = %f" % accuracy)
-
- if args.timing:
- end = time()
- during_time = end - start
- print_rank0("Running time of current epoch = %fs" % (during_time))
- if i != 0:
- running_time += during_time
- if args.validate:
- val_loss_all = 0
- batch_num = 0
- correct_predictions = []
- for minibatch_index in range(n_valid_batches):
- loss_val, valid_y_predicted, y_val = executor.run(
- 'validate', eval_node_list=[loss, y, y_], convert_to_numpy_ret_vals=True)
- val_loss_all += loss_val
- batch_num += 1
- correct_prediction = np.equal(
- np.argmax(y_val, 1),
- np.argmax(valid_y_predicted, 1)).astype(np.float32)
- correct_predictions.extend(correct_prediction)
-
- val_loss_all /= batch_num
- accuracy = np.mean(correct_predictions)
- print_rank0("Validation loss = %f" % val_loss_all)
- print_rank0("Validation accuracy = %f" % accuracy)
- print_rank0("*"*50)
- print_rank0("Running time of total %d epoch = %fs" %
- (args.num_epochs, running_time))
|