Browse Source

Add new feature: add LSTMCell and test

tags/v0.110.0-LSTM-Model
Wanglongzhi2001 Yaohui Liu 2 years ago
parent
commit
df7d700fb1
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
10 changed files with 376 additions and 36 deletions
  1. +30
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs
  2. +1
    -3
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs
  3. +12
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  4. +1
    -2
      src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs
  5. +43
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs
  6. +28
    -0
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  7. +2
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs
  8. +228
    -4
      src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs
  9. +2
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  10. +29
    -21
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs

+ 30
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs View File

@@ -1,7 +1,35 @@
namespace Tensorflow.Keras.ArgsDefinition.Rnn
using Newtonsoft.Json;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
// TODO: complete the implementation
public class LSTMCellArgs : LayerArgs
public class LSTMCellArgs : AutoSerializeLayerArgs
{
[JsonProperty("units")]
public int Units { get; set; }
// TODO(Rinne): lack of initialized value of Activation. Merging keras
// into tf.net could resolve it.
[JsonProperty("activation")]
public Activation Activation { get; set; }
[JsonProperty("recurrent_activation")]
public Activation RecurrentActivation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;
[JsonProperty("dropout")]
public float Dropout { get; set; } = .0f;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; }
[JsonProperty("recurrent_initializer")]
public IInitializer RecurrentInitializer { get; set; }
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; }
[JsonProperty("unit_forget_bias")]
public bool UnitForgetBias { get; set; } = true;
[JsonProperty("implementation")]
public int Implementation { get; set; } = 2;

}
}

+ 1
- 3
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs View File

@@ -1,7 +1,4 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
@@ -25,5 +22,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
public IInitializer RecurrentInitializer { get; set; }
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; }

}
}

+ 12
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -160,6 +160,18 @@ namespace Tensorflow.Keras.Layers
public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false);
public ILayer LeakyReLU(float alpha = 0.3f);

public IRnnCell LSTMCell(int uints,
string activation = "tanh",
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
bool unit_forget_bias = true,
float dropout = 0f,
float recurrent_dropout = 0f,
int implementation = 2);

public ILayer LSTM(int units,
Activation activation = null,
Activation recurrent_activation = null,


+ 1
- 2
src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs View File

@@ -58,8 +58,7 @@ public class Orthogonal : IInitializer

if (num_rows < num_cols)
{
// q = tf.linalg.matrix_transpose(q);
throw new NotImplementedException("");
q = array_ops.matrix_transpose(q);
}

return _gain * tf.reshape(q, shape);


+ 43
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -971,6 +971,49 @@ namespace Tensorflow
});
}

/// <summary>
/// Transposes last two dimensions of tensor `a`.
/// For example:
/// <code> python
/// x = tf.constant([[1, 2, 3], [4, 5, 6]])
/// tf.matrix_transpose(x) # [[1, 4],
/// # [2, 5],
/// # [3, 6]]
/// </code>
/// Matrix with two batch dimensions.
/// x.shape is [1, 2, 3, 4]
/// tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3]
/// </summary>
/// <param name="a"></param>
/// <param name="name"></param>
/// <param name="conjugate"></param>
/// <returns></returns>
/// <exception cref="ValueError"></exception>
public static Tensor matrix_transpose(Tensor a, string name = "matrix_transpose", bool conjugate = false)
{
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
var a_shape = a.shape;
var ndims = a.shape.ndim;
Axis perm;
if(ndims != 0)
{
if (ndims < 2)
{
throw new ValueError("Argument `a` should be a (batch) matrix with rank " +
$">= 2. Received `a` = {a} with shape: {a_shape}");
}
perm = new Axis(Enumerable.Range(0, ndims - 2).Concat(new int[] { ndims - 1, ndims - 2 }).ToArray());
}
else
{
var a_rank = a.rank;
perm = new Axis(Enumerable.Range(0, a_rank - 2).Concat(new int[] { a_rank - 1, a_rank - 2 }).ToArray());
}
return transpose(a, perm:perm, conjugate:conjugate);
});
}

public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1,
string name = "split")
{


+ 28
- 0
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -702,6 +702,7 @@ namespace Tensorflow.Keras.Layers
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
Dropout = dropout,
RecurrentDropout = recurrent_dropout
});
@@ -786,6 +787,33 @@ namespace Tensorflow.Keras.Layers
TimeMajor = time_major
});


public IRnnCell LSTMCell(int uints,
string activation = "tanh",
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed.
string bias_initializer = "zeros",
bool unit_forget_bias = true,
float dropout = 0f,
float recurrent_dropout = 0f,
int implementation = 2)
=> new LSTMCell(new LSTMCellArgs
{
Units = uints,
Activation = keras.activations.GetActivationFromName(activation),
RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation),
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
UnitForgetBias = unit_forget_bias,
Dropout = dropout,
RecurrentDropout = recurrent_dropout,
Implementation = implementation
});

/// <summary>
/// Long Short-Term Memory layer - Hochreiter 1997.
/// </summary>


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs View File

@@ -41,7 +41,7 @@ namespace Tensorflow.Keras.Layers.Rnn

}

public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
public Tensors? get_dropout_mask_for_cell(Tensors input, bool training, int count = 1)
{
if (dropout == 0f)
return null;
@@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Layers.Rnn
}

// Get the recurrent dropout mask for RNN cell.
public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
public Tensors? get_recurrent_dropout_mask_for_cell(Tensors input, bool training, int count = 1)
{
if (dropout == 0f)
return null;


+ 228
- 4
src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs View File

@@ -1,16 +1,240 @@
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Serilog.Core;
using System.Diagnostics;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers.Rnn
{
public class LSTMCell : Layer
/// <summary>
/// Cell class for the LSTM layer.
/// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
/// for details about the usage of RNN API.
/// This class processes one step within the whole time sequence input, whereas
/// `tf.keras.layer.LSTM` processes the whole sequence.
/// </summary>
public class LSTMCell : DropoutRNNCellMixin
{
LSTMCellArgs args;
LSTMCellArgs _args;
IVariableV1 _kernel;
IVariableV1 _recurrent_kernel;
IInitializer _bias_initializer;
IVariableV1 _bias;
GeneralizedTensorShape _state_size;
GeneralizedTensorShape _output_size;
public override GeneralizedTensorShape StateSize => _state_size;

public override GeneralizedTensorShape OutputSize => _output_size;

public override bool IsTFRnnCell => true;

public override bool SupportOptionalArgs => false;
public LSTMCell(LSTMCellArgs args)
: base(args)
{
this.args = args;
_args = args;
if (args.Units <= 0)
{
throw new ValueError(
$"units must be a positive integer, got {args.Units}");
}
_args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout));
_args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
if (_args.RecurrentDropout != 0f && _args.Implementation != 1)
{
Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." +
"Using `implementation=1`.");
_args.Implementation = 1;
}

_state_size = new GeneralizedTensorShape(_args.Units, 2);
_output_size = new GeneralizedTensorShape(_args.Units);


}

public override void build(KerasShapesWrapper input_shape)
{
var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1];
_kernel = add_weight("kernel", (input_dim, _args.Units * 4),
initializer: _args.KernelInitializer
);

_recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units * 4),
initializer: _args.RecurrentInitializer
);

if (_args.UseBias)
{
if (_args.UnitForgetBias)
{
Tensor bias_initializer()
{
return keras.backend.concatenate(
new Tensors(
_args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units))),
tf.ones_initializer.Apply(new InitializerArgs(shape: (_args.Units))),
_args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units)))), axis: 0);
}
}
else
{
_bias_initializer = _args.BiasInitializer;
}
_bias = add_weight("bias", (_args.Units * 4),
initializer: _args.BiasInitializer);
}
built = true;
}
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
{
var h_tm1 = states[0]; // previous memory state
var c_tm1 = states[1]; // previous carry state

var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 4);
var rec_dp_mask = get_recurrent_dropout_mask_for_cell(
h_tm1, training.Value, count: 4);


Tensor c;
Tensor o;
if (_args.Implementation == 1)
{
Tensor inputs_i;
Tensor inputs_f;
Tensor inputs_c;
Tensor inputs_o;
if (0f < _args.Dropout && _args.Dropout < 1f)
{
inputs_i = inputs * dp_mask[0];
inputs_f = inputs * dp_mask[1];
inputs_c = inputs * dp_mask[2];
inputs_o = inputs * dp_mask[3];
}
else
{
inputs_i = inputs;
inputs_f = inputs;
inputs_c = inputs;
inputs_o = inputs;
}
var k = tf.split(_kernel.AsTensor(), num_split: 4, axis: 1);
Tensor k_i = k[0], k_f = k[1], k_c = k[2], k_o = k[3];
var x_i = math_ops.matmul(inputs_i, k_i);
var x_f = math_ops.matmul(inputs_f, k_f);
var x_c = math_ops.matmul(inputs_c, k_c);
var x_o = math_ops.matmul(inputs_o, k_o);
if(_args.UseBias)
{
var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0);
Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3];
x_i = gen_nn_ops.bias_add(x_i, b_i);
x_f = gen_nn_ops.bias_add(x_f, b_f);
x_c = gen_nn_ops.bias_add(x_c, b_c);
x_o = gen_nn_ops.bias_add(x_o, b_o);
}

Tensor h_tm1_i;
Tensor h_tm1_f;
Tensor h_tm1_c;
Tensor h_tm1_o;
if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f)
{
h_tm1_i = h_tm1 * rec_dp_mask[0];
h_tm1_f = h_tm1 * rec_dp_mask[1];
h_tm1_c = h_tm1 * rec_dp_mask[2];
h_tm1_o = h_tm1 * rec_dp_mask[3];
}
else
{
h_tm1_i = h_tm1;
h_tm1_f = h_tm1;
h_tm1_c = h_tm1;
h_tm1_o = h_tm1;
}
var x = new Tensor[] { x_i, x_f, x_c, x_o };
var h_tm1_array = new Tensor[] { h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o };
(c, o) = _compute_carry_and_output(x, h_tm1_array, c_tm1);
}
else
{
if (0f < _args.Dropout && _args.Dropout < 1f)
inputs = inputs * dp_mask[0];
var z = math_ops.matmul(inputs, _kernel.AsTensor());
z += math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor());
if (_args.UseBias)
{
z = tf.nn.bias_add(z, _bias);
}
var z_array = tf.split(z, num_split: 4, axis: 1);
(c, o) = _compute_carry_and_output_fused(z_array, c_tm1);
}
var h = o * _args.Activation.Apply(c);
// 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组
return new Tensors(h, h, c);
}

/// <summary>
/// Computes carry and output using split kernels.
/// </summary>
/// <param name="x"></param>
/// <param name="h_tm1"></param>
/// <param name="c_tm1"></param>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
public Tensors _compute_carry_and_output(Tensor[] x, Tensor[] h_tm1, Tensor c_tm1)
{
Tensor x_i = x[0], x_f = x[1], x_c = x[2], x_o = x[3];
Tensor h_tm1_i = h_tm1[0], h_tm1_f = h_tm1[1], h_tm1_c = h_tm1[2],
h_tm1_o = h_tm1[3];

var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor();
var startIndex = _recurrent_kernel_tensor.shape[0];
var endIndex = _recurrent_kernel_tensor.shape[1];
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, 0 }, new[] { startIndex, _args.Units });
var i = _args.RecurrentActivation.Apply(
x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units }, new[] { startIndex, _args.Units * 2});
var f = _args.RecurrentActivation.Apply(
x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units * 3 });
var c = f * c_tm1 + i * _args.Activation.Apply(
x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units * 3 }, new[] { startIndex, endIndex });
var o = _args.RecurrentActivation.Apply(
x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice));

return new Tensors(c, o);
}

/// <summary>
/// Computes carry and output using fused kernels.
/// </summary>
/// <param name="z"></param>
/// <param name="c_tm1"></param>
/// <returns></returns>
public Tensors _compute_carry_and_output_fused(Tensor[] z, Tensor c_tm1)
{
Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3];
var i = _args.RecurrentActivation.Apply(z0);
var f = _args.RecurrentActivation.Apply(z1);
var c = f * c_tm1 + i * _args.RecurrentActivation.Apply(z2);
var o = _args.RecurrentActivation.Apply(z3);
return new Tensors(c, o);
}

public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value);
}
}

}

+ 2
- 2
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -74,8 +74,8 @@ namespace Tensorflow.Keras.Layers.Rnn
{
// TODO(Rinne): check if it will have multiple tensors when not nested.
Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states;
var dp_mask = get_dropout_maskcell_for_cell(inputs, training.Value);
var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value);
var dp_mask = get_dropout_mask_for_cell(inputs, training.Value);
var rec_dp_mask = get_recurrent_dropout_mask_for_cell(prev_output, training.Value);

Tensor h;
var ranks = inputs.rank;


+ 29
- 21
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -21,21 +21,6 @@ namespace Tensorflow.Keras.UnitTest.Layers
[TestMethod]
public void SimpleRNNCell()
{
//var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f);
//var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
//var x = tf.random.normal((4, 100));
//var (y, h1) = cell.Apply(inputs: x, states: h0);
//var h2 = h1;
//Assert.AreEqual((4, 64), y.shape);
//Assert.AreEqual((4, 64), h2[0].shape);

//var model = keras.Sequential(new List<ILayer>
//{
// keras.layers.InputLayer(input_shape: (4,100)),
// keras.layers.SimpleRNNCell(64)
//});
//model.summary();

var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f);
var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
var x = tf.random.normal((4, 100));
@@ -59,6 +44,17 @@ namespace Tensorflow.Keras.UnitTest.Layers
Assert.AreEqual((32, 4), state[0].shape);
}

[TestMethod]
public void LSTMCell()
{
var inputs = tf.ones((2, 100));
var states = new Tensors { tf.zeros((2, 4)), tf.zeros((2, 4)) };
var rnn = tf.keras.layers.LSTMCell(4);
var (output, new_states) = rnn.Apply(inputs, states);
Assert.AreEqual((2, 4), output.shape);
Assert.AreEqual((2, 4), new_states[0].shape);
}

[TestMethod]
public void SimpleRNN()
{
@@ -105,15 +101,27 @@ namespace Tensorflow.Keras.UnitTest.Layers
}

[TestMethod]
public void WlzTest()
public void RNNForLSTMCell()
{
long[] b = { 1, 2, 3 };
Shape a = new Shape(Unknown).concatenate(b);
Console.WriteLine(a);
var inputs = tf.ones((5, 10, 8));
var rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4));
var output = rnn.Apply(inputs);
Console.WriteLine($"output: {output}");
Assert.AreEqual((5, 4), output.shape);
}

[TestMethod]
public void MyTest()
{
var a = tf.zeros((2, 3));
var b = tf.ones_like(a);
var c = tf.ones((3,4));

var d = new Tensors { a, b, c };
var (A, BC) = d;
Console.WriteLine($"A:{A}");
Console.WriteLine($"BC:{BC}");
}

}
}

Loading…
Cancel
Save