You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_RNN.py 1.8 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import numpy as np
  2. import tensorflow as tf
  3. def tf_rnn(x, y_):
  4. '''
  5. RNN model in TensorFlow, for MNIST dataset.
  6. Parameters:
  7. x: Variable(tensorflow.python.framework.ops.Tensor), shape (N, dims)
  8. y_: Variable(tensorflow.python.framework.ops.Tensor), shape (N, num_classes)
  9. Return:
  10. loss: Variable(tensorflow.python.framework.ops.Tensor), shape (1,)
  11. y: Variable(tensorflow.python.framework.ops.Tensor), shape (N, num_classes)
  12. '''
  13. print("Building RNN model in tensorflow...")
  14. diminput = 28
  15. dimhidden = 128
  16. dimoutput = 10
  17. nsteps = 28
  18. weight1 = tf.Variable(np.random.normal(
  19. scale=0.1, size=(diminput, dimhidden)).astype(np.float32))
  20. bias1 = tf.Variable(np.random.normal(
  21. scale=0.1, size=(dimhidden, )).astype(np.float32))
  22. weight2 = tf.Variable(np.random.normal(scale=0.1, size=(
  23. dimhidden + dimhidden, dimhidden)).astype(np.float32))
  24. bias2 = tf.Variable(np.random.normal(
  25. scale=0.1, size=(dimhidden, )).astype(np.float32))
  26. weight3 = tf.Variable(np.random.normal(
  27. scale=0.1, size=(dimhidden, dimoutput)).astype(np.float32))
  28. bias3 = tf.Variable(np.random.normal(
  29. scale=0.1, size=(dimoutput, )).astype(np.float32))
  30. last_state = tf.zeros((128, dimhidden), dtype=tf.float32)
  31. for i in range(nsteps):
  32. cur_x = tf.slice(x, (0, i * diminput), (-1, diminput))
  33. h = tf.matmul(cur_x, weight1) + bias1
  34. s = tf.concat([h, last_state], axis=1)
  35. s = tf.matmul(s, weight2) + bias2
  36. last_state = tf.nn.relu(s)
  37. final_state = last_state
  38. y = tf.matmul(final_state, weight3) + bias3
  39. loss = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_)
  40. loss = tf.reduce_mean(loss)
  41. return loss, y

分布式深度学习系统

Contributors (1)