using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;
using Tensorflow.Common.Extensions;
using Tensorflow.Keras.Utils;
namespace Tensorflow.Keras.Layers.Rnn
{
///
/// Cell class for SimpleRNN.
/// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
/// for details about the usage of RNN API.
/// This class processes one step within the whole time sequence input, whereas
/// `tf.keras.layer.SimpleRNN` processes the whole sequence.
///
public class SimpleRNNCell : DropoutRNNCellMixin
{
SimpleRNNCellArgs _args;
IVariableV1 _kernel;
IVariableV1 _recurrent_kernel;
IVariableV1 _bias;
GeneralizedTensorShape _state_size;
GeneralizedTensorShape _output_size;
public override GeneralizedTensorShape StateSize => _state_size;
public override GeneralizedTensorShape OutputSize => _output_size;
public override bool IsTFRnnCell => true;
public override bool SupportOptionalArgs => false;
public SimpleRNNCell(SimpleRNNCellArgs args) : base(args)
{
this._args = args;
if (args.Units <= 0)
{
throw new ValueError(
$"units must be a positive integer, got {args.Units}");
}
this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout));
this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
_state_size = new GeneralizedTensorShape(args.Units);
_output_size = new GeneralizedTensorShape(args.Units);
}
public override void build(KerasShapesWrapper input_shape)
{
// TODO(Rinne): add the cache.
var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1];
_kernel = add_weight("kernel", (single_shape[-1], _args.Units),
initializer: _args.KernelInitializer
);
_recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units),
initializer: _args.RecurrentInitializer
);
if (_args.UseBias)
{
_bias = add_weight("bias", (_args.Units),
initializer: _args.BiasInitializer
);
}
built = true;
}
// TODO(Rinne): revise the trining param (with refactoring of the framework)
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
{
// TODO(Rinne): check if it will have multiple tensors when not nested.
Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states;
var dp_mask = get_dropout_mask_for_cell(inputs, training.Value);
var rec_dp_mask = get_recurrent_dropout_mask_for_cell(prev_output, training.Value);
Tensor h;
var ranks = inputs.rank;
if (dp_mask != null)
{
h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor());
}
else
{
h = math_ops.matmul(inputs, _kernel.AsTensor());
}
if (_bias != null)
{
h = tf.nn.bias_add(h, _bias);
}
if (rec_dp_mask != null)
{
prev_output = math_ops.multiply(prev_output, rec_dp_mask);
}
Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor());
if (_args.Activation != null)
{
output = _args.Activation.Apply(output);
}
if (Nest.IsNested(states))
{
return new Nest(new List> {
new Nest(new List> { new Nest(output) }), new Nest(output) })
.ToTensors();
}
else
{
return new Tensors(output, output);
}
}
public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value);
}
}
}