import hetu as ht from hetu import init import numpy as np def lstm(x, y_): ''' LSTM model, for MNIST dataset. Parameters: x: Variable(hetu.gpu_ops.Node.Node), shape (N, dims) y_: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes) Return: loss: Variable(hetu.gpu_ops.Node.Node), shape (1,) y: Variable(hetu.gpu_ops.Node.Node), shape (N, num_classes) ''' diminput = 28 dimhidden = 128 dimoutput = 10 nsteps = 28 forget_gate_w = init.random_normal( shape=(diminput, dimhidden), stddev=0.1, name="lstm_forget_gate_w") forget_gate_u = init.random_normal( shape=(dimhidden, dimhidden), stddev=0.1, name="lstm_forget_gate_u") forget_gate_b = init.random_normal( shape=(dimhidden,), stddev=0.1, name="lstm_forget_gate_b") input_gate_w = init.random_normal( shape=(diminput, dimhidden), stddev=0.1, name="lstm_input_gate_w") input_gate_u = init.random_normal( shape=(dimhidden, dimhidden), stddev=0.1, name="lstm_input_gate_u") input_gate_b = init.random_normal( shape=(dimhidden,), stddev=0.1, name="lstm_input_gate_b") output_gate_w = init.random_normal( shape=(diminput, dimhidden), stddev=0.1, name="lstm_output_gate_w") output_gate_u = init.random_normal( shape=(dimhidden, dimhidden), stddev=0.1, name="lstm_output_gate_u") output_gate_b = init.random_normal( shape=(dimhidden,), stddev=0.1, name="lstm_output_gate_b") tanh_w = init.random_normal( shape=(diminput, dimhidden), stddev=0.1, name="lstm_tanh_w") tanh_u = init.random_normal( shape=(dimhidden, dimhidden), stddev=0.1, name="lstm_tanh_u") tanh_b = init.random_normal( shape=(dimhidden,), stddev=0.1, name="lstm_tanh_b") out_weights = init.random_normal( shape=(dimhidden, dimoutput), stddev=0.1, name="lstm_out_weight") out_bias = init.random_normal( shape=(dimoutput,), stddev=0.1, name="lstm_out_bias") initial_state = ht.Variable(value=np.zeros((1,)).astype( np.float32), name='initial_state', trainable=False) for i in range(nsteps): cur_x = ht.slice_op(x, (0, i * diminput), (-1, diminput)) # forget gate if i == 0: temp = ht.matmul_op(cur_x, forget_gate_w) last_c_state = ht.broadcastto_op(initial_state, temp) last_h_state = ht.broadcastto_op(initial_state, temp) cur_forget = ht.matmul_op(last_h_state, forget_gate_u) + temp else: cur_forget = ht.matmul_op( last_h_state, forget_gate_u) + ht.matmul_op(cur_x, forget_gate_w) cur_forget = cur_forget + ht.broadcastto_op(forget_gate_b, cur_forget) cur_forget = ht.sigmoid_op(cur_forget) # input gate cur_input = ht.matmul_op( last_h_state, input_gate_u) + ht.matmul_op(cur_x, input_gate_w) cur_input = cur_input + ht.broadcastto_op(input_gate_b, cur_input) cur_input = ht.sigmoid_op(cur_input) # output gate cur_output = ht.matmul_op( last_h_state, output_gate_u) + ht.matmul_op(cur_x, output_gate_w) cur_output = cur_output + ht.broadcastto_op(output_gate_b, cur_output) cur_output = ht.sigmoid_op(cur_output) # tanh cur_tanh = ht.matmul_op(last_h_state, tanh_u) + \ ht.matmul_op(cur_x, tanh_w) cur_tanh = cur_tanh + ht.broadcastto_op(tanh_b, cur_tanh) cur_tanh = ht.tanh_op(cur_tanh) last_c_state = ht.mul_op(last_c_state, cur_forget) + \ ht.mul_op(cur_input, cur_tanh) last_h_state = ht.tanh_op(last_c_state) * cur_output x = ht.matmul_op(last_h_state, out_weights) y = x + ht.broadcastto_op(out_bias, x) loss = ht.softmaxcrossentropy_op(y, y_) loss = ht.reduce_mean_op(loss, [0]) return loss, y