You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

neural_network.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # imports
  2. import tensorflow as tf
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. img_h = img_w = 28 # MNIST images are 28x28
  6. img_size_flat = img_h * img_w # 28x28=784, the total number of pixels
  7. n_classes = 10 # Number of classes, one class per digit
  8. def load_data(mode='train'):
  9. """
  10. Function to (download and) load the MNIST data
  11. :param mode: train or test
  12. :return: images and the corresponding labels
  13. """
  14. from tensorflow.examples.tutorials.mnist import input_data
  15. mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
  16. if mode == 'train':
  17. x_train, y_train, x_valid, y_valid = mnist.train.images, mnist.train.labels, \
  18. mnist.validation.images, mnist.validation.labels
  19. return x_train, y_train, x_valid, y_valid
  20. elif mode == 'test':
  21. x_test, y_test = mnist.test.images, mnist.test.labels
  22. return x_test, y_test
  23. def randomize(x, y):
  24. """ Randomizes the order of data samples and their corresponding labels"""
  25. permutation = np.random.permutation(y.shape[0])
  26. shuffled_x = x[permutation, :]
  27. shuffled_y = y[permutation]
  28. return shuffled_x, shuffled_y
  29. def get_next_batch(x, y, start, end):
  30. x_batch = x[start:end]
  31. y_batch = y[start:end]
  32. return x_batch, y_batch
  33. # Load MNIST data
  34. x_train, y_train, x_valid, y_valid = load_data(mode='train')
  35. print("Size of:")
  36. print("- Training-set:\t\t{}".format(len(y_train)))
  37. print("- Validation-set:\t{}".format(len(y_valid)))
  38. print('x_train:\t{}'.format(x_train.shape))
  39. print('y_train:\t{}'.format(y_train.shape))
  40. print('x_train:\t{}'.format(x_valid.shape))
  41. print('y_valid:\t{}'.format(y_valid.shape))
  42. print(y_valid[:5, :])
  43. # Hyper-parameters
  44. epochs = 10 # Total number of training epochs
  45. batch_size = 100 # Training batch size
  46. display_freq = 100 # Frequency of displaying the training results
  47. learning_rate = 0.001 # The optimization initial learning rate
  48. h1 = 200 # number of nodes in the 1st hidden layer
  49. # weight and bais wrappers
  50. def weight_variable(name, shape):
  51. """
  52. Create a weight variable with appropriate initialization
  53. :param name: weight name
  54. :param shape: weight shape
  55. :return: initialized weight variable
  56. """
  57. initer = tf.truncated_normal_initializer(stddev=0.01)
  58. return tf.get_variable('W_' + name,
  59. dtype=tf.float32,
  60. shape=shape,
  61. initializer=initer)
  62. def bias_variable(name, shape):
  63. """
  64. Create a bias variable with appropriate initialization
  65. :param name: bias variable name
  66. :param shape: bias variable shape
  67. :return: initialized bias variable
  68. """
  69. initial = tf.constant(0., shape=shape, dtype=tf.float32)
  70. return tf.get_variable('b_' + name,
  71. dtype=tf.float32,
  72. initializer=initial)
  73. def fc_layer(x, num_units, name, use_relu=True):
  74. """
  75. Create a fully-connected layer
  76. :param x: input from previous layer
  77. :param num_units: number of hidden units in the fully-connected layer
  78. :param name: layer name
  79. :param use_relu: boolean to add ReLU non-linearity (or not)
  80. :return: The output array
  81. """
  82. in_dim = x.get_shape()[1]
  83. W = weight_variable(name, shape=[in_dim, num_units])
  84. b = bias_variable(name, [num_units])
  85. layer = tf.matmul(x, W)
  86. layer += b
  87. if use_relu:
  88. layer = tf.nn.relu(layer)
  89. return layer
  90. # Create the graph for the linear model
  91. # Placeholders for inputs (x) and outputs(y)
  92. x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='X')
  93. y = tf.placeholder(tf.float32, shape=[None, n_classes], name='Y')
  94. # Create a fully-connected layer with h1 nodes as hidden layer
  95. fc1 = fc_layer(x, h1, 'FC1', use_relu=True)
  96. # Create a fully-connected layer with n_classes nodes as output layer
  97. output_logits = fc_layer(fc1, n_classes, 'OUT', use_relu=False)
  98. # Define the loss function, optimizer, and accuracy
  99. logits = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=output_logits)
  100. loss = tf.reduce_mean(logits, name='loss')
  101. optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, name='Adam-op').minimize(loss)
  102. correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name='correct_pred')
  103. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')
  104. # Network predictions
  105. cls_prediction = tf.argmax(output_logits, axis=1, name='predictions')
  106. # export graph
  107. #tf.train.export_meta_graph(filename='neural_network.meta', graph=tf.get_default_graph(), clear_extraneous_savers= True, as_text = True)
  108. # Create the op for initializing all variables
  109. init = tf.global_variables_initializer()
  110. # Create an interactive session (to keep the session in the other cells)
  111. sess = tf.InteractiveSession()
  112. # Initialize all variables
  113. sess.run(init)
  114. # Number of training iterations in each epoch
  115. num_tr_iter = int(len(y_train) / batch_size)
  116. for epoch in range(epochs):
  117. print('Training epoch: {}'.format(epoch + 1))
  118. # Randomly shuffle the training data at the beginning of each epoch
  119. x_train, y_train = randomize(x_train, y_train)
  120. for iteration in range(num_tr_iter):
  121. start = iteration * batch_size
  122. end = (iteration + 1) * batch_size
  123. x_batch, y_batch = get_next_batch(x_train, y_train, start, end)
  124. # Run optimization op (backprop)
  125. feed_dict_batch = {x: x_batch, y: y_batch}
  126. sess.run(optimizer, feed_dict=feed_dict_batch)
  127. if iteration % display_freq == 0:
  128. # Calculate and display the batch loss and accuracy
  129. loss_batch, acc_batch = sess.run([loss, accuracy],
  130. feed_dict=feed_dict_batch)
  131. print("iter {0:3d}:\t Loss={1:.2f},\tTraining Accuracy={2:.01%}".
  132. format(iteration, loss_batch, acc_batch))
  133. # Run validation after every epoch
  134. feed_dict_valid = {x: x_valid[:1000], y: y_valid[:1000]}
  135. loss_valid, acc_valid = sess.run([loss, accuracy], feed_dict=feed_dict_valid)
  136. print('---------------------------------------------------------')
  137. print("Epoch: {0}, validation loss: {1:.2f}, validation accuracy: {2:.01%}".
  138. format(epoch + 1, loss_valid, acc_valid))
  139. print('---------------------------------------------------------')