|
|
@@ -87,8 +87,8 @@ namespace Tensorflow |
|
|
|
// array_ops.split(value: state, num_or_size_splits: 2, axis: one); |
|
|
|
throw new NotImplementedException("BasicLstmCell call"); |
|
|
|
} |
|
|
|
var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel as RefVariable); |
|
|
|
gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); |
|
|
|
var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel.AsTensor()); |
|
|
|
gate_inputs = nn_ops.bias_add(gate_inputs, _bias.AsTensor()); |
|
|
|
|
|
|
|
// i = input_gate, j = new_input, f = forget_gate, o = output_gate |
|
|
|
var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one); |
|
|
|