You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RNN.py 1.9 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import hetu as ht
  2. from hetu import init
  3. import numpy as np
  4. def rnn(x, y_):
  5. '''
  6. RNN model, for MNIST dataset.
  7. Parameters:
  8. x: Variable(hetu.gpu_ops.Node.Node), shape (N, dims)
  9. y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  10. Return:
  11. loss: Variable(hetu.gpu_ops.Node.Node), shape (1,)
  12. y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  13. '''
  14. print("Building RNN model...")
  15. diminput = 28
  16. dimhidden = 128
  17. dimoutput = 10
  18. nsteps = 28
  19. weight1 = init.random_normal(
  20. shape=(diminput, dimhidden), stddev=0.1, name='rnn_weight1')
  21. bias1 = init.random_normal(
  22. shape=(dimhidden, ), stddev=0.1, name='rnn_bias1')
  23. weight2 = init.random_normal(
  24. shape=(dimhidden+dimhidden, dimhidden), stddev=0.1, name='rnn_weight2')
  25. bias2 = init.random_normal(
  26. shape=(dimhidden, ), stddev=0.1, name='rnn_bias2')
  27. weight3 = init.random_normal(
  28. shape=(dimhidden, dimoutput), stddev=0.1, name='rnn_weight3')
  29. bias3 = init.random_normal(
  30. shape=(dimoutput, ), stddev=0.1, name='rnn_bias3')
  31. last_state = ht.Variable(value=np.zeros((1,)).astype(
  32. np.float32), name='initial_state', trainable=False)
  33. for i in range(nsteps):
  34. cur_x = ht.slice_op(x, (0, i*diminput), (-1, diminput))
  35. h = ht.matmul_op(cur_x, weight1)
  36. h = h + ht.broadcastto_op(bias1, h)
  37. if i == 0:
  38. last_state = ht.broadcastto_op(last_state, h)
  39. s = ht.concat_op(h, last_state, axis=1)
  40. s = ht.matmul_op(s, weight2)
  41. s = s + ht.broadcastto_op(bias2, s)
  42. last_state = ht.relu_op(s)
  43. final_state = last_state
  44. x = ht.matmul_op(final_state, weight3)
  45. y = x + ht.broadcastto_op(bias3, x)
  46. loss = ht.softmaxcrossentropy_op(y, y_)
  47. loss = ht.reduce_mean_op(loss, [0])
  48. return loss, y

分布式深度学习系统

Contributors (1)