You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_main.py 8.8 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import tensorflow as tf
  2. import tf_models
  3. import hetu as ht
  4. import numpy as np
  5. import argparse
  6. from time import time
  7. import logging
  8. logging.basicConfig(level=logging.INFO,
  9. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  10. logger = logging.getLogger(__name__)
  11. def print_rank0(msg):
  12. logger.info(msg)
  13. if __name__ == "__main__":
  14. # argument parser
  15. parser = argparse.ArgumentParser()
  16. parser.add_argument('--model', type=str, required=True,
  17. help='model to be tested')
  18. parser.add_argument('--dataset', type=str, required=True,
  19. help='dataset to be trained on')
  20. parser.add_argument('--batch-size', type=int,
  21. default=128, help='batch size')
  22. parser.add_argument('--learning-rate', type=float,
  23. default=0.1, help='learning rate')
  24. parser.add_argument('--opt', type=str, default='sgd',
  25. help='optimizer to be used, default sgd; sgd / momentum / adagrad / adam')
  26. parser.add_argument('--num-epochs', type=int,
  27. default=20, help='epoch number')
  28. parser.add_argument('--gpu', type=int, default=0,
  29. help='gpu to be used, -1 means cpu')
  30. parser.add_argument('--validate', action='store_true',
  31. help='whether to use validation')
  32. parser.add_argument('--timing', action='store_true',
  33. help='whether to time the training phase')
  34. args = parser.parse_args()
  35. if args.gpu == -1:
  36. device = '/cpu:0'
  37. print_rank0('Use CPU.')
  38. else:
  39. device = '/gpu:%d' % args.gpu
  40. print_rank0('Use GPU %d.' % args.gpu)
  41. print_rank0("Training {} on TensorFlow".format(args.model))
  42. assert args.model in ['tf_cnn_3_layers', 'tf_lenet', 'tf_logreg', 'tf_lstm', 'tf_mlp', 'tf_resnet18', 'tf_resnet34', 'tf_rnn', 'tf_vgg16', 'tf_vgg19'], \
  43. 'Model not supported now.'
  44. model = eval('tf_models.' + args.model)
  45. assert args.dataset in ['MNIST', 'CIFAR10', 'CIFAR100', 'ImageNet']
  46. dataset = args.dataset
  47. assert args.opt in ['sgd', 'momentum', 'nesterov',
  48. 'adagrad', 'adam'], 'Optimizer not supported!'
  49. if args.opt == 'sgd':
  50. print_rank0('Use SGD Optimizer.')
  51. opt = tf.train.GradientDescentOptimizer(
  52. learning_rate=args.learning_rate)
  53. elif args.opt == 'momentum':
  54. print_rank0('Use Momentum Optimizer.')
  55. opt = tf.train.MomentumOptimizer(
  56. learning_rate=args.learning_rate, momentum=0.9)
  57. elif args.opt == 'nesterov':
  58. print_rank0('Use Nesterov Momentum Optimizer.')
  59. opt = tf.train.MomentumOptimizer(
  60. learning_rate=args.learning_rate, momentum=0.9, use_nesterov=True)
  61. elif args.opt == 'adagrad':
  62. print_rank0('Use AdaGrad Optimizer.')
  63. opt = tf.train.AdagradOptimizer(learning_rate=args.learning_rate)
  64. else:
  65. print_rank0('Use Adam Optimizer.')
  66. opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
  67. # model definition
  68. print_rank0('Building model...')
  69. with tf.device(device):
  70. if dataset == 'MNIST':
  71. x = tf.placeholder(dtype=tf.float32, shape=(None, 784), name='x')
  72. y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10), name='y_')
  73. loss, y = model(x, y_)
  74. elif dataset == 'CIFAR10':
  75. if args.model == "tf_mlp":
  76. x = tf.placeholder(
  77. dtype=tf.float32, shape=(None, 3072), name='x')
  78. y_ = tf.placeholder(
  79. dtype=tf.float32, shape=(None, 10), name='y_')
  80. else:
  81. x = tf.placeholder(dtype=tf.float32, shape=(
  82. None, 32, 32, 3), name='x')
  83. y_ = tf.placeholder(
  84. dtype=tf.float32, shape=(None, 10), name='y_')
  85. loss, y = model(x, y_, 10)
  86. elif dataset == 'CIFAR100':
  87. x = tf.placeholder(dtype=tf.float32, shape=(
  88. None, 32, 32, 3), name='x')
  89. y_ = tf.placeholder(dtype=tf.float32, shape=(None, 100), name='y_')
  90. loss, y = model(x, y_, 100)
  91. train_op = opt.minimize(loss)
  92. # data loading
  93. print_rank0('Loading %s data...' % dataset)
  94. if dataset == 'MNIST':
  95. datasets = ht.data.mnist()
  96. train_set_x, train_set_y = datasets[0]
  97. valid_set_x, valid_set_y = datasets[1]
  98. test_set_x, test_set_y = datasets[2]
  99. n_train_batches = train_set_x.shape[0] // args.batch_size
  100. n_valid_batches = valid_set_x.shape[0] // args.batch_size
  101. # train_set_x: (50000, 784), train_set_y: (50000,)
  102. # valid_set_x: (10000, 784), valid_set_y: (10000,)
  103. elif dataset == 'CIFAR10':
  104. train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.tf_normalize_cifar(
  105. num_class=10)
  106. n_train_batches = train_set_x.shape[0] // args.batch_size
  107. n_valid_batches = valid_set_x.shape[0] // args.batch_size
  108. if args.model == "tf_mlp":
  109. train_set_x = train_set_x.reshape(train_set_x.shape[0], -1)
  110. valid_set_x = valid_set_x.reshape(valid_set_x.shape[0], -1)
  111. # train_set_x: (50000, 32, 32, 3), train_set_y: (50000,)
  112. # valid_set_x: (10000, 32, 32, 3), valid_set_y: (10000,)
  113. elif dataset == 'CIFAR100':
  114. train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.tf_normalize_cifar(
  115. num_class=100)
  116. n_train_batches = train_set_x.shape[0] // args.batch_size
  117. n_valid_batches = valid_set_x.shape[0] // args.batch_size
  118. # train_set_x: (50000, 32, 32, 3), train_set_y: (50000,)
  119. # valid_set_x: (10000, 32, 32, 3), valid_set_y: (10000,)
  120. else:
  121. raise NotImplementedError
  122. # training
  123. print_rank0("Start training loop...")
  124. running_time = 0
  125. with tf.Session() as sess:
  126. sess.run(tf.global_variables_initializer())
  127. for i in range(args.num_epochs + 1):
  128. print_rank0("Epoch %d" % i)
  129. loss_all = 0
  130. batch_num = 0
  131. if args.timing:
  132. start = time()
  133. correct_predictions = []
  134. for minibatch_index in range(n_train_batches):
  135. minibatch_start = minibatch_index * args.batch_size
  136. minibatch_end = (minibatch_index + 1) * args.batch_size
  137. x_val = train_set_x[minibatch_start:minibatch_end]
  138. y_val = train_set_y[minibatch_start:minibatch_end]
  139. loss_val, predict_y, _ = sess.run([loss, y, train_op],
  140. feed_dict={x: x_val, y_: y_val})
  141. correct_prediction = np.equal(
  142. np.argmax(y_val, 1),
  143. np.argmax(predict_y, 1)).astype(np.float32)
  144. correct_predictions.extend(correct_prediction)
  145. batch_num += 1
  146. loss_all += loss_val
  147. loss_all /= batch_num
  148. accuracy = np.mean(correct_predictions)
  149. print_rank0("Train loss = %f" % loss_all)
  150. print_rank0("Train accuracy = %f" % accuracy)
  151. if args.timing:
  152. end = time()
  153. print_rank0("Running time of current epoch = %fs" %
  154. (end - start))
  155. if i != 0:
  156. running_time += (end - start)
  157. if args.validate:
  158. val_loss_all = 0
  159. batch_num = 0
  160. correct_predictions = []
  161. for minibatch_index in range(n_valid_batches):
  162. minibatch_start = minibatch_index * args.batch_size
  163. minibatch_end = (minibatch_index + 1) * args.batch_size
  164. valid_x_val = valid_set_x[minibatch_start:minibatch_end]
  165. valid_y_val = valid_set_y[minibatch_start:minibatch_end]
  166. loss_val, valid_y_predicted = sess.run([loss, y],
  167. feed_dict={x: valid_x_val, y_: valid_y_val})
  168. correct_prediction = np.equal(
  169. np.argmax(valid_y_val, 1),
  170. np.argmax(valid_y_predicted, 1)).astype(np.float32)
  171. correct_predictions.extend(correct_prediction)
  172. val_loss_all += loss_all
  173. batch_num += 1
  174. val_loss_all /= batch_num
  175. accuracy = np.mean(correct_predictions)
  176. print_rank0("Validation loss = %f" % val_loss_all)
  177. print_rank0("Validation accuracy = %f" % accuracy)
  178. print_rank0("*"*50)
  179. print_rank0("Running time of total %d epoch = %fs" %
  180. (args.num_epochs, running_time))