Browse Source

Finish SimpleRNNCell and add test

pull/1090/head
Wanglongzhi2001 2 years ago
parent
commit
08b4b89f77
6 changed files with 29 additions and 23 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  3. +5
    -1
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  4. +6
    -3
      src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs
  5. +6
    -12
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  6. +7
    -5
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

+ 3
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -206,7 +206,9 @@ namespace Tensorflow.Keras.Layers
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros");
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f);

public ILayer Subtract();
}


+ 2
- 1
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -4633,8 +4633,9 @@ public static class gen_math_ops
var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatMul", name) { args = new object[] { a, b }, attrs = new Dictionary<string, object>() { ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b } });
return _fast_path_result[0];
}
catch (Exception)
catch (ArgumentException)
{
throw new ArgumentException("In[0] and In[1] has diffrent ndims!");
}
try
{


+ 5
- 1
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -715,7 +715,9 @@ namespace Tensorflow.Keras.Layers
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros")
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f)
=> new SimpleRNNCell(new SimpleRNNArgs
{
Units = units,
@@ -723,6 +725,8 @@ namespace Tensorflow.Keras.Layers
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
Dropout = dropout,
RecurrentDropout = recurrent_dropout
}
);



+ 6
- 3
src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs View File

@@ -13,9 +13,10 @@ namespace Tensorflow.Keras.Layers.Rnn
public float dropout;
public float recurrent_dropout;
// Get the dropout mask for RNN cell's input.
public Tensors get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
{

if (dropout == 0f)
return null;
return _generate_dropout_mask(
tf.ones_like(input),
dropout,
@@ -24,8 +25,10 @@ namespace Tensorflow.Keras.Layers.Rnn
}

// Get the recurrent dropout mask for RNN cell.
public Tensors get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
{
if (dropout == 0f)
return null;
return _generate_dropout_mask(
tf.ones_like(input),
recurrent_dropout,


+ 6
- 12
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -58,10 +58,7 @@ namespace Tensorflow.Keras.Layers.Rnn

protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Console.WriteLine($"shape of input: {inputs.shape}");
Tensor states = initial_state[0];
Console.WriteLine($"shape of initial_state: {states.shape}");

var prev_output = nest.is_nested(states) ? states[0] : states;
var dp_mask = DRCMixin.get_dropout_maskcell_for_cell(inputs, training.Value);
var rec_dp_mask = DRCMixin.get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value);
@@ -72,11 +69,12 @@ namespace Tensorflow.Keras.Layers.Rnn
{
if (ranks > 2)
{
h = tf.linalg.tensordot(tf.multiply(inputs, dp_mask), kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
// 因为multiply函数会自动添加第一个维度,所以加上下标0
h = tf.linalg.tensordot(math_ops.multiply(inputs, dp_mask)[0], kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
}
else
{
h = math_ops.matmul(tf.multiply(inputs, dp_mask), kernel.AsTensor());
h = math_ops.matmul(math_ops.multiply(inputs, dp_mask)[0], kernel.AsTensor());
}
}
else
@@ -98,22 +96,18 @@ namespace Tensorflow.Keras.Layers.Rnn

if (rec_dp_mask != null)
{
prev_output = tf.multiply(prev_output, rec_dp_mask);
prev_output = math_ops.multiply(prev_output, rec_dp_mask)[0];
}

ranks = prev_output.rank;
Console.WriteLine($"shape of h: {h.shape}");

Tensor output;
if (ranks > 2)
{
var tmp = tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
output = h + tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } })[0];
output = h + tf.linalg.tensordot(prev_output[0], recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
}
else
{
output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor())[0];

output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor());
}
Console.WriteLine($"shape of output: {output.shape}");



+ 7
- 5
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -147,13 +147,15 @@ namespace Tensorflow.Keras.UnitTest.Layers
[TestMethod]
public void SimpleRNNCell()
{
var cell = keras.layers.SimpleRNNCell(64, dropout:0.5f, recurrent_dropout:0.5f);
var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
var x = tf.random.normal(new Shape(4, 100));
var cell = keras.layers.SimpleRNNCell(64);
var (y, h1) = cell.Apply(inputs:x, state:h0);
var x = tf.random.normal((4, 100));
var (y, h1) = cell.Apply(inputs: x, state: h0);
// TODO(Wanglongzhi2001),因为SimpleRNNCell需要返回一个Tensor和一个Tensors,只用一个Tensors的话
// hold不住,所以自行在外面将h强制转换成Tensors
var h2 = (Tensors)h1;
Assert.AreEqual((4, 64), y.shape);
// this test now cannot pass, need to deal with SimpleRNNCell's Call method
//Assert.AreEqual((4, 64), h1[0].shape);
Assert.AreEqual((4, 64), h2[0].shape);
}

[TestMethod, Ignore("WIP")]


Loading…
Cancel
Save