@@ -16,10 +16,6 @@ namespace Tensorflow.Common.Types | |||||
{ | { | ||||
var elem = new TensorShapeConfig() { Items = new long?[] { dim } }; | var elem = new TensorShapeConfig() { Items = new long?[] { dim } }; | ||||
Shapes = Enumerable.Repeat(elem, size).ToArray(); | Shapes = Enumerable.Repeat(elem, size).ToArray(); | ||||
//Shapes = new TensorShapeConfig[size]; | |||||
//Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } }); | |||||
//Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } }); | |||||
////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; | |||||
} | } | ||||
public GeneralizedTensorShape(Shape shape) | public GeneralizedTensorShape(Shape shape) | ||||
@@ -1,7 +1,35 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
using Newtonsoft.Json; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
{ | { | ||||
// TODO: complete the implementation | // TODO: complete the implementation | ||||
public class LSTMCellArgs : LayerArgs | |||||
public class LSTMCellArgs : AutoSerializeLayerArgs | |||||
{ | { | ||||
[JsonProperty("units")] | |||||
public int Units { get; set; } | |||||
// TODO(Rinne): lack of initialized value of Activation. Merging keras | |||||
// into tf.net could resolve it. | |||||
[JsonProperty("activation")] | |||||
public Activation Activation { get; set; } | |||||
[JsonProperty("recurrent_activation")] | |||||
public Activation RecurrentActivation { get; set; } | |||||
[JsonProperty("use_bias")] | |||||
public bool UseBias { get; set; } = true; | |||||
[JsonProperty("dropout")] | |||||
public float Dropout { get; set; } = .0f; | |||||
[JsonProperty("recurrent_dropout")] | |||||
public float RecurrentDropout { get; set; } = .0f; | |||||
[JsonProperty("kernel_initializer")] | |||||
public IInitializer KernelInitializer { get; set; } | |||||
[JsonProperty("recurrent_initializer")] | |||||
public IInitializer RecurrentInitializer { get; set; } | |||||
[JsonProperty("bias_initializer")] | |||||
public IInitializer BiasInitializer { get; set; } | |||||
[JsonProperty("unit_forget_bias")] | |||||
public bool UnitForgetBias { get; set; } = true; | |||||
[JsonProperty("implementation")] | |||||
public int Implementation { get; set; } = 2; | |||||
} | } | ||||
} | } |
@@ -1,7 +1,4 @@ | |||||
using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | namespace Tensorflow.Keras.ArgsDefinition.Rnn | ||||
{ | { | ||||
@@ -25,5 +22,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
public IInitializer RecurrentInitializer { get; set; } | public IInitializer RecurrentInitializer { get; set; } | ||||
[JsonProperty("bias_initializer")] | [JsonProperty("bias_initializer")] | ||||
public IInitializer BiasInitializer { get; set; } | public IInitializer BiasInitializer { get; set; } | ||||
} | } | ||||
} | } |
@@ -160,6 +160,18 @@ namespace Tensorflow.Keras.Layers | |||||
public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | ||||
public ILayer LeakyReLU(float alpha = 0.3f); | public ILayer LeakyReLU(float alpha = 0.3f); | ||||
public IRnnCell LSTMCell(int uints, | |||||
string activation = "tanh", | |||||
string recurrent_activation = "sigmoid", | |||||
bool use_bias = true, | |||||
string kernel_initializer = "glorot_uniform", | |||||
string recurrent_initializer = "orthogonal", | |||||
string bias_initializer = "zeros", | |||||
bool unit_forget_bias = true, | |||||
float dropout = 0f, | |||||
float recurrent_dropout = 0f, | |||||
int implementation = 2); | |||||
public ILayer LSTM(int units, | public ILayer LSTM(int units, | ||||
Activation activation = null, | Activation activation = null, | ||||
Activation recurrent_activation = null, | Activation recurrent_activation = null, | ||||
@@ -58,8 +58,7 @@ public class Orthogonal : IInitializer | |||||
if (num_rows < num_cols) | if (num_rows < num_cols) | ||||
{ | { | ||||
// q = tf.linalg.matrix_transpose(q); | |||||
throw new NotImplementedException(""); | |||||
q = array_ops.matrix_transpose(q); | |||||
} | } | ||||
return _gain * tf.reshape(q, shape); | return _gain * tf.reshape(q, shape); | ||||
@@ -947,6 +947,49 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
/// <summary> | |||||
/// Transposes last two dimensions of tensor `a`. | |||||
/// For example: | |||||
/// <code> python | |||||
/// x = tf.constant([[1, 2, 3], [4, 5, 6]]) | |||||
/// tf.matrix_transpose(x) # [[1, 4], | |||||
/// # [2, 5], | |||||
/// # [3, 6]] | |||||
/// </code> | |||||
/// Matrix with two batch dimensions. | |||||
/// x.shape is [1, 2, 3, 4] | |||||
/// tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3] | |||||
/// </summary> | |||||
/// <param name="a"></param> | |||||
/// <param name="name"></param> | |||||
/// <param name="conjugate"></param> | |||||
/// <returns></returns> | |||||
/// <exception cref="ValueError"></exception> | |||||
public static Tensor matrix_transpose(Tensor a, string name = "matrix_transpose", bool conjugate = false) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "transpose", new { a }), scope => | |||||
{ | |||||
var a_shape = a.shape; | |||||
var ndims = a.shape.ndim; | |||||
Axis perm; | |||||
if(ndims != 0) | |||||
{ | |||||
if (ndims < 2) | |||||
{ | |||||
throw new ValueError("Argument `a` should be a (batch) matrix with rank " + | |||||
$">= 2. Received `a` = {a} with shape: {a_shape}"); | |||||
} | |||||
perm = new Axis(Enumerable.Range(0, ndims - 2).Concat(new int[] { ndims - 1, ndims - 2 }).ToArray()); | |||||
} | |||||
else | |||||
{ | |||||
var a_rank = a.rank; | |||||
perm = new Axis(Enumerable.Range(0, a_rank - 2).Concat(new int[] { a_rank - 1, a_rank - 2 }).ToArray()); | |||||
} | |||||
return transpose(a, perm:perm, conjugate:conjugate); | |||||
}); | |||||
} | |||||
public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1, | public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1, | ||||
string name = "split") | string name = "split") | ||||
{ | { | ||||
@@ -702,6 +702,7 @@ namespace Tensorflow.Keras.Layers | |||||
UseBias = use_bias, | UseBias = use_bias, | ||||
KernelInitializer = GetInitializerByName(kernel_initializer), | KernelInitializer = GetInitializerByName(kernel_initializer), | ||||
RecurrentInitializer = GetInitializerByName(recurrent_initializer), | RecurrentInitializer = GetInitializerByName(recurrent_initializer), | ||||
BiasInitializer = GetInitializerByName(bias_initializer), | |||||
Dropout = dropout, | Dropout = dropout, | ||||
RecurrentDropout = recurrent_dropout | RecurrentDropout = recurrent_dropout | ||||
}); | }); | ||||
@@ -786,6 +787,33 @@ namespace Tensorflow.Keras.Layers | |||||
TimeMajor = time_major | TimeMajor = time_major | ||||
}); | }); | ||||
public IRnnCell LSTMCell(int uints, | |||||
string activation = "tanh", | |||||
string recurrent_activation = "sigmoid", | |||||
bool use_bias = true, | |||||
string kernel_initializer = "glorot_uniform", | |||||
string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed. | |||||
string bias_initializer = "zeros", | |||||
bool unit_forget_bias = true, | |||||
float dropout = 0f, | |||||
float recurrent_dropout = 0f, | |||||
int implementation = 2) | |||||
=> new LSTMCell(new LSTMCellArgs | |||||
{ | |||||
Units = uints, | |||||
Activation = keras.activations.GetActivationFromName(activation), | |||||
RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation), | |||||
UseBias = use_bias, | |||||
KernelInitializer = GetInitializerByName(kernel_initializer), | |||||
RecurrentInitializer = GetInitializerByName(recurrent_initializer), | |||||
BiasInitializer = GetInitializerByName(bias_initializer), | |||||
UnitForgetBias = unit_forget_bias, | |||||
Dropout = dropout, | |||||
RecurrentDropout = recurrent_dropout, | |||||
Implementation = implementation | |||||
}); | |||||
/// <summary> | /// <summary> | ||||
/// Long Short-Term Memory layer - Hochreiter 1997. | /// Long Short-Term Memory layer - Hochreiter 1997. | ||||
/// </summary> | /// </summary> | ||||
@@ -32,7 +32,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) | |||||
public Tensors? get_dropout_mask_for_cell(Tensors input, bool training, int count = 1) | |||||
{ | { | ||||
if (dropout == 0f) | if (dropout == 0f) | ||||
return null; | return null; | ||||
@@ -44,7 +44,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
// Get the recurrent dropout mask for RNN cell. | // Get the recurrent dropout mask for RNN cell. | ||||
public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) | |||||
public Tensors? get_recurrent_dropout_mask_for_cell(Tensors input, bool training, int count = 1) | |||||
{ | { | ||||
if (dropout == 0f) | if (dropout == 0f) | ||||
return null; | return null; | ||||
@@ -1,16 +1,240 @@ | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Serilog.Core; | |||||
using System.Diagnostics; | |||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Keras.Utils; | |||||
namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
{ | { | ||||
public class LSTMCell : Layer | |||||
/// <summary> | |||||
/// Cell class for the LSTM layer. | |||||
/// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) | |||||
/// for details about the usage of RNN API. | |||||
/// This class processes one step within the whole time sequence input, whereas | |||||
/// `tf.keras.layer.LSTM` processes the whole sequence. | |||||
/// </summary> | |||||
public class LSTMCell : DropoutRNNCellMixin | |||||
{ | { | ||||
LSTMCellArgs args; | |||||
LSTMCellArgs _args; | |||||
IVariableV1 _kernel; | |||||
IVariableV1 _recurrent_kernel; | |||||
IInitializer _bias_initializer; | |||||
IVariableV1 _bias; | |||||
GeneralizedTensorShape _state_size; | |||||
GeneralizedTensorShape _output_size; | |||||
public override GeneralizedTensorShape StateSize => _state_size; | |||||
public override GeneralizedTensorShape OutputSize => _output_size; | |||||
public override bool IsTFRnnCell => true; | |||||
public override bool SupportOptionalArgs => false; | |||||
public LSTMCell(LSTMCellArgs args) | public LSTMCell(LSTMCellArgs args) | ||||
: base(args) | : base(args) | ||||
{ | { | ||||
this.args = args; | |||||
_args = args; | |||||
if (args.Units <= 0) | |||||
{ | |||||
throw new ValueError( | |||||
$"units must be a positive integer, got {args.Units}"); | |||||
} | |||||
_args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); | |||||
_args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); | |||||
if (_args.RecurrentDropout != 0f && _args.Implementation != 1) | |||||
{ | |||||
Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." + | |||||
"Using `implementation=1`."); | |||||
_args.Implementation = 1; | |||||
} | |||||
_state_size = new GeneralizedTensorShape(_args.Units, 2); | |||||
_output_size = new GeneralizedTensorShape(_args.Units); | |||||
} | |||||
public override void build(KerasShapesWrapper input_shape) | |||||
{ | |||||
var single_shape = input_shape.ToSingleShape(); | |||||
var input_dim = single_shape[-1]; | |||||
_kernel = add_weight("kernel", (input_dim, _args.Units * 4), | |||||
initializer: _args.KernelInitializer | |||||
); | |||||
_recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units * 4), | |||||
initializer: _args.RecurrentInitializer | |||||
); | |||||
if (_args.UseBias) | |||||
{ | |||||
if (_args.UnitForgetBias) | |||||
{ | |||||
Tensor bias_initializer() | |||||
{ | |||||
return keras.backend.concatenate( | |||||
new Tensors( | |||||
_args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units))), | |||||
tf.ones_initializer.Apply(new InitializerArgs(shape: (_args.Units))), | |||||
_args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units)))), axis: 0); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
_bias_initializer = _args.BiasInitializer; | |||||
} | |||||
_bias = add_weight("bias", (_args.Units * 4), | |||||
initializer: _args.BiasInitializer); | |||||
} | |||||
built = true; | |||||
} | |||||
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) | |||||
{ | |||||
var h_tm1 = states[0]; // previous memory state | |||||
var c_tm1 = states[1]; // previous carry state | |||||
var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 4); | |||||
var rec_dp_mask = get_recurrent_dropout_mask_for_cell( | |||||
h_tm1, training.Value, count: 4); | |||||
Tensor c; | |||||
Tensor o; | |||||
if (_args.Implementation == 1) | |||||
{ | |||||
Tensor inputs_i; | |||||
Tensor inputs_f; | |||||
Tensor inputs_c; | |||||
Tensor inputs_o; | |||||
if (0f < _args.Dropout && _args.Dropout < 1f) | |||||
{ | |||||
inputs_i = inputs * dp_mask[0]; | |||||
inputs_f = inputs * dp_mask[1]; | |||||
inputs_c = inputs * dp_mask[2]; | |||||
inputs_o = inputs * dp_mask[3]; | |||||
} | |||||
else | |||||
{ | |||||
inputs_i = inputs; | |||||
inputs_f = inputs; | |||||
inputs_c = inputs; | |||||
inputs_o = inputs; | |||||
} | |||||
var k = tf.split(_kernel.AsTensor(), num_split: 4, axis: 1); | |||||
Tensor k_i = k[0], k_f = k[1], k_c = k[2], k_o = k[3]; | |||||
var x_i = math_ops.matmul(inputs_i, k_i); | |||||
var x_f = math_ops.matmul(inputs_f, k_f); | |||||
var x_c = math_ops.matmul(inputs_c, k_c); | |||||
var x_o = math_ops.matmul(inputs_o, k_o); | |||||
if(_args.UseBias) | |||||
{ | |||||
var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0); | |||||
Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3]; | |||||
x_i = gen_nn_ops.bias_add(x_i, b_i); | |||||
x_f = gen_nn_ops.bias_add(x_f, b_f); | |||||
x_c = gen_nn_ops.bias_add(x_c, b_c); | |||||
x_o = gen_nn_ops.bias_add(x_o, b_o); | |||||
} | |||||
Tensor h_tm1_i; | |||||
Tensor h_tm1_f; | |||||
Tensor h_tm1_c; | |||||
Tensor h_tm1_o; | |||||
if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f) | |||||
{ | |||||
h_tm1_i = h_tm1 * rec_dp_mask[0]; | |||||
h_tm1_f = h_tm1 * rec_dp_mask[1]; | |||||
h_tm1_c = h_tm1 * rec_dp_mask[2]; | |||||
h_tm1_o = h_tm1 * rec_dp_mask[3]; | |||||
} | |||||
else | |||||
{ | |||||
h_tm1_i = h_tm1; | |||||
h_tm1_f = h_tm1; | |||||
h_tm1_c = h_tm1; | |||||
h_tm1_o = h_tm1; | |||||
} | |||||
var x = new Tensor[] { x_i, x_f, x_c, x_o }; | |||||
var h_tm1_array = new Tensor[] { h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o }; | |||||
(c, o) = _compute_carry_and_output(x, h_tm1_array, c_tm1); | |||||
} | |||||
else | |||||
{ | |||||
if (0f < _args.Dropout && _args.Dropout < 1f) | |||||
inputs = inputs * dp_mask[0]; | |||||
var z = math_ops.matmul(inputs, _kernel.AsTensor()); | |||||
z += math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor()); | |||||
if (_args.UseBias) | |||||
{ | |||||
z = tf.nn.bias_add(z, _bias); | |||||
} | |||||
var z_array = tf.split(z, num_split: 4, axis: 1); | |||||
(c, o) = _compute_carry_and_output_fused(z_array, c_tm1); | |||||
} | |||||
var h = o * _args.Activation.Apply(c); | |||||
// 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组 | |||||
return new Tensors(h, h, c); | |||||
} | |||||
/// <summary> | |||||
/// Computes carry and output using split kernels. | |||||
/// </summary> | |||||
/// <param name="x"></param> | |||||
/// <param name="h_tm1"></param> | |||||
/// <param name="c_tm1"></param> | |||||
/// <returns></returns> | |||||
/// <exception cref="NotImplementedException"></exception> | |||||
public Tensors _compute_carry_and_output(Tensor[] x, Tensor[] h_tm1, Tensor c_tm1) | |||||
{ | |||||
Tensor x_i = x[0], x_f = x[1], x_c = x[2], x_o = x[3]; | |||||
Tensor h_tm1_i = h_tm1[0], h_tm1_f = h_tm1[1], h_tm1_c = h_tm1[2], | |||||
h_tm1_o = h_tm1[3]; | |||||
var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor(); | |||||
var startIndex = _recurrent_kernel_tensor.shape[0]; | |||||
var endIndex = _recurrent_kernel_tensor.shape[1]; | |||||
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | |||||
new[] { 0, 0 }, new[] { startIndex, _args.Units }); | |||||
var i = _args.RecurrentActivation.Apply( | |||||
x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice)); | |||||
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | |||||
new[] { 0, _args.Units }, new[] { startIndex, _args.Units * 2}); | |||||
var f = _args.RecurrentActivation.Apply( | |||||
x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice)); | |||||
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | |||||
new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units * 3 }); | |||||
var c = f * c_tm1 + i * _args.Activation.Apply( | |||||
x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice)); | |||||
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | |||||
new[] { 0, _args.Units * 3 }, new[] { startIndex, endIndex }); | |||||
var o = _args.RecurrentActivation.Apply( | |||||
x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice)); | |||||
return new Tensors(c, o); | |||||
} | |||||
/// <summary> | |||||
/// Computes carry and output using fused kernels. | |||||
/// </summary> | |||||
/// <param name="z"></param> | |||||
/// <param name="c_tm1"></param> | |||||
/// <returns></returns> | |||||
public Tensors _compute_carry_and_output_fused(Tensor[] z, Tensor c_tm1) | |||||
{ | |||||
Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3]; | |||||
var i = _args.RecurrentActivation.Apply(z0); | |||||
var f = _args.RecurrentActivation.Apply(z1); | |||||
var c = f * c_tm1 + i * _args.RecurrentActivation.Apply(z2); | |||||
var o = _args.RecurrentActivation.Apply(z3); | |||||
return new Tensors(c, o); | |||||
} | |||||
public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) | |||||
{ | |||||
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -74,8 +74,8 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
{ | { | ||||
// TODO(Rinne): check if it will have multiple tensors when not nested. | // TODO(Rinne): check if it will have multiple tensors when not nested. | ||||
Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states; | Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states; | ||||
var dp_mask = get_dropout_maskcell_for_cell(inputs, training.Value); | |||||
var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); | |||||
var dp_mask = get_dropout_mask_for_cell(inputs, training.Value); | |||||
var rec_dp_mask = get_recurrent_dropout_mask_for_cell(prev_output, training.Value); | |||||
Tensor h; | Tensor h; | ||||
var ranks = inputs.rank; | var ranks = inputs.rank; | ||||
@@ -21,21 +21,6 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
[TestMethod] | [TestMethod] | ||||
public void SimpleRNNCell() | public void SimpleRNNCell() | ||||
{ | { | ||||
//var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); | |||||
//var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; | |||||
//var x = tf.random.normal((4, 100)); | |||||
//var (y, h1) = cell.Apply(inputs: x, states: h0); | |||||
//var h2 = h1; | |||||
//Assert.AreEqual((4, 64), y.shape); | |||||
//Assert.AreEqual((4, 64), h2[0].shape); | |||||
//var model = keras.Sequential(new List<ILayer> | |||||
//{ | |||||
// keras.layers.InputLayer(input_shape: (4,100)), | |||||
// keras.layers.SimpleRNNCell(64) | |||||
//}); | |||||
//model.summary(); | |||||
var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); | var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); | ||||
var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; | var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; | ||||
var x = tf.random.normal((4, 100)); | var x = tf.random.normal((4, 100)); | ||||
@@ -59,6 +44,17 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
Assert.AreEqual((32, 4), state[0].shape); | Assert.AreEqual((32, 4), state[0].shape); | ||||
} | } | ||||
[TestMethod] | |||||
public void LSTMCell() | |||||
{ | |||||
var inputs = tf.ones((2, 100)); | |||||
var states = new Tensors { tf.zeros((2, 4)), tf.zeros((2, 4)) }; | |||||
var rnn = tf.keras.layers.LSTMCell(4); | |||||
var (output, new_states) = rnn.Apply(inputs, states); | |||||
Assert.AreEqual((2, 4), output.shape); | |||||
Assert.AreEqual((2, 4), new_states[0].shape); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void SimpleRNN() | public void SimpleRNN() | ||||
{ | { | ||||
@@ -99,6 +95,28 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
Assert.AreEqual((32, 5), output.shape); | Assert.AreEqual((32, 5), output.shape); | ||||
} | } | ||||
[TestMethod] | |||||
public void RNNForLSTMCell() | |||||
{ | |||||
var inputs = tf.ones((5, 10, 8)); | |||||
var rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4)); | |||||
var output = rnn.Apply(inputs); | |||||
Console.WriteLine($"output: {output}"); | |||||
Assert.AreEqual((5, 4), output.shape); | |||||
} | |||||
[TestMethod] | |||||
public void MyTest() | |||||
{ | |||||
var a = tf.zeros((2, 3)); | |||||
var b = tf.ones_like(a); | |||||
var c = tf.ones((3,4)); | |||||
var d = new Tensors { a, b, c }; | |||||
var (A, BC) = d; | |||||
Console.WriteLine($"A:{A}"); | |||||
Console.WriteLine($"BC:{BC}"); | |||||
} | |||||
} | } | ||||
} | } |