From 7b077eac7e6a9e60d9d34be9782e222317fbe353 Mon Sep 17 00:00:00 2001
From: Wanglongzhi2001 <583087864@qq.com>
Date: Mon, 4 Sep 2023 00:05:22 +0800
Subject: [PATCH] feat: implement GRU layer
---
.../Keras/ArgsDefinition/Rnn/GRUArgs.cs | 29 +++
.../ArgsDefinition/Rnn/GRUOptionalArgs.cs | 13 ++
.../Keras/Layers/ILayersApi.cs | 19 ++
src/TensorFlowNET.Keras/Layers/LayersApi.cs | 61 ++++++-
src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs | 168 ++++++++++++++++++
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 42 +----
.../Layers/Rnn.Test.cs | 9 +
7 files changed, 300 insertions(+), 41 deletions(-)
create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs
create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs
create mode 100644 src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs
new file mode 100644
index 00000000..cdc3097e
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs
@@ -0,0 +1,29 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.ArgsDefinition
+{
+ public class GRUArgs : AutoSerializeLayerArgs
+ {
+ public int Units { get; set; }
+ public Activation Activation { get; set; }
+ public Activation RecurrentActivation { get; set; }
+ public bool UseBias { get; set; } = true;
+ public float Dropout { get; set; } = .0f;
+ public float RecurrentDropout { get; set; } = .0f;
+ public IInitializer KernelInitializer { get; set; }
+ public IInitializer RecurrentInitializer { get; set; }
+ public IInitializer BiasInitializer { get; set; }
+ public bool ReturnSequences { get;set; }
+ public bool ReturnState { get;set; }
+ public bool GoBackwards { get;set; }
+ public bool Stateful { get;set; }
+ public bool Unroll { get;set; }
+ public bool TimeMajor { get;set; }
+ public bool ResetAfter { get;set; }
+ public int Implementation { get; set; } = 2;
+
+ }
+
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs
new file mode 100644
index 00000000..d441dc82
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs
@@ -0,0 +1,13 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.ArgsDefinition
+{
+ public class GRUOptionalArgs
+ {
+ public string Identifier => "GRU";
+
+ public Tensor Mask { get; set; } = null;
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
index b8aff5fb..5e08eadc 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
@@ -259,6 +259,25 @@ namespace Tensorflow.Keras.Layers
float recurrent_dropout = 0f,
bool reset_after = true);
+ public ILayer GRU(
+ int units,
+ string activation = "tanh",
+ string recurrent_activation = "sigmoid",
+ bool use_bias = true,
+ string kernel_initializer = "glorot_uniform",
+ string recurrent_initializer = "orthogonal",
+ string bias_initializer = "zeros",
+ float dropout = 0f,
+ float recurrent_dropout = 0f,
+ bool return_sequences = false,
+ bool return_state = false,
+ bool go_backwards = false,
+ bool stateful = false,
+ bool unroll = false,
+ bool time_major = false,
+ bool reset_after = true
+ );
+
///
/// Bidirectional wrapper for RNNs.
///
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index 9155c774..928e7e33 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -784,7 +784,7 @@ namespace Tensorflow.Keras.Layers
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
- string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed.
+ string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
bool unit_forget_bias = true,
float dropout = 0f,
@@ -908,6 +908,65 @@ namespace Tensorflow.Keras.Layers
ResetAfter = reset_after
});
+ ///
+ /// Gated Recurrent Unit - Cho et al. 2014.
+ ///
+ /// Positive integer, dimensionality of the output space.
+ /// Activation function to use. If you pass `None`, no activation is applied.(ie. "linear" activation: `a(x) = x`).
+ /// Activation function to use for the recurrent step. If you pass `None`, no activation is applied. (ie. "linear" activation: `a(x) = x`).
+ /// Boolean, (default `True`), whether the layer uses a bias vector.
+ /// Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. Default: `glorot_uniform`.
+ /// Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. Default: `orthogonal`.
+ /// Initializer for the bias vector. Default: `zeros`.
+ /// Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0.
+ /// Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0.
+ ///
+ /// Boolean. Whether to return the last output in the output sequence, or the full sequence. Default: `False`.
+ /// Boolean. Whether to return the last state in addition to the output. Default: `False`.
+ /// Boolean (default `False`). If True, process the input sequence backwards and return the reversed sequence.
+ /// Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.
+ /// Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN,
+ /// The shape format of the `inputs` and `outputs` tensors.
+ /// GRU convention (whether to apply reset gate after or before matrix multiplication). False = "before", True = "after" (default and cuDNN compatible).
+ ///
+ public ILayer GRU(
+ int units,
+ string activation = "tanh",
+ string recurrent_activation = "sigmoid",
+ bool use_bias = true,
+ string kernel_initializer = "glorot_uniform",
+ string recurrent_initializer = "orthogonal",
+ string bias_initializer = "zeros",
+ float dropout = 0f,
+ float recurrent_dropout = 0f,
+ bool return_sequences = false,
+ bool return_state = false,
+ bool go_backwards = false,
+ bool stateful = false,
+ bool unroll = false,
+ bool time_major = false,
+ bool reset_after = true
+ )
+ => new GRU(new GRUArgs
+ {
+ Units = units,
+ Activation = keras.activations.GetActivationFromName(activation),
+ RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation),
+ KernelInitializer = GetInitializerByName(kernel_initializer),
+ RecurrentInitializer = GetInitializerByName(recurrent_initializer),
+ BiasInitializer = GetInitializerByName(bias_initializer),
+ UseBias = use_bias,
+ Dropout = dropout,
+ RecurrentDropout = recurrent_dropout,
+ ReturnSequences = return_sequences,
+ ReturnState = return_state,
+ GoBackwards = go_backwards,
+ Stateful = stateful,
+ TimeMajor = time_major,
+ Unroll = unroll,
+ ResetAfter = reset_after
+ });
+
public ILayer Bidirectional(
ILayer layer,
string merge_mode = "concat",
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs b/src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs
new file mode 100644
index 00000000..0919883d
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs
@@ -0,0 +1,168 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Common.Extensions;
+using Tensorflow.Common.Types;
+using Tensorflow.Keras.Saving;
+
+
+namespace Tensorflow.Keras.Layers
+{
+ public class GRU : RNN
+ {
+ GRUArgs _args;
+ private static GRUCell _cell;
+
+ bool _return_runtime;
+ public GRUCell Cell { get => _cell; }
+ public int units { get => _args.Units; }
+ public Activation activation { get => _args.Activation; }
+ public Activation recurrent_activation { get => _args.RecurrentActivation; }
+ public bool use_bias { get => _args.UseBias; }
+ public float dropout { get => _args.Dropout; }
+ public float recurrent_dropout { get => _args.RecurrentDropout; }
+ public IInitializer kernel_initializer { get => _args.KernelInitializer; }
+ public IInitializer recurrent_initializer { get => _args.RecurrentInitializer; }
+ public IInitializer bias_initializer { get => _args.BiasInitializer; }
+ public int implementation { get => _args.Implementation; }
+ public bool reset_after { get => _args.ResetAfter; }
+
+ public GRU(GRUArgs args) : base(CreateCell(args), PreConstruct(args))
+ {
+ _args = args;
+
+ if (_args.Implementation == 0)
+ {
+ // Use the red output to act as a warning message that can also be used under the release version
+ Console.ForegroundColor = ConsoleColor.Red;
+ Console.WriteLine("Warning: `implementation=0` has been deprecated, "+
+ "and now defaults to `implementation=2`."+
+ "Please update your layer call.");
+ Console.ResetColor();
+ }
+
+ GRUCell cell = new GRUCell(new GRUCellArgs
+ {
+ Units = _args.Units,
+ Activation = _args.Activation,
+ RecurrentActivation = _args.RecurrentActivation,
+ UseBias = _args.UseBias,
+ Dropout = _args.Dropout,
+ RecurrentDropout = _args.RecurrentDropout,
+ KernelInitializer = _args.KernelInitializer,
+ RecurrentInitializer = _args.RecurrentInitializer,
+ BiasInitializer = _args.BiasInitializer,
+ ResetAfter = _args.ResetAfter,
+ Implementation = _args.Implementation
+ });
+ _cell = cell;
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null)
+ {
+ GRUOptionalArgs? gru_optional_args = optional_args as GRUOptionalArgs;
+ if (optional_args is not null && gru_optional_args is null)
+ {
+ throw new ArgumentException("The type of optional args should be `GRUOptionalArgs`.");
+ }
+ Tensors? mask = gru_optional_args?.Mask;
+
+ // Not support ragger input temporarily;
+ int row_length = 0;
+ bool is_ragged_input = false;
+
+ _validate_args_if_ragged(is_ragged_input, mask);
+
+ // GRU does not support constants.Ignore it during process.
+ (inputs, initial_state, _) = this._process_inputs(inputs, initial_state, null);
+
+ if (mask.Length > 1)
+ {
+ mask = mask[0];
+ }
+
+ var input_shape = inputs.shape;
+ var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];
+
+
+ // TODO(Wanglongzhi2001), finish _could_use_gpu_kernel part
+ Func step = (cell_inputs, cell_states) =>
+ {
+ var res = Cell.Apply(cell_inputs, cell_states, training is null ? true : training.Value);
+ var (output, state) = res;
+ return (output, state);
+ };
+
+ var (last_output, outputs, states) = keras.backend.rnn(
+ step,
+ inputs,
+ initial_state,
+ constants: null,
+ go_backwards: _args.GoBackwards,
+ mask: mask,
+ unroll: _args.Unroll,
+ input_length: ops.convert_to_tensor(timesteps),
+ time_major: _args.TimeMajor,
+ zero_output_for_mask: base.Args.ZeroOutputForMask,
+ return_all_outputs: _args.ReturnSequences
+ );
+
+ Tensors output;
+ if (_args.ReturnSequences)
+ {
+ output = outputs;
+ }
+ else
+ {
+ output = last_output;
+ }
+
+ if (_args.ReturnState)
+ {
+ output = new Tensors { output, states };
+ }
+ return output;
+ }
+
+ private static IRnnCell CreateCell(GRUArgs gruArgs)
+ {
+ return new GRUCell(new GRUCellArgs
+ {
+ Units = gruArgs.Units,
+ Activation = gruArgs.Activation,
+ RecurrentActivation = gruArgs.RecurrentActivation,
+ UseBias = gruArgs.UseBias,
+ Dropout = gruArgs.Dropout,
+ RecurrentDropout = gruArgs.RecurrentDropout,
+ KernelInitializer = gruArgs.KernelInitializer,
+ RecurrentInitializer = gruArgs.RecurrentInitializer,
+ BiasInitializer = gruArgs.BiasInitializer,
+ ResetAfter = gruArgs.ResetAfter,
+ Implementation = gruArgs.Implementation
+ });
+ }
+
+ private static RNNArgs PreConstruct(GRUArgs args)
+ {
+ return new RNNArgs
+ {
+ ReturnSequences = args.ReturnSequences,
+ ReturnState = args.ReturnState,
+ GoBackwards = args.GoBackwards,
+ Stateful = args.Stateful,
+ Unroll = args.Unroll,
+ TimeMajor = args.TimeMajor,
+ Units = args.Units,
+ Activation = args.Activation,
+ RecurrentActivation = args.RecurrentActivation,
+ UseBias = args.UseBias,
+ Dropout = args.Dropout,
+ RecurrentDropout = args.RecurrentDropout,
+ KernelInitializer = args.KernelInitializer,
+ RecurrentInitializer = args.RecurrentInitializer,
+ BiasInitializer = args.BiasInitializer
+ };
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
index c1922261..fec75559 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
@@ -25,8 +25,8 @@ namespace Tensorflow.Keras.Layers
private RNNArgs _args;
private object _input_spec = null; // or NoneValue??
private object _state_spec = null;
- private Tensors _states = null;
private object _constants_spec = null;
+ private Tensors _states = null;
private int _num_constants;
protected IVariableV1 _kernel;
protected IVariableV1 _bias;
@@ -469,7 +469,7 @@ namespace Tensorflow.Keras.Layers
return (inputs, initial_state, constants);
}
- private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
+ protected void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
{
if (!is_ragged_input)
{
@@ -528,44 +528,6 @@ namespace Tensorflow.Keras.Layers
throw new NotImplementedException();
}
- // 好像不能cell不能传接口类型
- //public RNN New(IRnnArgCell cell,
- // bool return_sequences = false,
- // bool return_state = false,
- // bool go_backwards = false,
- // bool stateful = false,
- // bool unroll = false,
- // bool time_major = false)
- // => new RNN(new RNNArgs
- // {
- // Cell = cell,
- // ReturnSequences = return_sequences,
- // ReturnState = return_state,
- // GoBackwards = go_backwards,
- // Stateful = stateful,
- // Unroll = unroll,
- // TimeMajor = time_major
- // });
-
- //public RNN New(List cell,
- // bool return_sequences = false,
- // bool return_state = false,
- // bool go_backwards = false,
- // bool stateful = false,
- // bool unroll = false,
- // bool time_major = false)
- // => new RNN(new RNNArgs
- // {
- // Cell = cell,
- // ReturnSequences = return_sequences,
- // ReturnState = return_state,
- // GoBackwards = go_backwards,
- // Stateful = stateful,
- // Unroll = unroll,
- // TimeMajor = time_major
- // });
-
-
protected Tensors get_initial_state(Tensors inputs)
{
var input = inputs[0];
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
index 03159346..dbf5cae1 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
@@ -146,6 +146,15 @@ namespace Tensorflow.Keras.UnitTest.Layers
}
+ [TestMethod]
+ public void GRU()
+ {
+ var inputs = tf.ones((32, 10, 8));
+ var gru = tf.keras.layers.GRU(4);
+ var output = gru.Apply(inputs);
+ Assert.AreEqual((32, 4), output.shape);
+ }
+
[TestMethod]
public void Bidirectional()
{