From b14ca691622e138aa305908d42810d98dfebf3c0 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Wed, 14 Jun 2023 12:58:41 +0800 Subject: [PATCH] Fix bug: fix the bug of generate_zero_filled_state when in graph mode --- src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 2 +- .../Layers/Rnn/SimpleRNNCell.cs | 2 +- src/TensorFlowNET.Keras/Utils/RnnUtils.cs | 17 ++++++++++++----- .../Layers/Rnn.Test.cs | 10 ---------- 4 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 0ebd7362..40b4d875 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -578,7 +578,7 @@ namespace Tensorflow.Keras.Layers.Rnn //} else { - init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype); + init_state = RnnUtils.generate_zero_filled_state(tf.convert_to_tensor(batch_size), _cell.StateSize, dtype); } return init_state; diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 39610ff5..165a439a 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Layers.Rnn { prev_output = math_ops.multiply(prev_output, rec_dp_mask); } - var tmp = _recurrent_kernel.AsTensor(); + Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); if (_args.Activation != null) diff --git a/src/TensorFlowNET.Keras/Utils/RnnUtils.cs b/src/TensorFlowNET.Keras/Utils/RnnUtils.cs index 3109eb77..132ceba9 100644 --- a/src/TensorFlowNET.Keras/Utils/RnnUtils.cs +++ b/src/TensorFlowNET.Keras/Utils/RnnUtils.cs @@ -5,19 +5,25 @@ using System.Text; using Tensorflow.Common.Types; using Tensorflow.Keras.Layers.Rnn; using Tensorflow.Common.Extensions; +using System.Linq; namespace Tensorflow.Keras.Utils { internal static class RnnUtils { - internal static Tensors generate_zero_filled_state(long batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) + internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) { Func create_zeros; create_zeros = (GeneralizedTensorShape unnested_state_size) => { var flat_dims = unnested_state_size.ToSingleShape().dims; - var init_state_size = new long[] { batch_size_tensor }.Concat(flat_dims).ToArray(); - return array_ops.zeros(new Shape(init_state_size), dtype: dtype); + var init_state_size = new List { batch_size_tensor}; + foreach(var dim in flat_dims) + { + init_state_size.add(dim); + } + var init_state_size_tensor = ops.convert_to_tensor(init_state_size.ToArray()); + return array_ops.zeros(init_state_size_tensor); }; // TODO(Rinne): map structure with nested tensors. @@ -34,12 +40,13 @@ namespace Tensorflow.Keras.Utils internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype) { + Tensor batch_size_tensor = tf.convert_to_tensor(batch_size); if (inputs != null) { - batch_size = inputs.shape[0]; + batch_size_tensor = tf.shape(inputs)[0]; dtype = inputs.dtype; } - return generate_zero_filled_state(batch_size, cell.StateSize, dtype); + return generate_zero_filled_state(batch_size_tensor, cell.StateSize, dtype); } /// diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs index 28a16ad4..59819d21 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs @@ -99,16 +99,6 @@ namespace Tensorflow.Keras.UnitTest.Layers Assert.AreEqual((32, 5), output.shape); } - [TestMethod] - public void WlzTest() - { - long[] b = { 1, 2, 3 }; - - Shape a = new Shape(Unknown).concatenate(b); - Console.WriteLine(a); - - } - } }