Browse Source

Fix bug: fix the bug of generate_zero_filled_state when in graph mode

pull/1106/head
Wanglongzhi2001 2 years ago
parent
commit
b14ca69162
4 changed files with 14 additions and 17 deletions
  1. +1
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  2. +1
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  3. +12
    -5
      src/TensorFlowNET.Keras/Utils/RnnUtils.cs
  4. +0
    -10
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs

+ 1
- 1
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -578,7 +578,7 @@ namespace Tensorflow.Keras.Layers.Rnn
//}
else
{
init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype);
init_state = RnnUtils.generate_zero_filled_state(tf.convert_to_tensor(batch_size), _cell.StateSize, dtype);
}

return init_state;


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Layers.Rnn
{
prev_output = math_ops.multiply(prev_output, rec_dp_mask);
}
var tmp = _recurrent_kernel.AsTensor();
Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor());

if (_args.Activation != null)


+ 12
- 5
src/TensorFlowNET.Keras/Utils/RnnUtils.cs View File

@@ -5,19 +5,25 @@ using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Common.Extensions;
using System.Linq;

namespace Tensorflow.Keras.Utils
{
internal static class RnnUtils
{
internal static Tensors generate_zero_filled_state(long batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype)
internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype)
{
Func<GeneralizedTensorShape, Tensor> create_zeros;
create_zeros = (GeneralizedTensorShape unnested_state_size) =>
{
var flat_dims = unnested_state_size.ToSingleShape().dims;
var init_state_size = new long[] { batch_size_tensor }.Concat(flat_dims).ToArray();
return array_ops.zeros(new Shape(init_state_size), dtype: dtype);
var init_state_size = new List<object> { batch_size_tensor};
foreach(var dim in flat_dims)
{
init_state_size.add(dim);
}
var init_state_size_tensor = ops.convert_to_tensor(init_state_size.ToArray());
return array_ops.zeros(init_state_size_tensor);
};

// TODO(Rinne): map structure with nested tensors.
@@ -34,12 +40,13 @@ namespace Tensorflow.Keras.Utils

internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype)
{
Tensor batch_size_tensor = tf.convert_to_tensor(batch_size);
if (inputs != null)
{
batch_size = inputs.shape[0];
batch_size_tensor = tf.shape(inputs)[0];
dtype = inputs.dtype;
}
return generate_zero_filled_state(batch_size, cell.StateSize, dtype);
return generate_zero_filled_state(batch_size_tensor, cell.StateSize, dtype);
}

/// <summary>


+ 0
- 10
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -99,16 +99,6 @@ namespace Tensorflow.Keras.UnitTest.Layers
Assert.AreEqual((32, 5), output.shape);
}

[TestMethod]
public void WlzTest()
{
long[] b = { 1, 2, 3 };
Shape a = new Shape(Unknown).concatenate(b);
Console.WriteLine(a);

}


}
}

Loading…
Cancel
Save