You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_LSTM.py 3.4 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import numpy as np
  2. import tensorflow as tf
  3. def tf_lstm(x, y_):
  4. '''
  5. LSTM model in TensorFlow, for MNIST dataset.
  6. Parameters:
  7. x: Variable(tensorflow.python.framework.ops.Tensor), shape (N, dims)
  8. y_: Variable(tensorflow.python.framework.ops.Tensor), shape (N, num_classes)
  9. Return:
  10. loss: Variable(tensorflow.python.framework.ops.Tensor), shape (1,)
  11. y: Variable(tensorflow.python.framework.ops.Tensor), shape (N, num_classes)
  12. '''
  13. print("Building LSTM model in tensorflow...")
  14. diminput = 28
  15. dimhidden = 128
  16. dimoutput = 10
  17. nsteps = 28
  18. forget_gate_w = tf.Variable(np.random.normal(
  19. scale=0.1, size=(diminput, dimhidden)).astype(np.float32))
  20. forget_gate_u = tf.Variable(np.random.normal(
  21. scale=0.1, size=(dimhidden, dimhidden)).astype(np.float32))
  22. forget_gate_b = tf.Variable(np.random.normal(
  23. scale=0.1, size=(dimhidden,)).astype(np.float32))
  24. input_gate_w = tf.Variable(np.random.normal(
  25. scale=0.1, size=(diminput, dimhidden)).astype(np.float32))
  26. input_gate_u = tf.Variable(np.random.normal(
  27. scale=0.1, size=(dimhidden, dimhidden)).astype(np.float32))
  28. input_gate_b = tf.Variable(np.random.normal(
  29. scale=0.1, size=(dimhidden,)).astype(np.float32))
  30. output_gate_w = tf.Variable(np.random.normal(
  31. scale=0.1, size=(diminput, dimhidden)).astype(np.float32))
  32. output_gate_u = tf.Variable(np.random.normal(
  33. scale=0.1, size=(dimhidden, dimhidden)).astype(np.float32))
  34. output_gate_b = tf.Variable(np.random.normal(
  35. scale=0.1, size=(dimhidden,)).astype(np.float32))
  36. tanh_w = tf.Variable(np.random.normal(
  37. scale=0.1, size=(diminput, dimhidden)).astype(np.float32))
  38. tanh_u = tf.Variable(np.random.normal(
  39. scale=0.1, size=(dimhidden, dimhidden)).astype(np.float32))
  40. tanh_b = tf.Variable(np.random.normal(
  41. scale=0.1, size=(dimhidden,)).astype(np.float32))
  42. out_weights = tf.Variable(np.random.normal(
  43. scale=0.1, size=(dimhidden, dimoutput)).astype(np.float32))
  44. out_bias = tf.Variable(np.random.normal(
  45. scale=0.1, size=(dimoutput,)).astype(np.float32))
  46. initial_state = tf.zeros((tf.shape(x)[0], dimhidden), dtype=tf.float32)
  47. last_c_state = initial_state
  48. last_h_state = initial_state
  49. for i in range(nsteps):
  50. cur_x = tf.slice(x, (0, i * diminput), (-1, diminput))
  51. # forget gate
  52. cur_forget = tf.matmul(last_h_state, forget_gate_u) + \
  53. tf.matmul(cur_x, forget_gate_w) + forget_gate_b
  54. cur_forget = tf.sigmoid(cur_forget)
  55. # input gate
  56. cur_input = tf.matmul(last_h_state, input_gate_u) + \
  57. tf.matmul(cur_x, input_gate_w) + input_gate_b
  58. cur_input = tf.sigmoid(cur_input)
  59. # output gate
  60. cur_output = tf.matmul(last_h_state, output_gate_u) + \
  61. tf.matmul(cur_x, output_gate_w) + output_gate_b
  62. cur_output = tf.sigmoid(cur_output)
  63. # tanh
  64. cur_tanh = tf.matmul(last_h_state, tanh_u) + \
  65. tf.matmul(cur_x, tanh_w) + tanh_b
  66. cur_tanh = tf.tanh(cur_tanh)
  67. last_c_state = last_c_state * cur_forget + cur_input * cur_tanh
  68. last_h_state = tf.tanh(last_c_state) * cur_output
  69. y = tf.matmul(last_h_state, out_weights) + out_bias
  70. loss = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_)
  71. loss = tf.reduce_mean(loss)
  72. return loss, y