From db8e43b241cbc86a707bab7f0da5d4a0861820ec Mon Sep 17 00:00:00 2001
From: Wanglongzhi2001 <583087864@qq.com>
Date: Mon, 12 Jun 2023 17:59:07 +0800
Subject: [PATCH] Add feature(not completed):add SimpleRNNCell, StackedRNNCell,
RNN and test
---
.../Common/Types/GeneralizedTensorShape.cs | 14 +-
.../Keras/ArgsDefinition/Rnn/RNNArgs.cs | 3 +
.../ArgsDefinition/Rnn/StackedRNNCellsArgs.cs | 3 +-
.../Keras/Layers/ILayersApi.cs | 34 ++++
.../Operations/_EagerTensorArray.cs | 14 +-
.../Operations/_GraphTensorArray.cs | 5 +-
src/TensorFlowNET.Keras/BackendImpl.cs | 27 +--
src/TensorFlowNET.Keras/Layers/LayersApi.cs | 77 +++++++++
.../Layers/Rnn/DropoutRNNCellMixin.cs | 15 ++
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 76 ++++++---
.../Layers/Rnn/SimpleRNNCell.cs | 10 +-
.../Layers/Rnn/StackedRNNCells.cs | 159 +++++++++++-------
.../Callbacks/EarlystoppingTest.cs | 25 ++-
.../Layers/Rnn.Test.cs | 102 ++++++++++-
14 files changed, 445 insertions(+), 119 deletions(-)
diff --git a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
index e05d3deb..c61d04b2 100644
--- a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
+++ b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
@@ -12,9 +12,14 @@ namespace Tensorflow.Common.Types
/// create a single-dim generalized Tensor shape.
///
///
- public GeneralizedTensorShape(int dim)
+ public GeneralizedTensorShape(int dim, int size = 1)
{
- Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
+ var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
+ Shapes = Enumerable.Repeat(elem, size).ToArray();
+ //Shapes = new TensorShapeConfig[size];
+ //Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
+ //Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
+ ////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
}
public GeneralizedTensorShape(Shape shape)
@@ -113,6 +118,11 @@ namespace Tensorflow.Common.Types
return new Nest(Shapes.Select(s => DealWithSingleShape(s)));
}
}
+
+
+
+ public static implicit operator GeneralizedTensorShape(int dims)
+ => new GeneralizedTensorShape(dims);
public IEnumerator GetEnumerator()
{
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
index ed5a1d6d..116ff7a2 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
@@ -10,6 +10,9 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
[JsonProperty("cell")]
// TODO: the cell should be serialized with `serialize_keras_object`.
public IRnnCell Cell { get; set; } = null;
+ [JsonProperty("cells")]
+ public IList Cells { get; set; } = null;
+
[JsonProperty("return_sequences")]
public bool ReturnSequences { get; set; } = false;
[JsonProperty("return_state")]
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
index fdfadab8..ea6f830b 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
@@ -1,10 +1,11 @@
using System.Collections.Generic;
+using Tensorflow.Keras.Layers.Rnn;
namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class StackedRNNCellsArgs : LayerArgs
{
- public IList Cells { get; set; }
+ public IList Cells { get; set; }
public Dictionary Kwargs { get; set; } = null;
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
index 6a29f9e5..3b223816 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
@@ -1,5 +1,6 @@
using System;
using Tensorflow.Framework.Models;
+using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.NumPy;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;
@@ -192,6 +193,19 @@ namespace Tensorflow.Keras.Layers
float offset = 0,
Shape input_shape = null);
+ public IRnnCell SimpleRNNCell(
+ int units,
+ string activation = "tanh",
+ bool use_bias = true,
+ string kernel_initializer = "glorot_uniform",
+ string recurrent_initializer = "orthogonal",
+ string bias_initializer = "zeros",
+ float dropout = 0f,
+ float recurrent_dropout = 0f);
+
+ public IRnnCell StackedRNNCells(
+ IEnumerable cells);
+
public ILayer SimpleRNN(int units,
string activation = "tanh",
string kernel_initializer = "glorot_uniform",
@@ -200,6 +214,26 @@ namespace Tensorflow.Keras.Layers
bool return_sequences = false,
bool return_state = false);
+ public ILayer RNN(
+ IRnnCell cell,
+ bool return_sequences = false,
+ bool return_state = false,
+ bool go_backwards = false,
+ bool stateful = false,
+ bool unroll = false,
+ bool time_major = false
+ );
+
+ public ILayer RNN(
+ IEnumerable cell,
+ bool return_sequences = false,
+ bool return_state = false,
+ bool go_backwards = false,
+ bool stateful = false,
+ bool unroll = false,
+ bool time_major = false
+ );
+
public ILayer Subtract();
}
}
diff --git a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
index ed65a08d..08e73fe6 100644
--- a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
+++ b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
@@ -109,7 +109,19 @@ namespace Tensorflow.Operations
return ta;
});*/
- throw new NotImplementedException("");
+ //if (indices is EagerTensor)
+ //{
+ // indices = indices as EagerTensor;
+ // indices = indices.numpy();
+ //}
+
+ //foreach (var (index, val) in zip(indices.ToArray(), array_ops.unstack(value)))
+ //{
+ // this.write(index, val);
+ //}
+ //return base;
+ //throw new NotImplementedException("");
+ return this;
}
public void _merge_element_shape(Shape shape)
diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
index 16870e9f..dde2624a 100644
--- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
+++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using Tensorflow.Eager;
using static Tensorflow.Binding;
namespace Tensorflow.Operations
@@ -146,7 +147,9 @@ namespace Tensorflow.Operations
return ta;
});*/
- throw new NotImplementedException("");
+
+ //throw new NotImplementedException("");
+ return this;
}
public void _merge_element_shape(Shape shape)
diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs
index 14491066..1336e9af 100644
--- a/src/TensorFlowNET.Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Keras/BackendImpl.cs
@@ -510,7 +510,7 @@ namespace Tensorflow.Keras
}
}
-
+
// tf.where needs its condition tensor to be the same shape as its two
// result tensors, but in our case the condition (mask) tensor is
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
@@ -535,7 +535,7 @@ namespace Tensorflow.Keras
{
mask_t = tf.expand_dims(mask_t, -1);
}
- var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank));
+ var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().Skip(fixed_dim).ToArray());
return tf.tile(mask_t, multiples);
}
@@ -570,9 +570,6 @@ namespace Tensorflow.Keras
// individually. The result of this will be a tuple of lists, each of
// the item in tuple is list of the tensor with shape (batch, feature)
-
-
-
Tensors _process_single_input_t(Tensor input_t)
{
var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim
@@ -609,7 +606,7 @@ namespace Tensorflow.Keras
var mask_list = tf.unstack(mask);
if (go_backwards)
{
- mask_list.Reverse();
+ mask_list.Reverse().ToArray();
}
for (int i = 0; i < time_steps; i++)
@@ -629,9 +626,10 @@ namespace Tensorflow.Keras
}
else
{
- prev_output = successive_outputs[successive_outputs.Length - 1];
+ prev_output = successive_outputs.Last();
}
+ // output could be a tensor
output = tf.where(tiled_mask_t, output, prev_output);
var flat_states = Nest.Flatten(states).ToList();
@@ -661,13 +659,13 @@ namespace Tensorflow.Keras
}
}
- last_output = successive_outputs[successive_outputs.Length - 1];
- new_states = successive_states[successive_states.Length - 1];
+ last_output = successive_outputs.Last();
+ new_states = successive_states.Last();
outputs = tf.stack(successive_outputs);
if (zero_output_for_mask)
{
- last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
+ last_output = tf.where(_expand_mask(mask_list.Last(), last_output), last_output, tf.zeros_like(last_output));
outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
}
else // mask is null
@@ -689,8 +687,8 @@ namespace Tensorflow.Keras
successive_states = new Tensors { newStates };
}
}
- last_output = successive_outputs[successive_outputs.Length - 1];
- new_states = successive_states[successive_states.Length - 1];
+ last_output = successive_outputs.Last();
+ new_states = successive_states.Last();
outputs = tf.stack(successive_outputs);
}
}
@@ -701,6 +699,8 @@ namespace Tensorflow.Keras
// Create input tensor array, if the inputs is nested tensors, then it
// will be flattened first, and tensor array will be created one per
// flattened tensor.
+
+
var input_ta = new List();
for (int i = 0; i < flatted_inptus.Count; i++)
{
@@ -719,6 +719,7 @@ namespace Tensorflow.Keras
}
}
+
// Get the time(0) input and compute the output for that, the output will
// be used to determine the dtype of output tensor array. Don't read from
// input_ta due to TensorArray clear_after_read default to True.
@@ -773,7 +774,7 @@ namespace Tensorflow.Keras
return res;
};
}
- // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
+ // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor), it could be an integer or tensor
else if (input_length is Tensor)
{
if (go_backwards)
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index 3b095bc2..dd25122d 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -685,6 +685,34 @@ namespace Tensorflow.Keras.Layers
Alpha = alpha
});
+
+ public IRnnCell SimpleRNNCell(
+ int units,
+ string activation = "tanh",
+ bool use_bias = true,
+ string kernel_initializer = "glorot_uniform",
+ string recurrent_initializer = "orthogonal",
+ string bias_initializer = "zeros",
+ float dropout = 0f,
+ float recurrent_dropout = 0f)
+ => new SimpleRNNCell(new SimpleRNNCellArgs
+ {
+ Units = units,
+ Activation = keras.activations.GetActivationFromName(activation),
+ UseBias = use_bias,
+ KernelInitializer = GetInitializerByName(kernel_initializer),
+ RecurrentInitializer = GetInitializerByName(recurrent_initializer),
+ Dropout = dropout,
+ RecurrentDropout = recurrent_dropout
+ });
+
+ public IRnnCell StackedRNNCells(
+ IEnumerable cells)
+ => new StackedRNNCells(new StackedRNNCellsArgs
+ {
+ Cells = cells.ToList()
+ });
+
///
///
///
@@ -709,6 +737,55 @@ namespace Tensorflow.Keras.Layers
ReturnState = return_state
});
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public ILayer RNN(
+ IRnnCell cell,
+ bool return_sequences = false,
+ bool return_state = false,
+ bool go_backwards = false,
+ bool stateful = false,
+ bool unroll = false,
+ bool time_major = false)
+ => new RNN(new RNNArgs
+ {
+ Cell = cell,
+ ReturnSequences = return_sequences,
+ ReturnState = return_state,
+ GoBackwards = go_backwards,
+ Stateful = stateful,
+ Unroll = unroll,
+ TimeMajor = time_major
+ });
+
+ public ILayer RNN(
+ IEnumerable cell,
+ bool return_sequences = false,
+ bool return_state = false,
+ bool go_backwards = false,
+ bool stateful = false,
+ bool unroll = false,
+ bool time_major = false)
+ => new RNN(new RNNArgs
+ {
+ Cells = cell.ToList(),
+ ReturnSequences = return_sequences,
+ ReturnState = return_state,
+ GoBackwards = go_backwards,
+ Stateful = stateful,
+ Unroll = unroll,
+ TimeMajor = time_major
+ });
+
///
/// Long Short-Term Memory layer - Hochreiter 1997.
///
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs
index 21396853..78d3dac9 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs
@@ -17,6 +17,21 @@ namespace Tensorflow.Keras.Layers.Rnn
}
+ protected void _create_non_trackable_mask_cache()
+ {
+
+ }
+
+ public void reset_dropout_mask()
+ {
+
+ }
+
+ public void reset_recurrent_dropout_mask()
+ {
+
+ }
+
public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
{
if (dropout == 0f)
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
index ab4cef12..0ebd7362 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
@@ -38,7 +38,17 @@ namespace Tensorflow.Keras.Layers.Rnn
SupportsMasking = true;
// if is StackedRnncell
- _cell = args.Cell;
+ if (args.Cells != null)
+ {
+ _cell = new StackedRNNCells(new StackedRNNCellsArgs
+ {
+ Cells = args.Cells
+ });
+ }
+ else
+ {
+ _cell = args.Cell;
+ }
// get input_shape
_args = PreConstruct(args);
@@ -122,6 +132,8 @@ namespace Tensorflow.Keras.Layers.Rnn
var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list());
return new Shape(state_shape);
};
+
+
var state_shape = _get_state_shape(state_size);
return new List { output_shape, state_shape };
@@ -240,7 +252,7 @@ namespace Tensorflow.Keras.Layers.Rnn
if (_cell is StackedRNNCells)
{
var stack_cell = _cell as StackedRNNCells;
- foreach (var cell in stack_cell.Cells)
+ foreach (IRnnCell cell in stack_cell.Cells)
{
_maybe_reset_cell_dropout_mask(cell);
}
@@ -253,7 +265,7 @@ namespace Tensorflow.Keras.Layers.Rnn
}
Shape input_shape;
- if (!inputs.IsSingle())
+ if (!inputs.IsNested())
{
// In the case of nested input, use the first element for shape check
// input_shape = nest.flatten(inputs)[0].shape;
@@ -267,7 +279,7 @@ namespace Tensorflow.Keras.Layers.Rnn
var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];
- if (_args.Unroll && timesteps != null)
+ if (_args.Unroll && timesteps == null)
{
throw new ValueError(
"Cannot unroll a RNN if the " +
@@ -302,7 +314,6 @@ namespace Tensorflow.Keras.Layers.Rnn
states = new Tensors(states.SkipLast(_num_constants));
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
- // TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors?
return (output, new_states.Single);
};
}
@@ -310,13 +321,14 @@ namespace Tensorflow.Keras.Layers.Rnn
{
step = (inputs, states) =>
{
- states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
+ states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states;
var (output, new_states) = _cell.Apply(inputs, states);
- return (output, new_states.Single);
+ return (output, new_states);
};
}
-
- var (last_output, outputs, states) = keras.backend.rnn(step,
+
+ var (last_output, outputs, states) = keras.backend.rnn(
+ step,
inputs,
initial_state,
constants: constants,
@@ -394,6 +406,7 @@ namespace Tensorflow.Keras.Layers.Rnn
initial_state = null;
inputs = inputs[0];
}
+
if (_args.Stateful)
{
@@ -402,7 +415,7 @@ namespace Tensorflow.Keras.Layers.Rnn
var tmp = new Tensor[] { };
foreach (var s in nest.flatten(States))
{
- tmp.add(tf.math.count_nonzero((Tensor)s));
+ tmp.add(tf.math.count_nonzero(s.Single()));
}
var non_zero_count = tf.add_n(tmp);
//initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
@@ -415,6 +428,15 @@ namespace Tensorflow.Keras.Layers.Rnn
{
initial_state = States;
}
+ // TODO(Wanglongzhi2001),
+// initial_state = tf.nest.map_structure(
+//# When the layer has a inferred dtype, use the dtype from the
+//# cell.
+// lambda v: tf.cast(
+// v, self.compute_dtype or self.cell.compute_dtype
+// ),
+// initial_state,
+// )
}
else if (initial_state is null)
@@ -424,10 +446,9 @@ namespace Tensorflow.Keras.Layers.Rnn
if (initial_state.Length != States.Length)
{
- throw new ValueError(
- $"Layer {this} expects {States.Length} state(s), " +
- $"but it received {initial_state.Length} " +
- $"initial state(s). Input received: {inputs}");
+ throw new ValueError($"Layer {this} expects {States.Length} state(s), " +
+ $"but it received {initial_state.Length} " +
+ $"initial state(s). Input received: {inputs}");
}
return (inputs, initial_state, constants);
@@ -458,11 +479,11 @@ namespace Tensorflow.Keras.Layers.Rnn
void _maybe_reset_cell_dropout_mask(ILayer cell)
{
- //if (cell is DropoutRNNCellMixin)
- //{
- // cell.reset_dropout_mask();
- // cell.reset_recurrent_dropout_mask();
- //}
+ if (cell is DropoutRNNCellMixin CellDRCMixin)
+ {
+ CellDRCMixin.reset_dropout_mask();
+ CellDRCMixin.reset_recurrent_dropout_mask();
+ }
}
private static RNNArgs PreConstruct(RNNArgs args)
@@ -537,15 +558,24 @@ namespace Tensorflow.Keras.Layers.Rnn
protected Tensors get_initial_state(Tensors inputs)
{
+ var get_initial_state_fn = _cell.GetType().GetMethod("get_initial_state");
+
var input = inputs[0];
- var input_shape = input.shape;
+ var input_shape = inputs.shape;
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
var dtype = input.dtype;
- Tensors init_state;
- if (_cell is RnnCellBase rnn_base_cell)
+
+ Tensors init_state = new Tensors();
+
+ if(get_initial_state_fn != null)
{
- init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype);
+ init_state = (Tensors)get_initial_state_fn.Invoke(_cell, new object[] { inputs, batch_size, dtype });
+
}
+ //if (_cell is RnnCellBase rnn_base_cell)
+ //{
+ // init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype);
+ //}
else
{
init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype);
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
index f0b2ed4d..39610ff5 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
@@ -6,6 +6,7 @@ using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;
using Tensorflow.Common.Extensions;
+using Tensorflow.Keras.Utils;
namespace Tensorflow.Keras.Layers.Rnn
{
@@ -77,8 +78,10 @@ namespace Tensorflow.Keras.Layers.Rnn
var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value);
Tensor h;
+ var ranks = inputs.rank;
if (dp_mask != null)
{
+
h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor());
}
else
@@ -95,7 +98,7 @@ namespace Tensorflow.Keras.Layers.Rnn
{
prev_output = math_ops.multiply(prev_output, rec_dp_mask);
}
-
+ var tmp = _recurrent_kernel.AsTensor();
Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor());
if (_args.Activation != null)
@@ -113,5 +116,10 @@ namespace Tensorflow.Keras.Layers.Rnn
return new Tensors(output, output);
}
}
+
+ public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
+ {
+ return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value);
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
index 0b92fd3c..56634853 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
@@ -1,17 +1,20 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
+using System.Linq;
+using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
+using Tensorflow.Keras.Utils;
namespace Tensorflow.Keras.Layers.Rnn
{
public class StackedRNNCells : Layer, IRnnCell
{
- public IList Cells { get; set; }
+ public IList Cells { get; set; }
public bool reverse_state_order;
public StackedRNNCells(StackedRNNCellsArgs args) : base(args)
@@ -20,8 +23,19 @@ namespace Tensorflow.Keras.Layers.Rnn
{
args.Kwargs = new Dictionary();
}
-
+ foreach (var cell in args.Cells)
+ {
+ //Type type = cell.GetType();
+ //var CallMethodInfo = type.GetMethod("Call");
+ //if (CallMethodInfo == null)
+ //{
+ // throw new ValueError(
+ // "All cells must have a `Call` method. " +
+ // $"Received cell without a `Call` method: {cell}");
+ //}
+ }
Cells = args.Cells;
+
reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false);
if (reverse_state_order)
@@ -33,91 +47,112 @@ namespace Tensorflow.Keras.Layers.Rnn
}
}
- public object state_size
+ public GeneralizedTensorShape StateSize
{
- get => throw new NotImplementedException();
- //@property
- //def state_size(self) :
- // return tuple(c.state_size for c in
- // (self.cells[::- 1] if self.reverse_state_order else self.cells))
+ get
+ {
+ GeneralizedTensorShape state_size = new GeneralizedTensorShape(1, Cells.Count);
+ if (reverse_state_order && Cells.Count > 0)
+ {
+ var idxAndCell = Cells.Reverse().Select((cell, idx) => (idx, cell));
+ foreach (var cell in idxAndCell)
+ {
+ state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First();
+ }
+ }
+ else
+ {
+ //foreach (var cell in Cells)
+ //{
+ // state_size.Shapes.add(cell.StateSize.Shapes.First());
+
+ //}
+ var idxAndCell = Cells.Select((cell, idx) => (idx, cell));
+ foreach (var cell in idxAndCell)
+ {
+ state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First();
+ }
+ }
+ return state_size;
+ }
}
public object output_size
{
get
{
- var lastCell = Cells[Cells.Count - 1];
-
- if (lastCell.output_size != -1)
+ var lastCell = Cells.LastOrDefault();
+ if (lastCell.OutputSize.ToSingleShape() != -1)
{
- return lastCell.output_size;
+ return lastCell.OutputSize;
}
else if (RNN.is_multiple_state(lastCell.StateSize))
{
- // return ((dynamic)Cells[-1].state_size)[0];
- throw new NotImplementedException("");
+ return lastCell.StateSize.First();
+ //throw new NotImplementedException("");
}
else
{
- return Cells[-1].state_size;
+ return lastCell.StateSize;
}
}
}
- public object get_initial_state()
+ public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
{
- throw new NotImplementedException();
- // def get_initial_state(self, inputs= None, batch_size= None, dtype= None) :
- // initial_states = []
- // for cell in self.cells[::- 1] if self.reverse_state_order else self.cells:
- // get_initial_state_fn = getattr(cell, 'get_initial_state', None)
- // if get_initial_state_fn:
- // initial_states.append(get_initial_state_fn(
- // inputs=inputs, batch_size=batch_size, dtype=dtype))
- // else:
- // initial_states.append(_generate_zero_filled_state_for_cell(
- // cell, inputs, batch_size, dtype))
-
- // return tuple(initial_states)
+ var cells = reverse_state_order ? Cells.Reverse() : Cells;
+ Tensors initial_states = new Tensors();
+ foreach (var cell in cells)
+ {
+ var get_initial_state_fn = cell.GetType().GetMethod("get_initial_state");
+ if (get_initial_state_fn != null)
+ {
+ var result = (Tensors)get_initial_state_fn.Invoke(cell, new object[] { inputs, batch_size, dtype });
+ initial_states.Add(result);
+ }
+ else
+ {
+ initial_states.Add(RnnUtils.generate_zero_filled_state_for_cell(cell, inputs, batch_size.Value, dtype.Value));
+ }
+ }
+ return initial_states;
}
- public object call()
+ protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
- throw new NotImplementedException();
- // def call(self, inputs, states, constants= None, training= None, ** kwargs):
- // # Recover per-cell states.
- // state_size = (self.state_size[::- 1]
- // if self.reverse_state_order else self.state_size)
- // nested_states = nest.pack_sequence_as(state_size, nest.flatten(states))
-
- // # Call the cells in order and store the returned states.
- // new_nested_states = []
- // for cell, states in zip(self.cells, nested_states) :
- // states = states if nest.is_nested(states) else [states]
- //# TF cell does not wrap the state into list when there is only one state.
- // is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None
- // states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
- // if generic_utils.has_arg(cell.call, 'training'):
- // kwargs['training'] = training
- // else:
- // kwargs.pop('training', None)
- // # Use the __call__ function for callable objects, eg layers, so that it
- // # will have the proper name scopes for the ops, etc.
- // cell_call_fn = cell.__call__ if callable(cell) else cell.call
- // if generic_utils.has_arg(cell.call, 'constants'):
- // inputs, states = cell_call_fn(inputs, states,
- // constants= constants, ** kwargs)
- // else:
- // inputs, states = cell_call_fn(inputs, states, ** kwargs)
- // new_nested_states.append(states)
+ // Recover per-cell states.
+ var state_size = reverse_state_order ? StateSize.Reverse() : StateSize;
+ var nested_states = reverse_state_order ? state.Flatten().Reverse() : state.Flatten();
- // return inputs, nest.pack_sequence_as(state_size,
- // nest.flatten(new_nested_states))
+
+ var new_nest_states = new Tensors();
+ // Call the cells in order and store the returned states.
+ foreach (var (cell, states) in zip(Cells, nested_states))
+ {
+ // states = states if tf.nest.is_nested(states) else [states]
+ var type = cell.GetType();
+ bool IsTFRnnCell = type.GetProperty("IsTFRnnCell") != null;
+ state = len(state) == 1 && IsTFRnnCell ? state.FirstOrDefault() : state;
+
+ RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
+ Tensors? constants = rnn_optional_args?.Constants;
+
+ Tensors new_states;
+ (inputs, new_states) = cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
+
+ new_nest_states.Add(new_states);
+ }
+ new_nest_states = reverse_state_order ? new_nest_states.Reverse().ToArray() : new_nest_states.ToArray();
+ return new Nest(new List> {
+ new Nest(new List> { new Nest(inputs.Single()) }), new Nest(new_nest_states) })
+ .ToTensors();
}
+
+
public void build()
{
- throw new NotImplementedException();
+ built = true;
// @tf_utils.shape_type_conversion
// def build(self, input_shape) :
// if isinstance(input_shape, list) :
@@ -168,9 +203,9 @@ namespace Tensorflow.Keras.Layers.Rnn
{
throw new NotImplementedException();
}
- public GeneralizedTensorShape StateSize => throw new NotImplementedException();
+
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
- public bool IsTFRnnCell => throw new NotImplementedException();
+ public bool IsTFRnnCell => true;
public bool SupportOptionalArgs => throw new NotImplementedException();
}
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs b/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs
index ac5ba15e..29648790 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs
@@ -2,6 +2,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using Tensorflow.Keras.Callbacks;
using Tensorflow.Keras.Engine;
+using Tensorflow.NumPy;
using static Tensorflow.KerasApi;
@@ -18,7 +19,7 @@ namespace Tensorflow.Keras.UnitTest.Callbacks
var layers = keras.layers;
var model = keras.Sequential(new List
{
- layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)),
+ layers.Rescaling(1.0f / 255, input_shape: (28, 28, 1)),
layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu),
layers.MaxPooling2D(),
layers.Flatten(),
@@ -36,8 +37,20 @@ namespace Tensorflow.Keras.UnitTest.Callbacks
var num_epochs = 3;
var batch_size = 8;
- var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
- x_train = x_train / 255.0f;
+ var data_loader = new MnistModelLoader();
+
+ var dataset = data_loader.LoadAsync(new ModelLoadSetting
+ {
+ TrainDir = "mnist",
+ OneHot = false,
+ ValidationSize = 59900,
+ }).Result;
+
+ NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1));
+ NDArray x2 = x1;
+
+ var x = new NDArray[] { x1, x2 };
+
// define a CallbackParams first, the parameters you pass al least contain Model and Epochs.
CallbackParams callback_parameters = new CallbackParams
{
@@ -47,10 +60,8 @@ namespace Tensorflow.Keras.UnitTest.Callbacks
// define your earlystop
ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy");
// define a callbcaklist, then add the earlystopping to it.
- var callbacks = new List();
- callbacks.add(earlystop);
-
- model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs, callbacks: callbacks);
+ var callbacks = new List{ earlystop};
+ model.fit(x, dataset.Train.Labels, batch_size, num_epochs, callbacks: callbacks);
}
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
index 55663d41..28a16ad4 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
@@ -4,25 +4,111 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
+using Tensorflow.Common.Types;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Layers.Rnn;
+using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
+using Tensorflow.Train;
using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class Rnn
{
+ [TestMethod]
+ public void SimpleRNNCell()
+ {
+ //var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f);
+ //var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
+ //var x = tf.random.normal((4, 100));
+ //var (y, h1) = cell.Apply(inputs: x, states: h0);
+ //var h2 = h1;
+ //Assert.AreEqual((4, 64), y.shape);
+ //Assert.AreEqual((4, 64), h2[0].shape);
+
+ //var model = keras.Sequential(new List
+ //{
+ // keras.layers.InputLayer(input_shape: (4,100)),
+ // keras.layers.SimpleRNNCell(64)
+ //});
+ //model.summary();
+
+ var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f);
+ var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
+ var x = tf.random.normal((4, 100));
+ var (y, h1) = cell.Apply(inputs: x, states: h0);
+ var h2 = h1;
+ Assert.AreEqual((4, 64), y.shape);
+ Assert.AreEqual((4, 64), h2[0].shape);
+ }
+
+ [TestMethod]
+ public void StackedRNNCell()
+ {
+ var inputs = tf.ones((32, 10));
+ var states = new Tensors { tf.zeros((32, 4)), tf.zeros((32, 5)) };
+ var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) };
+ var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells);
+ var (output, state) = stackedRNNCell.Apply(inputs, states);
+ Console.WriteLine(output);
+ Console.WriteLine(state.shape);
+ Assert.AreEqual((32, 5), output.shape);
+ Assert.AreEqual((32, 4), state[0].shape);
+ }
+
[TestMethod]
public void SimpleRNN()
{
- var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
- /*var simple_rnn = keras.layers.SimpleRNN(4);
- var output = simple_rnn.Apply(inputs);
- Assert.AreEqual((32, 4), output.shape);*/
- var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
- var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
- Console.WriteLine(whole_sequence_output);
- Console.WriteLine(final_state);
+ //var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
+ ///*var simple_rnn = keras.layers.SimpleRNN(4);
+ //var output = simple_rnn.Apply(inputs);
+ //Assert.AreEqual((32, 4), output.shape);*/
+
+ //var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
+ //var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
+ //Assert.AreEqual((6, 10, 4), whole_sequence_output.shape);
+ //Assert.AreEqual((6, 4), final_state.shape);
+
+ var inputs = keras.Input(shape: (10, 8));
+ var x = keras.layers.SimpleRNN(4).Apply(inputs);
+ var output = keras.layers.Dense(10).Apply(x);
+ var model = keras.Model(inputs, output);
+ model.summary();
+ }
+ [TestMethod]
+ public void RNNForSimpleRNNCell()
+ {
+ var inputs = tf.random.normal((32, 10, 8));
+ var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f);
+ var rnn = tf.keras.layers.RNN(cell: cell);
+ var output = rnn.Apply(inputs);
+ Assert.AreEqual((32, 10), output.shape);
+
}
+ [TestMethod]
+ public void RNNForStackedRNNCell()
+ {
+ var inputs = tf.random.normal((32, 10, 8));
+ var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) };
+ var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells);
+ var rnn = tf.keras.layers.RNN(cell: stackedRNNCell);
+ var output = rnn.Apply(inputs);
+ Assert.AreEqual((32, 5), output.shape);
+ }
+
+ [TestMethod]
+ public void WlzTest()
+ {
+ long[] b = { 1, 2, 3 };
+
+ Shape a = new Shape(Unknown).concatenate(b);
+ Console.WriteLine(a);
+
+ }
+
+
}
}