|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- import tensorflow as tf
- import tf_models
- import hetu as ht
- import numpy as np
- import argparse
- from time import time
- import logging
- logging.basicConfig(level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- logger = logging.getLogger(__name__)
-
-
- def print_rank0(msg):
- logger.info(msg)
-
-
- if __name__ == "__main__":
- # argument parser
- parser = argparse.ArgumentParser()
- parser.add_argument('--model', type=str, required=True,
- help='model to be tested')
- parser.add_argument('--dataset', type=str, required=True,
- help='dataset to be trained on')
- parser.add_argument('--batch-size', type=int,
- default=128, help='batch size')
- parser.add_argument('--learning-rate', type=float,
- default=0.1, help='learning rate')
- parser.add_argument('--opt', type=str, default='sgd',
- help='optimizer to be used, default sgd; sgd / momentum / adagrad / adam')
- parser.add_argument('--num-epochs', type=int,
- default=20, help='epoch number')
- parser.add_argument('--gpu', type=int, default=0,
- help='gpu to be used, -1 means cpu')
- parser.add_argument('--validate', action='store_true',
- help='whether to use validation')
- parser.add_argument('--timing', action='store_true',
- help='whether to time the training phase')
- args = parser.parse_args()
-
- if args.gpu == -1:
- device = '/cpu:0'
- print_rank0('Use CPU.')
- else:
- device = '/gpu:%d' % args.gpu
- print_rank0('Use GPU %d.' % args.gpu)
-
- print_rank0("Training {} on TensorFlow".format(args.model))
- assert args.model in ['tf_cnn_3_layers', 'tf_lenet', 'tf_logreg', 'tf_lstm', 'tf_mlp', 'tf_resnet18', 'tf_resnet34', 'tf_rnn', 'tf_vgg16', 'tf_vgg19'], \
- 'Model not supported now.'
- model = eval('tf_models.' + args.model)
-
- assert args.dataset in ['MNIST', 'CIFAR10', 'CIFAR100', 'ImageNet']
- dataset = args.dataset
-
- assert args.opt in ['sgd', 'momentum', 'nesterov',
- 'adagrad', 'adam'], 'Optimizer not supported!'
- if args.opt == 'sgd':
- print_rank0('Use SGD Optimizer.')
- opt = tf.train.GradientDescentOptimizer(
- learning_rate=args.learning_rate)
- elif args.opt == 'momentum':
- print_rank0('Use Momentum Optimizer.')
- opt = tf.train.MomentumOptimizer(
- learning_rate=args.learning_rate, momentum=0.9)
- elif args.opt == 'nesterov':
- print_rank0('Use Nesterov Momentum Optimizer.')
- opt = tf.train.MomentumOptimizer(
- learning_rate=args.learning_rate, momentum=0.9, use_nesterov=True)
- elif args.opt == 'adagrad':
- print_rank0('Use AdaGrad Optimizer.')
- opt = tf.train.AdagradOptimizer(learning_rate=args.learning_rate)
- else:
- print_rank0('Use Adam Optimizer.')
- opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
-
- # model definition
- print_rank0('Building model...')
- with tf.device(device):
- if dataset == 'MNIST':
- x = tf.placeholder(dtype=tf.float32, shape=(None, 784), name='x')
- y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10), name='y_')
- loss, y = model(x, y_)
- elif dataset == 'CIFAR10':
- if args.model == "tf_mlp":
- x = tf.placeholder(
- dtype=tf.float32, shape=(None, 3072), name='x')
- y_ = tf.placeholder(
- dtype=tf.float32, shape=(None, 10), name='y_')
- else:
- x = tf.placeholder(dtype=tf.float32, shape=(
- None, 32, 32, 3), name='x')
- y_ = tf.placeholder(
- dtype=tf.float32, shape=(None, 10), name='y_')
- loss, y = model(x, y_, 10)
- elif dataset == 'CIFAR100':
- x = tf.placeholder(dtype=tf.float32, shape=(
- None, 32, 32, 3), name='x')
- y_ = tf.placeholder(dtype=tf.float32, shape=(None, 100), name='y_')
- loss, y = model(x, y_, 100)
-
- train_op = opt.minimize(loss)
-
- # data loading
- print_rank0('Loading %s data...' % dataset)
- if dataset == 'MNIST':
- datasets = ht.data.mnist()
- train_set_x, train_set_y = datasets[0]
- valid_set_x, valid_set_y = datasets[1]
- test_set_x, test_set_y = datasets[2]
- n_train_batches = train_set_x.shape[0] // args.batch_size
- n_valid_batches = valid_set_x.shape[0] // args.batch_size
- # train_set_x: (50000, 784), train_set_y: (50000,)
- # valid_set_x: (10000, 784), valid_set_y: (10000,)
- elif dataset == 'CIFAR10':
- train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.tf_normalize_cifar(
- num_class=10)
- n_train_batches = train_set_x.shape[0] // args.batch_size
- n_valid_batches = valid_set_x.shape[0] // args.batch_size
- if args.model == "tf_mlp":
- train_set_x = train_set_x.reshape(train_set_x.shape[0], -1)
- valid_set_x = valid_set_x.reshape(valid_set_x.shape[0], -1)
- # train_set_x: (50000, 32, 32, 3), train_set_y: (50000,)
- # valid_set_x: (10000, 32, 32, 3), valid_set_y: (10000,)
- elif dataset == 'CIFAR100':
- train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.tf_normalize_cifar(
- num_class=100)
- n_train_batches = train_set_x.shape[0] // args.batch_size
- n_valid_batches = valid_set_x.shape[0] // args.batch_size
- # train_set_x: (50000, 32, 32, 3), train_set_y: (50000,)
- # valid_set_x: (10000, 32, 32, 3), valid_set_y: (10000,)
- else:
- raise NotImplementedError
-
- # training
- print_rank0("Start training loop...")
- running_time = 0
- with tf.Session() as sess:
- sess.run(tf.global_variables_initializer())
- for i in range(args.num_epochs + 1):
- print_rank0("Epoch %d" % i)
- loss_all = 0
- batch_num = 0
- if args.timing:
- start = time()
- correct_predictions = []
- for minibatch_index in range(n_train_batches):
- minibatch_start = minibatch_index * args.batch_size
- minibatch_end = (minibatch_index + 1) * args.batch_size
- x_val = train_set_x[minibatch_start:minibatch_end]
- y_val = train_set_y[minibatch_start:minibatch_end]
- loss_val, predict_y, _ = sess.run([loss, y, train_op],
- feed_dict={x: x_val, y_: y_val})
- correct_prediction = np.equal(
- np.argmax(y_val, 1),
- np.argmax(predict_y, 1)).astype(np.float32)
- correct_predictions.extend(correct_prediction)
- batch_num += 1
- loss_all += loss_val
- loss_all /= batch_num
- accuracy = np.mean(correct_predictions)
- print_rank0("Train loss = %f" % loss_all)
- print_rank0("Train accuracy = %f" % accuracy)
-
- if args.timing:
- end = time()
- print_rank0("Running time of current epoch = %fs" %
- (end - start))
- if i != 0:
- running_time += (end - start)
-
- if args.validate:
- val_loss_all = 0
- batch_num = 0
- correct_predictions = []
- for minibatch_index in range(n_valid_batches):
- minibatch_start = minibatch_index * args.batch_size
- minibatch_end = (minibatch_index + 1) * args.batch_size
- valid_x_val = valid_set_x[minibatch_start:minibatch_end]
- valid_y_val = valid_set_y[minibatch_start:minibatch_end]
- loss_val, valid_y_predicted = sess.run([loss, y],
- feed_dict={x: valid_x_val, y_: valid_y_val})
- correct_prediction = np.equal(
- np.argmax(valid_y_val, 1),
- np.argmax(valid_y_predicted, 1)).astype(np.float32)
- correct_predictions.extend(correct_prediction)
- val_loss_all += loss_all
- batch_num += 1
- val_loss_all /= batch_num
- accuracy = np.mean(correct_predictions)
- print_rank0("Validation loss = %f" % val_loss_all)
- print_rank0("Validation accuracy = %f" % accuracy)
- print_rank0("*"*50)
- print_rank0("Running time of total %d epoch = %fs" %
- (args.num_epochs, running_time))
|