diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs index 594c99bb..786236e4 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs @@ -1,7 +1,35 @@ -namespace Tensorflow.Keras.ArgsDefinition.Rnn +using Newtonsoft.Json; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.ArgsDefinition.Rnn { // TODO: complete the implementation - public class LSTMCellArgs : LayerArgs + public class LSTMCellArgs : AutoSerializeLayerArgs { + [JsonProperty("units")] + public int Units { get; set; } + // TODO(Rinne): lack of initialized value of Activation. Merging keras + // into tf.net could resolve it. + [JsonProperty("activation")] + public Activation Activation { get; set; } + [JsonProperty("recurrent_activation")] + public Activation RecurrentActivation { get; set; } + [JsonProperty("use_bias")] + public bool UseBias { get; set; } = true; + [JsonProperty("dropout")] + public float Dropout { get; set; } = .0f; + [JsonProperty("recurrent_dropout")] + public float RecurrentDropout { get; set; } = .0f; + [JsonProperty("kernel_initializer")] + public IInitializer KernelInitializer { get; set; } + [JsonProperty("recurrent_initializer")] + public IInitializer RecurrentInitializer { get; set; } + [JsonProperty("bias_initializer")] + public IInitializer BiasInitializer { get; set; } + [JsonProperty("unit_forget_bias")] + public bool UnitForgetBias { get; set; } = true; + [JsonProperty("implementation")] + public int Implementation { get; set; } = 2; + } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs index 1dfcbe9c..d21d6190 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs @@ -1,7 +1,4 @@ using Newtonsoft.Json; -using System; -using System.Collections.Generic; -using System.Text; namespace Tensorflow.Keras.ArgsDefinition.Rnn { @@ -25,5 +22,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn public IInitializer RecurrentInitializer { get; set; } [JsonProperty("bias_initializer")] public IInitializer BiasInitializer { get; set; } + } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index 3b223816..a19508d4 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -160,6 +160,18 @@ namespace Tensorflow.Keras.Layers public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); public ILayer LeakyReLU(float alpha = 0.3f); + public IRnnCell LSTMCell(int uints, + string activation = "tanh", + string recurrent_activation = "sigmoid", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + bool unit_forget_bias = true, + float dropout = 0f, + float recurrent_dropout = 0f, + int implementation = 2); + public ILayer LSTM(int units, Activation activation = null, Activation recurrent_activation = null, diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs index 88673bb5..ae873374 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs @@ -58,8 +58,7 @@ public class Orthogonal : IInitializer if (num_rows < num_cols) { - // q = tf.linalg.matrix_transpose(q); - throw new NotImplementedException(""); + q = array_ops.matrix_transpose(q); } return _gain * tf.reshape(q, shape); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index ca9e5fae..c4ec974b 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -971,6 +971,49 @@ namespace Tensorflow }); } + /// + /// Transposes last two dimensions of tensor `a`. + /// For example: + /// python + /// x = tf.constant([[1, 2, 3], [4, 5, 6]]) + /// tf.matrix_transpose(x) # [[1, 4], + /// # [2, 5], + /// # [3, 6]] + /// + /// Matrix with two batch dimensions. + /// x.shape is [1, 2, 3, 4] + /// tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3] + /// + /// + /// + /// + /// + /// + public static Tensor matrix_transpose(Tensor a, string name = "matrix_transpose", bool conjugate = false) + { + return tf_with(ops.name_scope(name, "transpose", new { a }), scope => + { + var a_shape = a.shape; + var ndims = a.shape.ndim; + Axis perm; + if(ndims != 0) + { + if (ndims < 2) + { + throw new ValueError("Argument `a` should be a (batch) matrix with rank " + + $">= 2. Received `a` = {a} with shape: {a_shape}"); + } + perm = new Axis(Enumerable.Range(0, ndims - 2).Concat(new int[] { ndims - 1, ndims - 2 }).ToArray()); + } + else + { + var a_rank = a.rank; + perm = new Axis(Enumerable.Range(0, a_rank - 2).Concat(new int[] { a_rank - 1, a_rank - 2 }).ToArray()); + } + return transpose(a, perm:perm, conjugate:conjugate); + }); + } + public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1, string name = "split") { diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index dd25122d..66c3cdc1 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -702,6 +702,7 @@ namespace Tensorflow.Keras.Layers UseBias = use_bias, KernelInitializer = GetInitializerByName(kernel_initializer), RecurrentInitializer = GetInitializerByName(recurrent_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), Dropout = dropout, RecurrentDropout = recurrent_dropout }); @@ -786,6 +787,33 @@ namespace Tensorflow.Keras.Layers TimeMajor = time_major }); + + public IRnnCell LSTMCell(int uints, + string activation = "tanh", + string recurrent_activation = "sigmoid", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed. + string bias_initializer = "zeros", + bool unit_forget_bias = true, + float dropout = 0f, + float recurrent_dropout = 0f, + int implementation = 2) + => new LSTMCell(new LSTMCellArgs + { + Units = uints, + Activation = keras.activations.GetActivationFromName(activation), + RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation), + UseBias = use_bias, + KernelInitializer = GetInitializerByName(kernel_initializer), + RecurrentInitializer = GetInitializerByName(recurrent_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + UnitForgetBias = unit_forget_bias, + Dropout = dropout, + RecurrentDropout = recurrent_dropout, + Implementation = implementation + }); + /// /// Long Short-Term Memory layer - Hochreiter 1997. /// diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs index d2669ccc..1cc36d34 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs @@ -41,7 +41,7 @@ namespace Tensorflow.Keras.Layers.Rnn } - public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) + public Tensors? get_dropout_mask_for_cell(Tensors input, bool training, int count = 1) { if (dropout == 0f) return null; @@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Layers.Rnn } // Get the recurrent dropout mask for RNN cell. - public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) + public Tensors? get_recurrent_dropout_mask_for_cell(Tensors input, bool training, int count = 1) { if (dropout == 0f) return null; diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs index a622c91a..94d98e13 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs @@ -1,16 +1,240 @@ -using Tensorflow.Keras.ArgsDefinition.Rnn; +using Serilog.Core; +using System.Diagnostics; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Layers.Rnn { - public class LSTMCell : Layer + /// + /// Cell class for the LSTM layer. + /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) + /// for details about the usage of RNN API. + /// This class processes one step within the whole time sequence input, whereas + /// `tf.keras.layer.LSTM` processes the whole sequence. + /// + public class LSTMCell : DropoutRNNCellMixin { - LSTMCellArgs args; + LSTMCellArgs _args; + IVariableV1 _kernel; + IVariableV1 _recurrent_kernel; + IInitializer _bias_initializer; + IVariableV1 _bias; + GeneralizedTensorShape _state_size; + GeneralizedTensorShape _output_size; + public override GeneralizedTensorShape StateSize => _state_size; + public override GeneralizedTensorShape OutputSize => _output_size; + + public override bool IsTFRnnCell => true; + + public override bool SupportOptionalArgs => false; public LSTMCell(LSTMCellArgs args) : base(args) { - this.args = args; + _args = args; + if (args.Units <= 0) + { + throw new ValueError( + $"units must be a positive integer, got {args.Units}"); + } + _args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); + _args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); + if (_args.RecurrentDropout != 0f && _args.Implementation != 1) + { + Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." + + "Using `implementation=1`."); + _args.Implementation = 1; + } + + _state_size = new GeneralizedTensorShape(_args.Units, 2); + _output_size = new GeneralizedTensorShape(_args.Units); + + + } + + public override void build(KerasShapesWrapper input_shape) + { + var single_shape = input_shape.ToSingleShape(); + var input_dim = single_shape[-1]; + _kernel = add_weight("kernel", (input_dim, _args.Units * 4), + initializer: _args.KernelInitializer + ); + + _recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units * 4), + initializer: _args.RecurrentInitializer + ); + + if (_args.UseBias) + { + if (_args.UnitForgetBias) + { + Tensor bias_initializer() + { + return keras.backend.concatenate( + new Tensors( + _args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units))), + tf.ones_initializer.Apply(new InitializerArgs(shape: (_args.Units))), + _args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units)))), axis: 0); + } + } + else + { + _bias_initializer = _args.BiasInitializer; + } + _bias = add_weight("bias", (_args.Units * 4), + initializer: _args.BiasInitializer); + } + built = true; + } + protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) + { + var h_tm1 = states[0]; // previous memory state + var c_tm1 = states[1]; // previous carry state + + var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 4); + var rec_dp_mask = get_recurrent_dropout_mask_for_cell( + h_tm1, training.Value, count: 4); + + + Tensor c; + Tensor o; + if (_args.Implementation == 1) + { + Tensor inputs_i; + Tensor inputs_f; + Tensor inputs_c; + Tensor inputs_o; + if (0f < _args.Dropout && _args.Dropout < 1f) + { + inputs_i = inputs * dp_mask[0]; + inputs_f = inputs * dp_mask[1]; + inputs_c = inputs * dp_mask[2]; + inputs_o = inputs * dp_mask[3]; + } + else + { + inputs_i = inputs; + inputs_f = inputs; + inputs_c = inputs; + inputs_o = inputs; + } + var k = tf.split(_kernel.AsTensor(), num_split: 4, axis: 1); + Tensor k_i = k[0], k_f = k[1], k_c = k[2], k_o = k[3]; + var x_i = math_ops.matmul(inputs_i, k_i); + var x_f = math_ops.matmul(inputs_f, k_f); + var x_c = math_ops.matmul(inputs_c, k_c); + var x_o = math_ops.matmul(inputs_o, k_o); + if(_args.UseBias) + { + var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0); + Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3]; + x_i = gen_nn_ops.bias_add(x_i, b_i); + x_f = gen_nn_ops.bias_add(x_f, b_f); + x_c = gen_nn_ops.bias_add(x_c, b_c); + x_o = gen_nn_ops.bias_add(x_o, b_o); + } + + Tensor h_tm1_i; + Tensor h_tm1_f; + Tensor h_tm1_c; + Tensor h_tm1_o; + if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f) + { + h_tm1_i = h_tm1 * rec_dp_mask[0]; + h_tm1_f = h_tm1 * rec_dp_mask[1]; + h_tm1_c = h_tm1 * rec_dp_mask[2]; + h_tm1_o = h_tm1 * rec_dp_mask[3]; + } + else + { + h_tm1_i = h_tm1; + h_tm1_f = h_tm1; + h_tm1_c = h_tm1; + h_tm1_o = h_tm1; + } + var x = new Tensor[] { x_i, x_f, x_c, x_o }; + var h_tm1_array = new Tensor[] { h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o }; + (c, o) = _compute_carry_and_output(x, h_tm1_array, c_tm1); + } + else + { + if (0f < _args.Dropout && _args.Dropout < 1f) + inputs = inputs * dp_mask[0]; + var z = math_ops.matmul(inputs, _kernel.AsTensor()); + z += math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor()); + if (_args.UseBias) + { + z = tf.nn.bias_add(z, _bias); + } + var z_array = tf.split(z, num_split: 4, axis: 1); + (c, o) = _compute_carry_and_output_fused(z_array, c_tm1); + } + var h = o * _args.Activation.Apply(c); + // 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组 + return new Tensors(h, h, c); + } + + /// + /// Computes carry and output using split kernels. + /// + /// + /// + /// + /// + /// + public Tensors _compute_carry_and_output(Tensor[] x, Tensor[] h_tm1, Tensor c_tm1) + { + Tensor x_i = x[0], x_f = x[1], x_c = x[2], x_o = x[3]; + Tensor h_tm1_i = h_tm1[0], h_tm1_f = h_tm1[1], h_tm1_c = h_tm1[2], + h_tm1_o = h_tm1[3]; + + var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor(); + var startIndex = _recurrent_kernel_tensor.shape[0]; + var endIndex = _recurrent_kernel_tensor.shape[1]; + var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, + new[] { 0, 0 }, new[] { startIndex, _args.Units }); + var i = _args.RecurrentActivation.Apply( + x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice)); + _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, + new[] { 0, _args.Units }, new[] { startIndex, _args.Units * 2}); + var f = _args.RecurrentActivation.Apply( + x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice)); + _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, + new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units * 3 }); + var c = f * c_tm1 + i * _args.Activation.Apply( + x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice)); + _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, + new[] { 0, _args.Units * 3 }, new[] { startIndex, endIndex }); + var o = _args.RecurrentActivation.Apply( + x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice)); + + return new Tensors(c, o); + } + + /// + /// Computes carry and output using fused kernels. + /// + /// + /// + /// + public Tensors _compute_carry_and_output_fused(Tensor[] z, Tensor c_tm1) + { + Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3]; + var i = _args.RecurrentActivation.Apply(z0); + var f = _args.RecurrentActivation.Apply(z1); + var c = f * c_tm1 + i * _args.RecurrentActivation.Apply(z2); + var o = _args.RecurrentActivation.Apply(z3); + return new Tensors(c, o); + } + + public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) + { + return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value); } } + + } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 3b4b9419..d318dc45 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -74,8 +74,8 @@ namespace Tensorflow.Keras.Layers.Rnn { // TODO(Rinne): check if it will have multiple tensors when not nested. Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states; - var dp_mask = get_dropout_maskcell_for_cell(inputs, training.Value); - var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); + var dp_mask = get_dropout_mask_for_cell(inputs, training.Value); + var rec_dp_mask = get_recurrent_dropout_mask_for_cell(prev_output, training.Value); Tensor h; var ranks = inputs.rank; diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs index fcb9ad1d..54ea1565 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs @@ -21,21 +21,6 @@ namespace Tensorflow.Keras.UnitTest.Layers [TestMethod] public void SimpleRNNCell() { - //var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); - //var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; - //var x = tf.random.normal((4, 100)); - //var (y, h1) = cell.Apply(inputs: x, states: h0); - //var h2 = h1; - //Assert.AreEqual((4, 64), y.shape); - //Assert.AreEqual((4, 64), h2[0].shape); - - //var model = keras.Sequential(new List - //{ - // keras.layers.InputLayer(input_shape: (4,100)), - // keras.layers.SimpleRNNCell(64) - //}); - //model.summary(); - var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; var x = tf.random.normal((4, 100)); @@ -59,6 +44,17 @@ namespace Tensorflow.Keras.UnitTest.Layers Assert.AreEqual((32, 4), state[0].shape); } + [TestMethod] + public void LSTMCell() + { + var inputs = tf.ones((2, 100)); + var states = new Tensors { tf.zeros((2, 4)), tf.zeros((2, 4)) }; + var rnn = tf.keras.layers.LSTMCell(4); + var (output, new_states) = rnn.Apply(inputs, states); + Assert.AreEqual((2, 4), output.shape); + Assert.AreEqual((2, 4), new_states[0].shape); + } + [TestMethod] public void SimpleRNN() { @@ -105,15 +101,27 @@ namespace Tensorflow.Keras.UnitTest.Layers } [TestMethod] - public void WlzTest() + public void RNNForLSTMCell() { - long[] b = { 1, 2, 3 }; - - Shape a = new Shape(Unknown).concatenate(b); - Console.WriteLine(a); - + var inputs = tf.ones((5, 10, 8)); + var rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4)); + var output = rnn.Apply(inputs); + Console.WriteLine($"output: {output}"); + Assert.AreEqual((5, 4), output.shape); } + [TestMethod] + public void MyTest() + { + var a = tf.zeros((2, 3)); + var b = tf.ones_like(a); + var c = tf.ones((3,4)); + + var d = new Tensors { a, b, c }; + var (A, BC) = d; + Console.WriteLine($"A:{A}"); + Console.WriteLine($"BC:{BC}"); + } } }