using System; using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; using Tensorflow.Common.Types; using Tensorflow.Common.Extensions; using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Layers.Rnn { /// /// Cell class for SimpleRNN. /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) /// for details about the usage of RNN API. /// This class processes one step within the whole time sequence input, whereas /// `tf.keras.layer.SimpleRNN` processes the whole sequence. /// public class SimpleRNNCell : DropoutRNNCellMixin { SimpleRNNCellArgs _args; IVariableV1 _kernel; IVariableV1 _recurrent_kernel; IVariableV1 _bias; GeneralizedTensorShape _state_size; GeneralizedTensorShape _output_size; public override GeneralizedTensorShape StateSize => _state_size; public override GeneralizedTensorShape OutputSize => _output_size; public override bool IsTFRnnCell => true; public override bool SupportOptionalArgs => false; public SimpleRNNCell(SimpleRNNCellArgs args) : base(args) { this._args = args; if (args.Units <= 0) { throw new ValueError( $"units must be a positive integer, got {args.Units}"); } this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); _state_size = new GeneralizedTensorShape(args.Units); _output_size = new GeneralizedTensorShape(args.Units); } public override void build(KerasShapesWrapper input_shape) { // TODO(Rinne): add the cache. var single_shape = input_shape.ToSingleShape(); var input_dim = single_shape[-1]; _kernel = add_weight("kernel", (single_shape[-1], _args.Units), initializer: _args.KernelInitializer ); _recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units), initializer: _args.RecurrentInitializer ); if (_args.UseBias) { _bias = add_weight("bias", (_args.Units), initializer: _args.BiasInitializer ); } built = true; } // TODO(Rinne): revise the trining param (with refactoring of the framework) protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) { // TODO(Rinne): check if it will have multiple tensors when not nested. Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states; var dp_mask = get_dropout_mask_for_cell(inputs, training.Value); var rec_dp_mask = get_recurrent_dropout_mask_for_cell(prev_output, training.Value); Tensor h; var ranks = inputs.rank; if (dp_mask != null) { h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor()); } else { h = math_ops.matmul(inputs, _kernel.AsTensor()); } if (_bias != null) { h = tf.nn.bias_add(h, _bias); } if (rec_dp_mask != null) { prev_output = math_ops.multiply(prev_output, rec_dp_mask); } Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); if (_args.Activation != null) { output = _args.Activation.Apply(output); } if (Nest.IsNested(states)) { return new Nest(new List> { new Nest(new List> { new Nest(output) }), new Nest(output) }) .ToTensors(); } else { return new Tensors(output, output); } } public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) { return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value); } } }