You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

minst_lstm.py 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import tensorflow as tf
  2. from tensorflow.contrib import rnn
  3. #import mnist dataset
  4. from tensorflow.examples.tutorials.mnist import input_data
  5. mnist=input_data.read_data_sets("/tmp/data/",one_hot=True)
  6. #define constants
  7. #unrolled through 28 time steps
  8. time_steps=28
  9. #hidden LSTM units
  10. num_units=128
  11. #rows of 28 pixels
  12. n_input=28
  13. #learning rate for adam
  14. learning_rate=0.001
  15. #mnist is meant to be classified in 10 classes(0-9).
  16. n_classes=10
  17. #size of batch
  18. batch_size=128
  19. #weights and biases of appropriate shape to accomplish above task
  20. out_weights=tf.Variable(tf.random_normal([num_units,n_classes]))
  21. out_bias=tf.Variable(tf.random_normal([n_classes]))
  22. #defining placeholders
  23. #input image placeholder
  24. x=tf.placeholder("float",[None,time_steps,n_input])
  25. #input label placeholder
  26. y=tf.placeholder("float",[None,n_classes])
  27. #processing the input tensor from [batch_size,n_steps,n_input] to "time_steps" number of [batch_size,n_input] tensors
  28. input=tf.unstack(x ,time_steps,1)
  29. #defining the network
  30. lstm_layer=rnn.BasicLSTMCell(num_units,forget_bias=1)
  31. outputs,_=rnn.static_rnn(lstm_layer,input,dtype="float32")
  32. #converting last output of dimension [batch_size,num_units] to [batch_size,n_classes] by out_weight multiplication
  33. prediction=tf.matmul(outputs[-1],out_weights)+out_bias
  34. #loss_function
  35. loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
  36. #optimization
  37. opt=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
  38. #model evaluation
  39. correct_prediction=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
  40. accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  41. #initialize variables
  42. init=tf.global_variables_initializer()
  43. with tf.Session() as sess:
  44. sess.run(init)
  45. iter=1
  46. while iter<800:
  47. batch_x,batch_y=mnist.train.next_batch(batch_size=batch_size)
  48. batch_x=batch_x.reshape((batch_size,time_steps,n_input))
  49. sess.run(opt, feed_dict={x: batch_x, y: batch_y})
  50. if iter %10==0:
  51. acc=sess.run(accuracy,feed_dict={x:batch_x,y:batch_y})
  52. los=sess.run(loss,feed_dict={x:batch_x,y:batch_y})
  53. print("For iter ",iter)
  54. print("Accuracy ",acc)
  55. print("Loss ",los)
  56. print("__________________")
  57. iter=iter+1
  58. #calculating test accuracy
  59. test_data = mnist.test.images[:128].reshape((-1, time_steps, n_input))
  60. test_label = mnist.test.labels[:128]
  61. print("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))