You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LSTM.py 3.9 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import hetu as ht
  2. from hetu import init
  3. import numpy as np
  4. def lstm(x, y_):
  5. '''
  6. LSTM model, for MNIST dataset.
  7. Parameters:
  8. x: Variable(hetu.gpu_ops.Node.Node), shape (N, dims)
  9. y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  10. Return:
  11. loss: Variable(hetu.gpu_ops.Node.Node), shape (1,)
  12. y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes)
  13. '''
  14. diminput = 28
  15. dimhidden = 128
  16. dimoutput = 10
  17. nsteps = 28
  18. forget_gate_w = init.random_normal(
  19. shape=(diminput, dimhidden), stddev=0.1, name="lstm_forget_gate_w")
  20. forget_gate_u = init.random_normal(
  21. shape=(dimhidden, dimhidden), stddev=0.1, name="lstm_forget_gate_u")
  22. forget_gate_b = init.random_normal(
  23. shape=(dimhidden,), stddev=0.1, name="lstm_forget_gate_b")
  24. input_gate_w = init.random_normal(
  25. shape=(diminput, dimhidden), stddev=0.1, name="lstm_input_gate_w")
  26. input_gate_u = init.random_normal(
  27. shape=(dimhidden, dimhidden), stddev=0.1, name="lstm_input_gate_u")
  28. input_gate_b = init.random_normal(
  29. shape=(dimhidden,), stddev=0.1, name="lstm_input_gate_b")
  30. output_gate_w = init.random_normal(
  31. shape=(diminput, dimhidden), stddev=0.1, name="lstm_output_gate_w")
  32. output_gate_u = init.random_normal(
  33. shape=(dimhidden, dimhidden), stddev=0.1, name="lstm_output_gate_u")
  34. output_gate_b = init.random_normal(
  35. shape=(dimhidden,), stddev=0.1, name="lstm_output_gate_b")
  36. tanh_w = init.random_normal(
  37. shape=(diminput, dimhidden), stddev=0.1, name="lstm_tanh_w")
  38. tanh_u = init.random_normal(
  39. shape=(dimhidden, dimhidden), stddev=0.1, name="lstm_tanh_u")
  40. tanh_b = init.random_normal(
  41. shape=(dimhidden,), stddev=0.1, name="lstm_tanh_b")
  42. out_weights = init.random_normal(
  43. shape=(dimhidden, dimoutput), stddev=0.1, name="lstm_out_weight")
  44. out_bias = init.random_normal(
  45. shape=(dimoutput,), stddev=0.1, name="lstm_out_bias")
  46. initial_state = ht.Variable(value=np.zeros((1,)).astype(
  47. np.float32), name='initial_state', trainable=False)
  48. for i in range(nsteps):
  49. cur_x = ht.slice_op(x, (0, i * diminput), (-1, diminput))
  50. # forget gate
  51. if i == 0:
  52. temp = ht.matmul_op(cur_x, forget_gate_w)
  53. last_c_state = ht.broadcastto_op(initial_state, temp)
  54. last_h_state = ht.broadcastto_op(initial_state, temp)
  55. cur_forget = ht.matmul_op(last_h_state, forget_gate_u) + temp
  56. else:
  57. cur_forget = ht.matmul_op(
  58. last_h_state, forget_gate_u) + ht.matmul_op(cur_x, forget_gate_w)
  59. cur_forget = cur_forget + ht.broadcastto_op(forget_gate_b, cur_forget)
  60. cur_forget = ht.sigmoid_op(cur_forget)
  61. # input gate
  62. cur_input = ht.matmul_op(
  63. last_h_state, input_gate_u) + ht.matmul_op(cur_x, input_gate_w)
  64. cur_input = cur_input + ht.broadcastto_op(input_gate_b, cur_input)
  65. cur_input = ht.sigmoid_op(cur_input)
  66. # output gate
  67. cur_output = ht.matmul_op(
  68. last_h_state, output_gate_u) + ht.matmul_op(cur_x, output_gate_w)
  69. cur_output = cur_output + ht.broadcastto_op(output_gate_b, cur_output)
  70. cur_output = ht.sigmoid_op(cur_output)
  71. # tanh
  72. cur_tanh = ht.matmul_op(last_h_state, tanh_u) + \
  73. ht.matmul_op(cur_x, tanh_w)
  74. cur_tanh = cur_tanh + ht.broadcastto_op(tanh_b, cur_tanh)
  75. cur_tanh = ht.tanh_op(cur_tanh)
  76. last_c_state = ht.mul_op(last_c_state, cur_forget) + \
  77. ht.mul_op(cur_input, cur_tanh)
  78. last_h_state = ht.tanh_op(last_c_state) * cur_output
  79. x = ht.matmul_op(last_h_state, out_weights)
  80. y = x + ht.broadcastto_op(out_bias, x)
  81. loss = ht.softmaxcrossentropy_op(y, y_)
  82. loss = ht.reduce_mean_op(loss, [0])
  83. return loss, y

分布式深度学习系统

Contributors (1)