diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index e60ba6fc..7f596500 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -206,7 +206,9 @@ namespace Tensorflow.Keras.Layers bool use_bias = true, string kernel_initializer = "glorot_uniform", string recurrent_initializer = "orthogonal", - string bias_initializer = "zeros"); + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f); public ILayer Subtract(); } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 3456d9b3..68d561ae 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -4633,8 +4633,9 @@ public static class gen_math_ops var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatMul", name) { args = new object[] { a, b }, attrs = new Dictionary() { ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b } }); return _fast_path_result[0]; } - catch (Exception) + catch (ArgumentException) { + throw new ArgumentException("In[0] and In[1] has diffrent ndims!"); } try { diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 02e9d995..35410337 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -715,7 +715,9 @@ namespace Tensorflow.Keras.Layers bool use_bias = true, string kernel_initializer = "glorot_uniform", string recurrent_initializer = "orthogonal", - string bias_initializer = "zeros") + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f) => new SimpleRNNCell(new SimpleRNNArgs { Units = units, @@ -723,6 +725,8 @@ namespace Tensorflow.Keras.Layers UseBias = use_bias, KernelInitializer = GetInitializerByName(kernel_initializer), RecurrentInitializer = GetInitializerByName(recurrent_initializer), + Dropout = dropout, + RecurrentDropout = recurrent_dropout } ); diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs b/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs index fcf9b596..b9a6fbc3 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs @@ -13,9 +13,10 @@ namespace Tensorflow.Keras.Layers.Rnn public float dropout; public float recurrent_dropout; // Get the dropout mask for RNN cell's input. - public Tensors get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) + public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) { - + if (dropout == 0f) + return null; return _generate_dropout_mask( tf.ones_like(input), dropout, @@ -24,8 +25,10 @@ namespace Tensorflow.Keras.Layers.Rnn } // Get the recurrent dropout mask for RNN cell. - public Tensors get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) + public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) { + if (dropout == 0f) + return null; return _generate_dropout_mask( tf.ones_like(input), recurrent_dropout, diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 0bca437b..ad2e9484 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -58,10 +58,7 @@ namespace Tensorflow.Keras.Layers.Rnn protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { - Console.WriteLine($"shape of input: {inputs.shape}"); Tensor states = initial_state[0]; - Console.WriteLine($"shape of initial_state: {states.shape}"); - var prev_output = nest.is_nested(states) ? states[0] : states; var dp_mask = DRCMixin.get_dropout_maskcell_for_cell(inputs, training.Value); var rec_dp_mask = DRCMixin.get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); @@ -72,11 +69,12 @@ namespace Tensorflow.Keras.Layers.Rnn { if (ranks > 2) { - h = tf.linalg.tensordot(tf.multiply(inputs, dp_mask), kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); + // 因为multiply函数会自动添加第一个维度,所以加上下标0 + h = tf.linalg.tensordot(math_ops.multiply(inputs, dp_mask)[0], kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); } else { - h = math_ops.matmul(tf.multiply(inputs, dp_mask), kernel.AsTensor()); + h = math_ops.matmul(math_ops.multiply(inputs, dp_mask)[0], kernel.AsTensor()); } } else @@ -98,22 +96,18 @@ namespace Tensorflow.Keras.Layers.Rnn if (rec_dp_mask != null) { - prev_output = tf.multiply(prev_output, rec_dp_mask); + prev_output = math_ops.multiply(prev_output, rec_dp_mask)[0]; } ranks = prev_output.rank; - Console.WriteLine($"shape of h: {h.shape}"); - Tensor output; if (ranks > 2) { - var tmp = tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); - output = h + tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } })[0]; + output = h + tf.linalg.tensordot(prev_output[0], recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); } else { - output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor())[0]; - + output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor()); } Console.WriteLine($"shape of output: {output.shape}"); diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index b3d45729..c4888a39 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -147,13 +147,15 @@ namespace Tensorflow.Keras.UnitTest.Layers [TestMethod] public void SimpleRNNCell() { + var cell = keras.layers.SimpleRNNCell(64, dropout:0.5f, recurrent_dropout:0.5f); var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; - var x = tf.random.normal(new Shape(4, 100)); - var cell = keras.layers.SimpleRNNCell(64); - var (y, h1) = cell.Apply(inputs:x, state:h0); + var x = tf.random.normal((4, 100)); + var (y, h1) = cell.Apply(inputs: x, state: h0); + // TODO(Wanglongzhi2001),因为SimpleRNNCell需要返回一个Tensor和一个Tensors,只用一个Tensors的话 + // hold不住,所以自行在外面将h强制转换成Tensors + var h2 = (Tensors)h1; Assert.AreEqual((4, 64), y.shape); - // this test now cannot pass, need to deal with SimpleRNNCell's Call method - //Assert.AreEqual((4, 64), h1[0].shape); + Assert.AreEqual((4, 64), h2[0].shape); } [TestMethod, Ignore("WIP")]