diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs new file mode 100644 index 00000000..cdc3097e --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUArgs.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class GRUArgs : AutoSerializeLayerArgs + { + public int Units { get; set; } + public Activation Activation { get; set; } + public Activation RecurrentActivation { get; set; } + public bool UseBias { get; set; } = true; + public float Dropout { get; set; } = .0f; + public float RecurrentDropout { get; set; } = .0f; + public IInitializer KernelInitializer { get; set; } + public IInitializer RecurrentInitializer { get; set; } + public IInitializer BiasInitializer { get; set; } + public bool ReturnSequences { get;set; } + public bool ReturnState { get;set; } + public bool GoBackwards { get;set; } + public bool Stateful { get;set; } + public bool Unroll { get;set; } + public bool TimeMajor { get;set; } + public bool ResetAfter { get;set; } + public int Implementation { get; set; } = 2; + + } + +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs new file mode 100644 index 00000000..d441dc82 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class GRUOptionalArgs + { + public string Identifier => "GRU"; + + public Tensor Mask { get; set; } = null; + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index b8aff5fb..5e08eadc 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -259,6 +259,25 @@ namespace Tensorflow.Keras.Layers float recurrent_dropout = 0f, bool reset_after = true); + public ILayer GRU( + int units, + string activation = "tanh", + string recurrent_activation = "sigmoid", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false, + bool reset_after = true + ); + /// /// Bidirectional wrapper for RNNs. /// diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 9155c774..928e7e33 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -784,7 +784,7 @@ namespace Tensorflow.Keras.Layers string recurrent_activation = "sigmoid", bool use_bias = true, string kernel_initializer = "glorot_uniform", - string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed. + string recurrent_initializer = "orthogonal", string bias_initializer = "zeros", bool unit_forget_bias = true, float dropout = 0f, @@ -908,6 +908,65 @@ namespace Tensorflow.Keras.Layers ResetAfter = reset_after }); + /// + /// Gated Recurrent Unit - Cho et al. 2014. + /// + /// Positive integer, dimensionality of the output space. + /// Activation function to use. If you pass `None`, no activation is applied.(ie. "linear" activation: `a(x) = x`). + /// Activation function to use for the recurrent step. If you pass `None`, no activation is applied. (ie. "linear" activation: `a(x) = x`). + /// Boolean, (default `True`), whether the layer uses a bias vector. + /// Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. Default: `glorot_uniform`. + /// Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. Default: `orthogonal`. + /// Initializer for the bias vector. Default: `zeros`. + /// Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0. + /// Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0. + /// + /// Boolean. Whether to return the last output in the output sequence, or the full sequence. Default: `False`. + /// Boolean. Whether to return the last state in addition to the output. Default: `False`. + /// Boolean (default `False`). If True, process the input sequence backwards and return the reversed sequence. + /// Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch. + /// Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, + /// The shape format of the `inputs` and `outputs` tensors. + /// GRU convention (whether to apply reset gate after or before matrix multiplication). False = "before", True = "after" (default and cuDNN compatible). + /// + public ILayer GRU( + int units, + string activation = "tanh", + string recurrent_activation = "sigmoid", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false, + bool reset_after = true + ) + => new GRU(new GRUArgs + { + Units = units, + Activation = keras.activations.GetActivationFromName(activation), + RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation), + KernelInitializer = GetInitializerByName(kernel_initializer), + RecurrentInitializer = GetInitializerByName(recurrent_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + UseBias = use_bias, + Dropout = dropout, + RecurrentDropout = recurrent_dropout, + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + TimeMajor = time_major, + Unroll = unroll, + ResetAfter = reset_after + }); + public ILayer Bidirectional( ILayer layer, string merge_mode = "concat", diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs b/src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs new file mode 100644 index 00000000..0919883d --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/GRU.cs @@ -0,0 +1,168 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Common.Extensions; +using Tensorflow.Common.Types; +using Tensorflow.Keras.Saving; + + +namespace Tensorflow.Keras.Layers +{ + public class GRU : RNN + { + GRUArgs _args; + private static GRUCell _cell; + + bool _return_runtime; + public GRUCell Cell { get => _cell; } + public int units { get => _args.Units; } + public Activation activation { get => _args.Activation; } + public Activation recurrent_activation { get => _args.RecurrentActivation; } + public bool use_bias { get => _args.UseBias; } + public float dropout { get => _args.Dropout; } + public float recurrent_dropout { get => _args.RecurrentDropout; } + public IInitializer kernel_initializer { get => _args.KernelInitializer; } + public IInitializer recurrent_initializer { get => _args.RecurrentInitializer; } + public IInitializer bias_initializer { get => _args.BiasInitializer; } + public int implementation { get => _args.Implementation; } + public bool reset_after { get => _args.ResetAfter; } + + public GRU(GRUArgs args) : base(CreateCell(args), PreConstruct(args)) + { + _args = args; + + if (_args.Implementation == 0) + { + // Use the red output to act as a warning message that can also be used under the release version + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine("Warning: `implementation=0` has been deprecated, "+ + "and now defaults to `implementation=2`."+ + "Please update your layer call."); + Console.ResetColor(); + } + + GRUCell cell = new GRUCell(new GRUCellArgs + { + Units = _args.Units, + Activation = _args.Activation, + RecurrentActivation = _args.RecurrentActivation, + UseBias = _args.UseBias, + Dropout = _args.Dropout, + RecurrentDropout = _args.RecurrentDropout, + KernelInitializer = _args.KernelInitializer, + RecurrentInitializer = _args.RecurrentInitializer, + BiasInitializer = _args.BiasInitializer, + ResetAfter = _args.ResetAfter, + Implementation = _args.Implementation + }); + _cell = cell; + } + + protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + GRUOptionalArgs? gru_optional_args = optional_args as GRUOptionalArgs; + if (optional_args is not null && gru_optional_args is null) + { + throw new ArgumentException("The type of optional args should be `GRUOptionalArgs`."); + } + Tensors? mask = gru_optional_args?.Mask; + + // Not support ragger input temporarily; + int row_length = 0; + bool is_ragged_input = false; + + _validate_args_if_ragged(is_ragged_input, mask); + + // GRU does not support constants.Ignore it during process. + (inputs, initial_state, _) = this._process_inputs(inputs, initial_state, null); + + if (mask.Length > 1) + { + mask = mask[0]; + } + + var input_shape = inputs.shape; + var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; + + + // TODO(Wanglongzhi2001), finish _could_use_gpu_kernel part + Func step = (cell_inputs, cell_states) => + { + var res = Cell.Apply(cell_inputs, cell_states, training is null ? true : training.Value); + var (output, state) = res; + return (output, state); + }; + + var (last_output, outputs, states) = keras.backend.rnn( + step, + inputs, + initial_state, + constants: null, + go_backwards: _args.GoBackwards, + mask: mask, + unroll: _args.Unroll, + input_length: ops.convert_to_tensor(timesteps), + time_major: _args.TimeMajor, + zero_output_for_mask: base.Args.ZeroOutputForMask, + return_all_outputs: _args.ReturnSequences + ); + + Tensors output; + if (_args.ReturnSequences) + { + output = outputs; + } + else + { + output = last_output; + } + + if (_args.ReturnState) + { + output = new Tensors { output, states }; + } + return output; + } + + private static IRnnCell CreateCell(GRUArgs gruArgs) + { + return new GRUCell(new GRUCellArgs + { + Units = gruArgs.Units, + Activation = gruArgs.Activation, + RecurrentActivation = gruArgs.RecurrentActivation, + UseBias = gruArgs.UseBias, + Dropout = gruArgs.Dropout, + RecurrentDropout = gruArgs.RecurrentDropout, + KernelInitializer = gruArgs.KernelInitializer, + RecurrentInitializer = gruArgs.RecurrentInitializer, + BiasInitializer = gruArgs.BiasInitializer, + ResetAfter = gruArgs.ResetAfter, + Implementation = gruArgs.Implementation + }); + } + + private static RNNArgs PreConstruct(GRUArgs args) + { + return new RNNArgs + { + ReturnSequences = args.ReturnSequences, + ReturnState = args.ReturnState, + GoBackwards = args.GoBackwards, + Stateful = args.Stateful, + Unroll = args.Unroll, + TimeMajor = args.TimeMajor, + Units = args.Units, + Activation = args.Activation, + RecurrentActivation = args.RecurrentActivation, + UseBias = args.UseBias, + Dropout = args.Dropout, + RecurrentDropout = args.RecurrentDropout, + KernelInitializer = args.KernelInitializer, + RecurrentInitializer = args.RecurrentInitializer, + BiasInitializer = args.BiasInitializer + }; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index c1922261..fec75559 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -25,8 +25,8 @@ namespace Tensorflow.Keras.Layers private RNNArgs _args; private object _input_spec = null; // or NoneValue?? private object _state_spec = null; - private Tensors _states = null; private object _constants_spec = null; + private Tensors _states = null; private int _num_constants; protected IVariableV1 _kernel; protected IVariableV1 _bias; @@ -469,7 +469,7 @@ namespace Tensorflow.Keras.Layers return (inputs, initial_state, constants); } - private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) + protected void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) { if (!is_ragged_input) { @@ -528,44 +528,6 @@ namespace Tensorflow.Keras.Layers throw new NotImplementedException(); } - // 好像不能cell不能传接口类型 - //public RNN New(IRnnArgCell cell, - // bool return_sequences = false, - // bool return_state = false, - // bool go_backwards = false, - // bool stateful = false, - // bool unroll = false, - // bool time_major = false) - // => new RNN(new RNNArgs - // { - // Cell = cell, - // ReturnSequences = return_sequences, - // ReturnState = return_state, - // GoBackwards = go_backwards, - // Stateful = stateful, - // Unroll = unroll, - // TimeMajor = time_major - // }); - - //public RNN New(List cell, - // bool return_sequences = false, - // bool return_state = false, - // bool go_backwards = false, - // bool stateful = false, - // bool unroll = false, - // bool time_major = false) - // => new RNN(new RNNArgs - // { - // Cell = cell, - // ReturnSequences = return_sequences, - // ReturnState = return_state, - // GoBackwards = go_backwards, - // Stateful = stateful, - // Unroll = unroll, - // TimeMajor = time_major - // }); - - protected Tensors get_initial_state(Tensors inputs) { var input = inputs[0]; diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs index 03159346..dbf5cae1 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs @@ -146,6 +146,15 @@ namespace Tensorflow.Keras.UnitTest.Layers } + [TestMethod] + public void GRU() + { + var inputs = tf.ones((32, 10, 8)); + var gru = tf.keras.layers.GRU(4); + var output = gru.Apply(inputs); + Assert.AreEqual((32, 4), output.shape); + } + [TestMethod] public void Bidirectional() {