diff --git a/src/TensorFlowNET.Core/APIs/tf.tensor.cs b/src/TensorFlowNET.Core/APIs/tf.tensor.cs
index 45aebc0c..b03168ab 100644
--- a/src/TensorFlowNET.Core/APIs/tf.tensor.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.tensor.cs
@@ -68,20 +68,27 @@ namespace Tensorflow
/// A name for the operation (optional)
/// if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects;
/// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.
- public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
+ public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null)
=> array_ops.split(
value: value,
num_or_size_splits: num_split,
axis: axis,
name: name);
- public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
+ public Tensor[] split(Tensor value, int[] num_split, Axis axis, string name = null)
=> array_ops.split(
value: value,
num_or_size_splits: num_split,
- axis: ops.convert_to_tensor(axis),
+ axis: axis,
name: name);
+ //public Tensor[] split(Tensor value, int num_split, Axis axis, string name = null)
+ // => array_ops.split(
+ // value: value,
+ // num_or_size_splits: num_split,
+ // axis: axis,
+ // name: name);
+
public Tensor ensure_shape(Tensor x, Shape shape, string name = null)
{
return gen_ops.ensure_shape(x, shape, name);
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs
new file mode 100644
index 00000000..75d5d021
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs
@@ -0,0 +1,39 @@
+using Newtonsoft.Json;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.ArgsDefinition.Rnn
+{
+ public class GRUCellArgs : AutoSerializeLayerArgs
+ {
+ [JsonProperty("units")]
+ public int Units { get; set; }
+ // TODO(Rinne): lack of initialized value of Activation. Merging keras
+ // into tf.net could resolve it.
+ [JsonProperty("activation")]
+ public Activation Activation { get; set; }
+ [JsonProperty("recurrent_activation")]
+ public Activation RecurrentActivation { get; set; }
+ [JsonProperty("use_bias")]
+ public bool UseBias { get; set; } = true;
+ [JsonProperty("dropout")]
+ public float Dropout { get; set; } = .0f;
+ [JsonProperty("recurrent_dropout")]
+ public float RecurrentDropout { get; set; } = .0f;
+ [JsonProperty("kernel_initializer")]
+ public IInitializer KernelInitializer { get; set; }
+ [JsonProperty("recurrent_initializer")]
+ public IInitializer RecurrentInitializer { get; set; }
+ [JsonProperty("bias_initializer")]
+ public IInitializer BiasInitializer { get; set; }
+ [JsonProperty("reset_after")]
+ public bool ResetAfter { get;set; }
+ [JsonProperty("implementation")]
+ public int Implementation { get; set; } = 2;
+
+
+
+ }
+
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
index a19508d4..9bc99701 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
@@ -246,6 +246,18 @@ namespace Tensorflow.Keras.Layers
bool time_major = false
);
+ public IRnnCell GRUCell(
+ int units,
+ string activation = "tanh",
+ string recurrent_activation = "sigmoid",
+ bool use_bias = true,
+ string kernel_initializer = "glorot_uniform",
+ string recurrent_initializer = "orthogonal",
+ string bias_initializer = "zeros",
+ float dropout = 0f,
+ float recurrent_dropout = 0f,
+ bool reset_after = true);
+
public ILayer Subtract();
}
}
diff --git a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs
index 73ccc87b..59152d9b 100644
--- a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs
+++ b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs
@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Callbacks;
public class EarlyStopping: ICallback
{
int _paitence;
- int _min_delta;
+ float _min_delta;
int _verbose;
int _stopped_epoch;
int _wait;
@@ -26,7 +26,7 @@ public class EarlyStopping: ICallback
CallbackParams _parameters;
public Dictionary>? history { get; set; }
// user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model
- public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", int min_delta = 0, int patience = 0,
+ public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0,
int verbose = 1, string mode = "auto", float baseline = 0f, bool restore_best_weights = false,
int start_from_epoch = 0)
{
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index 0bdcbc84..d2080337 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -873,6 +873,45 @@ namespace Tensorflow.Keras.Layers
UnitForgetBias = unit_forget_bias
});
+ ///
+ /// Cell class for the GRU layer.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public IRnnCell GRUCell(
+ int units,
+ string activation = "tanh",
+ string recurrent_activation = "sigmoid",
+ bool use_bias = true,
+ string kernel_initializer = "glorot_uniform",
+ string recurrent_initializer = "orthogonal",
+ string bias_initializer = "zeros",
+ float dropout = 0f,
+ float recurrent_dropout = 0f,
+ bool reset_after = true)
+ => new GRUCell(new GRUCellArgs
+ {
+ Units = units,
+ Activation = keras.activations.GetActivationFromName(activation),
+ RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation),
+ KernelInitializer = GetInitializerByName(kernel_initializer),
+ RecurrentInitializer = GetInitializerByName(recurrent_initializer),
+ BiasInitializer = GetInitializerByName(bias_initializer),
+ UseBias = use_bias,
+ Dropout = dropout,
+ RecurrentDropout = recurrent_dropout,
+ ResetAfter = reset_after
+ });
+
///
///
///
@@ -983,5 +1022,9 @@ namespace Tensorflow.Keras.Layers
Variance = variance,
Invert = invert
});
+
+
+
+
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs
new file mode 100644
index 00000000..02fe54f4
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs
@@ -0,0 +1,282 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.ArgsDefinition.Rnn;
+using Tensorflow.Common.Extensions;
+using Tensorflow.Common.Types;
+using Tensorflow.Keras.Saving;
+
+namespace Tensorflow.Keras.Layers.Rnn
+{
+ ///
+ /// Cell class for the GRU layer.
+ ///
+ public class GRUCell : DropoutRNNCellMixin
+ {
+ GRUCellArgs _args;
+ IVariableV1 _kernel;
+ IVariableV1 _recurrent_kernel;
+ IInitializer _bias_initializer;
+ IVariableV1 _bias;
+ INestStructure _state_size;
+ INestStructure _output_size;
+ int Units;
+ public override INestStructure StateSize => _state_size;
+
+ public override INestStructure OutputSize => _output_size;
+
+ public override bool SupportOptionalArgs => false;
+ public GRUCell(GRUCellArgs args) : base(args)
+ {
+ _args = args;
+ if (_args.Units <= 0)
+ {
+ throw new ValueError(
+ $"units must be a positive integer, got {args.Units}");
+ }
+ _args.Dropout = Math.Min(1f, Math.Max(0f, _args.Dropout));
+ _args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
+ if (_args.RecurrentDropout != 0f && _args.Implementation != 1)
+ {
+ Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." +
+ "Using `implementation=1`.");
+ _args.Implementation = 1;
+ }
+ Units = _args.Units;
+ _state_size = new NestList(Units);
+ _output_size = new NestNode(Units);
+ }
+
+ public override void build(KerasShapesWrapper input_shape)
+ {
+ //base.build(input_shape);
+
+ var single_shape = input_shape.ToSingleShape();
+ var input_dim = single_shape[-1];
+
+ _kernel = add_weight("kernel", (input_dim, _args.Units * 3),
+ initializer: _args.KernelInitializer
+ );
+
+ _recurrent_kernel = add_weight("recurrent_kernel", (Units, Units * 3),
+ initializer: _args.RecurrentInitializer
+ );
+ if (_args.UseBias)
+ {
+ Shape bias_shape;
+ if (!_args.ResetAfter)
+ {
+ bias_shape = new Shape(3 * Units);
+ }
+ else
+ {
+ bias_shape = (2, 3 * Units);
+ }
+ _bias = add_weight("bias", bias_shape,
+ initializer: _bias_initializer
+ );
+ }
+ built = true;
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
+ {
+ var h_tm1 = states.IsNested() ? states[0] : states.Single();
+ var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 3);
+ var rec_dp_mask = get_recurrent_dropout_mask_for_cell(h_tm1, training.Value, count: 3);
+
+ IVariableV1 input_bias = _bias;
+ IVariableV1 recurrent_bias = _bias;
+ if (_args.UseBias)
+ {
+ if (!_args.ResetAfter)
+ {
+ input_bias = _bias;
+ recurrent_bias = null;
+ }
+ else
+ {
+ input_bias = tf.Variable(tf.unstack(_bias.AsTensor())[0]);
+ recurrent_bias = tf.Variable(tf.unstack(_bias.AsTensor())[1]);
+ }
+ }
+
+
+ Tensor hh;
+ Tensor z;
+ if ( _args.Implementation == 1)
+ {
+ Tensor inputs_z;
+ Tensor inputs_r;
+ Tensor inputs_h;
+ if (0f < _args.Dropout && _args.Dropout < 1f)
+ {
+ inputs_z = inputs * dp_mask[0];
+ inputs_r = inputs * dp_mask[1];
+ inputs_h = inputs * dp_mask[2];
+ }
+ else
+ {
+ inputs_z = inputs.Single();
+ inputs_r = inputs.Single();
+ inputs_h = inputs.Single();
+ }
+
+
+ int startIndex = (int)_kernel.AsTensor().shape[0];
+ var _kernel_slice = tf.slice(_kernel.AsTensor(),
+ new[] { 0, 0 }, new[] { startIndex, Units });
+ var x_z = math_ops.matmul(inputs_z, _kernel_slice);
+ _kernel_slice = tf.slice(_kernel.AsTensor(),
+ new[] { 0, Units }, new[] { Units, Units});
+ var x_r = math_ops.matmul(
+ inputs_r, _kernel_slice);
+ int endIndex = (int)_kernel.AsTensor().shape[1];
+ _kernel_slice = tf.slice(_kernel.AsTensor(),
+ new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 });
+ var x_h = math_ops.matmul(inputs_h, _kernel_slice);
+
+ if(_args.UseBias)
+ {
+ x_z = tf.nn.bias_add(
+ x_z, tf.Variable(input_bias.AsTensor()[$":{Units}"]));
+ x_r = tf.nn.bias_add(
+ x_r, tf.Variable(input_bias.AsTensor()[$"{Units}:{Units * 2}"]));
+ x_h = tf.nn.bias_add(
+ x_h, tf.Variable(input_bias.AsTensor()[$"{Units * 2}:"]));
+ }
+
+ Tensor h_tm1_z;
+ Tensor h_tm1_r;
+ Tensor h_tm1_h;
+ if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f)
+ {
+ h_tm1_z = h_tm1 * rec_dp_mask[0];
+ h_tm1_r = h_tm1 * rec_dp_mask[1];
+ h_tm1_h = h_tm1 * rec_dp_mask[2];
+ }
+ else
+ {
+ h_tm1_z = h_tm1;
+ h_tm1_r = h_tm1;
+ h_tm1_h = h_tm1;
+ }
+
+ startIndex = (int)_recurrent_kernel.AsTensor().shape[0];
+ var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
+ new[] { 0, 0 }, new[] { startIndex, Units });
+ var recurrent_z = math_ops.matmul(
+ h_tm1_z, _recurrent_kernel_slice);
+ _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
+ new[] { 0, Units }, new[] { startIndex, Units});
+ var recurrent_r = math_ops.matmul(
+ h_tm1_r, _recurrent_kernel_slice);
+ if(_args.ResetAfter && _args.UseBias)
+ {
+ recurrent_z = tf.nn.bias_add(
+ recurrent_z, tf.Variable(recurrent_bias.AsTensor()[$":{Units}"]));
+ recurrent_r = tf.nn.bias_add(
+ recurrent_r, tf.Variable(recurrent_bias.AsTensor()[$"{Units}: {Units * 2}"]));
+ }
+ z = _args.RecurrentActivation.Apply(x_z + recurrent_z);
+ var r = _args.RecurrentActivation.Apply(x_r + recurrent_r);
+
+ Tensor recurrent_h;
+ if (_args.ResetAfter)
+ {
+ endIndex = (int)_recurrent_kernel.AsTensor().shape[1];
+ _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
+ new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 });
+ recurrent_h = math_ops.matmul(
+ h_tm1_h, _recurrent_kernel_slice);
+ if(_args.UseBias)
+ {
+ recurrent_h = tf.nn.bias_add(
+ recurrent_h, tf.Variable(recurrent_bias.AsTensor()[$"{Units * 2}:"]));
+ }
+ recurrent_h *= r;
+ }
+ else
+ {
+ _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
+ new[] { 0, Units * 2 }, new[] { startIndex, endIndex - Units * 2 });
+ recurrent_h = math_ops.matmul(
+ r * h_tm1_h, _recurrent_kernel_slice);
+ }
+ hh = _args.Activation.Apply(x_h + recurrent_h);
+ }
+ else
+ {
+ if (0f < _args.Dropout && _args.Dropout < 1f)
+ {
+ inputs = inputs * dp_mask[0];
+ }
+
+ var matrix_x = math_ops.matmul(inputs, _kernel.AsTensor());
+ if(_args.UseBias)
+ {
+ matrix_x = tf.nn.bias_add(matrix_x, input_bias);
+ }
+ var matrix_x_spilted = tf.split(matrix_x, 3, axis: -1);
+ var x_z = matrix_x_spilted[0];
+ var x_r = matrix_x_spilted[1];
+ var x_h = matrix_x_spilted[2];
+
+ Tensor matrix_inner;
+ if (_args.ResetAfter)
+ {
+ matrix_inner = math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor());
+ if ( _args.UseBias)
+ {
+ matrix_inner = tf.nn.bias_add(
+ matrix_inner, recurrent_bias);
+ }
+ }
+ else
+ {
+ var startIndex = (int)_recurrent_kernel.AsTensor().shape[0];
+ var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
+ new[] { 0, 0 }, new[] { startIndex, Units * 2 });
+ matrix_inner = math_ops.matmul(
+ h_tm1, _recurrent_kernel_slice);
+ }
+
+ var matrix_inner_splitted = tf.split(matrix_inner, new int[] {Units, Units, -1}, axis:-1);
+ var recurrent_z = matrix_inner_splitted[0];
+ var recurrent_r = matrix_inner_splitted[0];
+ var recurrent_h = matrix_inner_splitted[0];
+
+ z = _args.RecurrentActivation.Apply(x_z + recurrent_z);
+ var r = _args.RecurrentActivation.Apply(x_r + recurrent_r);
+
+ if(_args.ResetAfter)
+ {
+ recurrent_h = r * recurrent_h;
+ }
+ else
+ {
+ var startIndex = (int)_recurrent_kernel.AsTensor().shape[0];
+ var endIndex = (int)_recurrent_kernel.AsTensor().shape[1];
+ var _recurrent_kernel_slice = tf.slice(_recurrent_kernel.AsTensor(),
+ new[] { 0, 2*Units }, new[] { startIndex, endIndex - 2 * Units });
+ recurrent_h = math_ops.matmul(
+ r * h_tm1, _recurrent_kernel_slice);
+ }
+ hh = _args.Activation.Apply(x_h + recurrent_h);
+ }
+ var h = z * h_tm1 + (1 - z) * hh;
+ if (states.IsNested())
+ {
+ var new_state = new NestList(h);
+ return new Nest(new INestStructure[] { new NestNode(h), new_state }).ToTensors();
+ }
+ else
+ {
+ return new Nest(new INestStructure[] { new NestNode(h), new NestNode(h)}).ToTensors();
+ }
+
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
index 8eeee7a8..becdbcd6 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
@@ -132,5 +132,18 @@ namespace Tensorflow.Keras.UnitTest.Layers
Console.WriteLine($"output: {output}");
Assert.AreEqual((5, 4), output.shape);
}
+
+ [TestMethod]
+ public void GRUCell()
+ {
+ var inputs = tf.random.normal((32, 10, 8));
+ var rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(4));
+ var output = rnn.Apply(inputs);
+ Assert.AreEqual((32, 4), output.shape);
+ rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(4, reset_after:false, use_bias:false));
+ output = rnn.Apply(inputs);
+ Assert.AreEqual((32, 4), output.shape);
+
+ }
}
}