|
@@ -0,0 +1,282 @@ |
|
|
|
|
|
using System; |
|
|
|
|
|
using System.Collections.Generic; |
|
|
|
|
|
using System.Diagnostics; |
|
|
|
|
|
using System.Text; |
|
|
|
|
|
using Tensorflow.Keras.ArgsDefinition; |
|
|
|
|
|
using Tensorflow.Keras.ArgsDefinition.Rnn; |
|
|
|
|
|
using Tensorflow.Common.Extensions; |
|
|
|
|
|
using Tensorflow.Common.Types; |
|
|
|
|
|
using Tensorflow.Keras.Saving; |
|
|
|
|
|
|
|
|
|
|
|
namespace Tensorflow.Keras.Layers.Rnn |
|
|
|
|
|
{ |
|
|
|
|
|
/// <summary> |
|
|
|
|
|
/// Cell class for the GRU layer. |
|
|
|
|
|
/// </summary> |
|
|
|
|
|
public class GRUCell : DropoutRNNCellMixin |
|
|
|
|
|
{ |
|
|
|
|
|
GRUCellArgs _args; |
|
|
|
|
|
IVariableV1 _kernel; |
|
|
|
|
|
IVariableV1 _recurrent_kernel; |
|
|
|
|
|
IInitializer _bias_initializer; |
|
|
|
|
|
IVariableV1 _bias; |
|
|
|
|
|
INestStructure<long> _state_size; |
|
|
|
|
|
INestStructure<long> _output_size; |
|
|
|
|
|
int Units; |
|
|
|
|
|
public override INestStructure<long> StateSize => _state_size; |
|
|
|
|
|
|
|
|
|
|
|
public override INestStructure<long> OutputSize => _output_size; |
|
|
|
|
|
|
|
|
|
|
|
public override bool SupportOptionalArgs => false; |
|
|
|
|
|
public GRUCell(GRUCellArgs args) : base(args) |
|
|
|
|
|
{ |
|
|
|
|
|
_args = args; |
|
|
|
|
|
if (_args.Units <= 0) |
|
|
|
|
|
{ |
|
|
|
|
|
throw new ValueError( |
|
|
|
|
|
$"units must be a positive integer, got {args.Units}"); |
|
|
|
|
|
} |
|
|
|
|
|
_args.Dropout = Math.Min(1f, Math.Max(0f, _args.Dropout)); |
|
|
|
|
|
_args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); |
|
|
|
|
|
if (_args.RecurrentDropout != 0f && _args.Implementation != 1) |
|
|
|
|
|
{ |
|
|
|
|
|
Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." + |
|
|
|
|
|
"Using `implementation=1`."); |
|
|
|
|
|
_args.Implementation = 1; |
|
|
|
|
|
} |
|
|
|
|
|
Units = _args.Units; |
|
|
|
|
|
_state_size = new NestList<long>(Units); |
|
|
|
|
|
_output_size = new NestNode<long>(Units); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public override void build(KerasShapesWrapper input_shape) |
|
|
|
|
|
{ |
|
|
|
|
|
//base.build(input_shape); |
|
|
|
|
|
|
|
|
|
|
|
var single_shape = input_shape.ToSingleShape(); |
|
|
|
|
|
var input_dim = single_shape[-1]; |
|
|
|
|
|
|
|
|
|
|
|
_kernel = add_weight("kernel", (input_dim, _args.Units * 3), |
|
|
|
|
|
initializer: _args.KernelInitializer |
|
|
|
|
|
); |
|
|
|
|
|
|
|
|
|
|
|
_recurrent_kernel = add_weight("recurrent_kernel", (Units, Units * 3), |
|
|
|
|
|
initializer: _args.RecurrentInitializer |
|
|
|
|
|
); |
|
|
|
|
|
if (_args.UseBias) |
|
|
|
|
|
{ |
|
|
|
|
|
Shape bias_shape; |
|
|
|
|
|
if (!_args.ResetAfter) |
|
|
|
|
|
{ |
|
|
|
|
|
bias_shape = new Shape(3 * Units); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
bias_shape = (2, 3 * Units); |
|
|
|
|
|
} |
|
|
|
|
|
_bias = add_weight("bias", bias_shape, |
|
|
|
|
|
initializer: _bias_initializer |
|
|
|
|
|
); |
|
|
|
|
|
} |
|
|
|
|
|
built = true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) |
|
|
|
|
|
{ |
|
|
|
|
|
var h_tm1 = states.IsNested() ? states[0] : states.Single(); |
|
|
|
|
|
var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 3); |
|
|
|
|
|
var rec_dp_mask = get_recurrent_dropout_mask_for_cell(h_tm1, training.Value, count: 3); |
|
|
|
|
|
|
|
|
|
|
|
IVariableV1 input_bias = _bias; |
|
|
|
|
|
IVariableV1 recurrent_bias = _bias; |
|
|
|
|
|
if (_args.UseBias) |
|
|
|
|
|
{ |
|
|
|
|
|
if (!_args.ResetAfter) |
|
|
|
|
|
{ |
|
|
|
|
|
input_bias = _bias; |
|
|
|
|
|
recurrent_bias = null; |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
input_bias = tf.Variable(tf.unstack(_bias.AsTensor())[0]); |
|
|
|
|
|
recurrent_bias = tf.Variable(tf.unstack(_bias.AsTensor())[1]); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Tensor hh; |
|
|
|
|
|
Tensor z; |
|
|
|
|
|
if ( _args.Implementation == 1) |
|
|
|
|
|
{ |
|
|
|
|
|
Tensor inputs_z; |
|
|
|
|
|
Tensor inputs_r; |
|
|
|
|
|
Tensor inputs_h; |
|
|
|
|
|
if (0f < _args.Dropout && _args.Dropout < 1f) |
|
|
|
|
|
{ |
|
|
|
|
|
inputs_z = inputs * dp_mask[0]; |
|
|
|
|
|
inputs_r = inputs * dp_mask[1]; |
|
|
|
|
|
inputs_h = inputs * dp_mask[2]; |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
inputs_z = inputs.Single(); |
|
|
|
|
|
inputs_r = inputs.Single(); |
|
|
|
|
|
inputs_h = inputs.Single(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int startIndex = (int)_kernel.AsTensor().shape[0]; |
|
|
|
|
|
var _kernel_slice = tf.slice(_kernel.AsTensor(), |
|
|
|
|
|
new[] { 0, 0 }, new[] { startIndex, Units }); |
|
|
|
|
|
var x_z = math_ops.matmul(inputs_z, _kernel_slice); |
|
|
|
|
|
_kernel_slice = tf.slice(_kernel.AsTensor(), |
|
|
|
|
|
new[] { 0, Units }, new[] { Units, Units}); |
|
|
|
|
|
var x_r = math_ops.matmul( |
|
|
|
|
|
inputs_r, _kernel_slice); |
|
|
|
|
|
int endIndex = (int)_kernel.AsTensor().shape[1]; |
|
|
|
|
|
_kernel_slice = tf.slice(_kernel.AsTensor(), |
|
|
|
|
|
new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 }); |
|
|
|
|
|
var x_h = math_ops.matmul(inputs_h, _kernel_slice); |
|
|
|
|
|
|
|
|
|
|
|
if(_args.UseBias) |
|
|
|
|
|
{ |
|
|
|
|
|
x_z = tf.nn.bias_add( |
|
|
|
|
|
x_z, tf.Variable(input_bias.AsTensor()[$":{Units}"])); |
|
|
|
|
|
x_r = tf.nn.bias_add( |
|
|
|
|
|
x_r, tf.Variable(input_bias.AsTensor()[$"{Units}:{Units * 2}"])); |
|
|
|
|
|
x_h = tf.nn.bias_add( |
|
|
|
|
|
x_h, tf.Variable(input_bias.AsTensor()[$"{Units * 2}:"])); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Tensor h_tm1_z; |
|
|
|
|
|
Tensor h_tm1_r; |
|
|
|
|
|
Tensor h_tm1_h; |
|
|
|
|
|
if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f) |
|
|
|
|
|
{ |
|
|
|
|
|
h_tm1_z = h_tm1 * rec_dp_mask[0]; |
|
|
|
|
|
h_tm1_r = h_tm1 * rec_dp_mask[1]; |
|
|
|
|
|
h_tm1_h = h_tm1 * rec_dp_mask[2]; |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
h_tm1_z = h_tm1; |
|
|
|
|
|
h_tm1_r = h_tm1; |
|
|
|
|
|
h_tm1_h = h_tm1; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
startIndex = (int)_recurrent_kernel.AsTensor().shape[0]; |
|
|
|
|
|
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), |
|
|
|
|
|
new[] { 0, 0 }, new[] { startIndex, Units }); |
|
|
|
|
|
var recurrent_z = math_ops.matmul( |
|
|
|
|
|
h_tm1_z, _recurrent_kernel_slice); |
|
|
|
|
|
_recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), |
|
|
|
|
|
new[] { 0, Units }, new[] { startIndex, Units}); |
|
|
|
|
|
var recurrent_r = math_ops.matmul( |
|
|
|
|
|
h_tm1_r, _recurrent_kernel_slice); |
|
|
|
|
|
if(_args.ResetAfter && _args.UseBias) |
|
|
|
|
|
{ |
|
|
|
|
|
recurrent_z = tf.nn.bias_add( |
|
|
|
|
|
recurrent_z, tf.Variable(recurrent_bias.AsTensor()[$":{Units}"])); |
|
|
|
|
|
recurrent_r = tf.nn.bias_add( |
|
|
|
|
|
recurrent_r, tf.Variable(recurrent_bias.AsTensor()[$"{Units}: {Units * 2}"])); |
|
|
|
|
|
} |
|
|
|
|
|
z = _args.RecurrentActivation.Apply(x_z + recurrent_z); |
|
|
|
|
|
var r = _args.RecurrentActivation.Apply(x_r + recurrent_r); |
|
|
|
|
|
|
|
|
|
|
|
Tensor recurrent_h; |
|
|
|
|
|
if (_args.ResetAfter) |
|
|
|
|
|
{ |
|
|
|
|
|
endIndex = (int)_recurrent_kernel.AsTensor().shape[1]; |
|
|
|
|
|
_recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), |
|
|
|
|
|
new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 }); |
|
|
|
|
|
recurrent_h = math_ops.matmul( |
|
|
|
|
|
h_tm1_h, _recurrent_kernel_slice); |
|
|
|
|
|
if(_args.UseBias) |
|
|
|
|
|
{ |
|
|
|
|
|
recurrent_h = tf.nn.bias_add( |
|
|
|
|
|
recurrent_h, tf.Variable(recurrent_bias.AsTensor()[$"{Units * 2}:"])); |
|
|
|
|
|
} |
|
|
|
|
|
recurrent_h *= r; |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
_recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), |
|
|
|
|
|
new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 }); |
|
|
|
|
|
recurrent_h = math_ops.matmul( |
|
|
|
|
|
r * h_tm1_h, _recurrent_kernel_slice); |
|
|
|
|
|
} |
|
|
|
|
|
hh = _args.Activation.Apply(x_h + recurrent_h); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
if (0f < _args.Dropout && _args.Dropout < 1f) |
|
|
|
|
|
{ |
|
|
|
|
|
inputs = inputs * dp_mask[0]; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
var matrix_x = math_ops.matmul(inputs, _kernel.AsTensor()); |
|
|
|
|
|
if(_args.UseBias) |
|
|
|
|
|
{ |
|
|
|
|
|
matrix_x = tf.nn.bias_add(matrix_x, input_bias); |
|
|
|
|
|
} |
|
|
|
|
|
var matrix_x_spilted = tf.split(matrix_x, 3, axis: -1); |
|
|
|
|
|
var x_z = matrix_x_spilted[0]; |
|
|
|
|
|
var x_r = matrix_x_spilted[1]; |
|
|
|
|
|
var x_h = matrix_x_spilted[2]; |
|
|
|
|
|
|
|
|
|
|
|
Tensor matrix_inner; |
|
|
|
|
|
if (_args.ResetAfter) |
|
|
|
|
|
{ |
|
|
|
|
|
matrix_inner = math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor()); |
|
|
|
|
|
if ( _args.UseBias) |
|
|
|
|
|
{ |
|
|
|
|
|
matrix_inner = tf.nn.bias_add( |
|
|
|
|
|
matrix_inner, recurrent_bias); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
var startIndex = (int)_recurrent_kernel.AsTensor().shape[0]; |
|
|
|
|
|
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), |
|
|
|
|
|
new[] { 0, 0 }, new[] { startIndex, Units * 2 }); |
|
|
|
|
|
matrix_inner = math_ops.matmul( |
|
|
|
|
|
h_tm1, _recurrent_kernel_slice); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
var matrix_inner_splitted = tf.split(matrix_inner, new int[] {Units, Units, -1}, axis:-1); |
|
|
|
|
|
var recurrent_z = matrix_inner_splitted[0]; |
|
|
|
|
|
var recurrent_r = matrix_inner_splitted[0]; |
|
|
|
|
|
var recurrent_h = matrix_inner_splitted[0]; |
|
|
|
|
|
|
|
|
|
|
|
z = _args.RecurrentActivation.Apply(x_z + recurrent_z); |
|
|
|
|
|
var r = _args.RecurrentActivation.Apply(x_r + recurrent_r); |
|
|
|
|
|
|
|
|
|
|
|
if(_args.ResetAfter) |
|
|
|
|
|
{ |
|
|
|
|
|
recurrent_h = r * recurrent_h; |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
var startIndex = (int)_recurrent_kernel.AsTensor().shape[0]; |
|
|
|
|
|
var endIndex = (int)_recurrent_kernel.AsTensor().shape[1]; |
|
|
|
|
|
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(), |
|
|
|
|
|
new[] { 0, 2*Units }, new[] { startIndex, endIndex - 2 * Units }); |
|
|
|
|
|
recurrent_h = math_ops.matmul( |
|
|
|
|
|
r * h_tm1, _recurrent_kernel_slice); |
|
|
|
|
|
} |
|
|
|
|
|
hh = _args.Activation.Apply(x_h + recurrent_h); |
|
|
|
|
|
} |
|
|
|
|
|
var h = z * h_tm1 + (1 - z) * hh; |
|
|
|
|
|
if (states.IsNested()) |
|
|
|
|
|
{ |
|
|
|
|
|
var new_state = new NestList<Tensor>(h); |
|
|
|
|
|
return new Nest<Tensor>(new INestStructure<Tensor>[] { new NestNode<Tensor>(h), new_state }).ToTensors(); |
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
return new Nest<Tensor>(new INestStructure<Tensor>[] { new NestNode<Tensor>(h), new NestNode<Tensor>(h)}).ToTensors(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |