using Serilog.Core;
using System.Diagnostics;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
namespace Tensorflow.Keras.Layers.Rnn
{
///
/// Cell class for the LSTM layer.
/// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
/// for details about the usage of RNN API.
/// This class processes one step within the whole time sequence input, whereas
/// `tf.keras.layer.LSTM` processes the whole sequence.
///
public class LSTMCell : DropoutRNNCellMixin
{
LSTMCellArgs _args;
IVariableV1 _kernel;
IVariableV1 _recurrent_kernel;
IInitializer _bias_initializer;
IVariableV1 _bias;
GeneralizedTensorShape _state_size;
GeneralizedTensorShape _output_size;
public override GeneralizedTensorShape StateSize => _state_size;
public override GeneralizedTensorShape OutputSize => _output_size;
public override bool IsTFRnnCell => true;
public override bool SupportOptionalArgs => false;
public LSTMCell(LSTMCellArgs args)
: base(args)
{
_args = args;
if (args.Units <= 0)
{
throw new ValueError(
$"units must be a positive integer, got {args.Units}");
}
_args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout));
_args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
if (_args.RecurrentDropout != 0f && _args.Implementation != 1)
{
Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." +
"Using `implementation=1`.");
_args.Implementation = 1;
}
_state_size = new GeneralizedTensorShape(_args.Units, 2);
_output_size = new GeneralizedTensorShape(_args.Units);
}
public override void build(KerasShapesWrapper input_shape)
{
var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1];
_kernel = add_weight("kernel", (input_dim, _args.Units * 4),
initializer: _args.KernelInitializer
);
_recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units * 4),
initializer: _args.RecurrentInitializer
);
if (_args.UseBias)
{
if (_args.UnitForgetBias)
{
Tensor bias_initializer()
{
return keras.backend.concatenate(
new Tensors(
_args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units))),
tf.ones_initializer.Apply(new InitializerArgs(shape: (_args.Units))),
_args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units)))), axis: 0);
}
}
else
{
_bias_initializer = _args.BiasInitializer;
}
_bias = add_weight("bias", (_args.Units * 4),
initializer: _args.BiasInitializer);
}
built = true;
}
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
{
var h_tm1 = states[0]; // previous memory state
var c_tm1 = states[1]; // previous carry state
var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 4);
var rec_dp_mask = get_recurrent_dropout_mask_for_cell(
h_tm1, training.Value, count: 4);
Tensor c;
Tensor o;
if (_args.Implementation == 1)
{
Tensor inputs_i;
Tensor inputs_f;
Tensor inputs_c;
Tensor inputs_o;
if (0f < _args.Dropout && _args.Dropout < 1f)
{
inputs_i = inputs * dp_mask[0];
inputs_f = inputs * dp_mask[1];
inputs_c = inputs * dp_mask[2];
inputs_o = inputs * dp_mask[3];
}
else
{
inputs_i = inputs;
inputs_f = inputs;
inputs_c = inputs;
inputs_o = inputs;
}
var k = tf.split(_kernel.AsTensor(), num_split: 4, axis: 1);
Tensor k_i = k[0], k_f = k[1], k_c = k[2], k_o = k[3];
var x_i = math_ops.matmul(inputs_i, k_i);
var x_f = math_ops.matmul(inputs_f, k_f);
var x_c = math_ops.matmul(inputs_c, k_c);
var x_o = math_ops.matmul(inputs_o, k_o);
if(_args.UseBias)
{
var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0);
Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3];
x_i = gen_nn_ops.bias_add(x_i, b_i);
x_f = gen_nn_ops.bias_add(x_f, b_f);
x_c = gen_nn_ops.bias_add(x_c, b_c);
x_o = gen_nn_ops.bias_add(x_o, b_o);
}
Tensor h_tm1_i;
Tensor h_tm1_f;
Tensor h_tm1_c;
Tensor h_tm1_o;
if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f)
{
h_tm1_i = h_tm1 * rec_dp_mask[0];
h_tm1_f = h_tm1 * rec_dp_mask[1];
h_tm1_c = h_tm1 * rec_dp_mask[2];
h_tm1_o = h_tm1 * rec_dp_mask[3];
}
else
{
h_tm1_i = h_tm1;
h_tm1_f = h_tm1;
h_tm1_c = h_tm1;
h_tm1_o = h_tm1;
}
var x = new Tensor[] { x_i, x_f, x_c, x_o };
var h_tm1_array = new Tensor[] { h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o };
(c, o) = _compute_carry_and_output(x, h_tm1_array, c_tm1);
}
else
{
if (0f < _args.Dropout && _args.Dropout < 1f)
inputs = inputs * dp_mask[0];
var z = math_ops.matmul(inputs, _kernel.AsTensor());
z += math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor());
if (_args.UseBias)
{
z = tf.nn.bias_add(z, _bias);
}
var z_array = tf.split(z, num_split: 4, axis: 1);
(c, o) = _compute_carry_and_output_fused(z_array, c_tm1);
}
var h = o * _args.Activation.Apply(c);
// 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组
return new Tensors(h, h, c);
}
///
/// Computes carry and output using split kernels.
///
///
///
///
///
///
public Tensors _compute_carry_and_output(Tensor[] x, Tensor[] h_tm1, Tensor c_tm1)
{
Tensor x_i = x[0], x_f = x[1], x_c = x[2], x_o = x[3];
Tensor h_tm1_i = h_tm1[0], h_tm1_f = h_tm1[1], h_tm1_c = h_tm1[2],
h_tm1_o = h_tm1[3];
var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor();
var startIndex = _recurrent_kernel_tensor.shape[0];
var endIndex = _recurrent_kernel_tensor.shape[1];
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, 0 }, new[] { startIndex, _args.Units });
var i = _args.RecurrentActivation.Apply(
x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units }, new[] { startIndex, _args.Units * 2});
var f = _args.RecurrentActivation.Apply(
x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units * 3 });
var c = f * c_tm1 + i * _args.Activation.Apply(
x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units * 3 }, new[] { startIndex, endIndex });
var o = _args.RecurrentActivation.Apply(
x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice));
return new Tensors(c, o);
}
///
/// Computes carry and output using fused kernels.
///
///
///
///
public Tensors _compute_carry_and_output_fused(Tensor[] z, Tensor c_tm1)
{
Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3];
var i = _args.RecurrentActivation.Apply(z0);
var f = _args.RecurrentActivation.Apply(z1);
var c = f * c_tm1 + i * _args.RecurrentActivation.Apply(z2);
var o = _args.RecurrentActivation.Apply(z3);
return new Tensors(c, o);
}
public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value);
}
}
}