Browse Source

Add feature(not completed):add SimpleRNNCell, StackedRNNCell, RNN and test

pull/1106/head
Wanglongzhi2001 2 years ago
parent
commit
5a87a56e83
14 changed files with 445 additions and 119 deletions
  1. +12
    -2
      src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
  3. +2
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
  4. +34
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  5. +13
    -1
      src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
  6. +4
    -1
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  7. +14
    -13
      src/TensorFlowNET.Keras/BackendImpl.cs
  8. +77
    -0
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  9. +15
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs
  10. +53
    -23
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  11. +9
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  12. +97
    -62
      src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
  13. +18
    -7
      test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs
  14. +94
    -8
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs

+ 12
- 2
src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs View File

@@ -12,9 +12,14 @@ namespace Tensorflow.Common.Types
/// create a single-dim generalized Tensor shape.
/// </summary>
/// <param name="dim"></param>
public GeneralizedTensorShape(int dim)
public GeneralizedTensorShape(int dim, int size = 1)
{
Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
Shapes = Enumerable.Repeat(elem, size).ToArray();
//Shapes = new TensorShapeConfig[size];
//Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
//Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
}

public GeneralizedTensorShape(Shape shape)
@@ -113,6 +118,11 @@ namespace Tensorflow.Common.Types
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s)));
}
}


public static implicit operator GeneralizedTensorShape(int dims)
=> new GeneralizedTensorShape(dims);

public IEnumerator<long?[]> GetEnumerator()
{


+ 3
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs View File

@@ -10,6 +10,9 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
[JsonProperty("cell")]
// TODO: the cell should be serialized with `serialize_keras_object`.
public IRnnCell Cell { get; set; } = null;
[JsonProperty("cells")]
public IList<IRnnCell> Cells { get; set; } = null;

[JsonProperty("return_sequences")]
public bool ReturnSequences { get; set; } = false;
[JsonProperty("return_state")]


+ 2
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs View File

@@ -1,10 +1,11 @@
using System.Collections.Generic;
using Tensorflow.Keras.Layers.Rnn;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class StackedRNNCellsArgs : LayerArgs
{
public IList<RnnCell> Cells { get; set; }
public IList<IRnnCell> Cells { get; set; }
public Dictionary<string, object> Kwargs { get; set; } = null;
}
}

+ 34
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -1,5 +1,6 @@
using System;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.NumPy;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;

@@ -192,6 +193,19 @@ namespace Tensorflow.Keras.Layers
float offset = 0,
Shape input_shape = null);

public IRnnCell SimpleRNNCell(
int units,
string activation = "tanh",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f);

public IRnnCell StackedRNNCells(
IEnumerable<IRnnCell> cells);

public ILayer SimpleRNN(int units,
string activation = "tanh",
string kernel_initializer = "glorot_uniform",
@@ -200,6 +214,26 @@ namespace Tensorflow.Keras.Layers
bool return_sequences = false,
bool return_state = false);

public ILayer RNN(
IRnnCell cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false
);

public ILayer RNN(
IEnumerable<IRnnCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false
);

public ILayer Subtract();
}
}

+ 13
- 1
src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs View File

@@ -109,7 +109,19 @@ namespace Tensorflow.Operations

return ta;
});*/
throw new NotImplementedException("");
//if (indices is EagerTensor)
//{
// indices = indices as EagerTensor;
// indices = indices.numpy();
//}

//foreach (var (index, val) in zip(indices.ToArray<int>(), array_ops.unstack(value)))
//{
// this.write(index, val);
//}
//return base;
//throw new NotImplementedException("");
return this;
}

public void _merge_element_shape(Shape shape)


+ 4
- 1
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
@@ -146,7 +147,9 @@ namespace Tensorflow.Operations

return ta;
});*/
throw new NotImplementedException("");

//throw new NotImplementedException("");
return this;
}

public void _merge_element_shape(Shape shape)


+ 14
- 13
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -510,7 +510,7 @@ namespace Tensorflow.Keras
}

}
// tf.where needs its condition tensor to be the same shape as its two
// result tensors, but in our case the condition (mask) tensor is
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
@@ -535,7 +535,7 @@ namespace Tensorflow.Keras
{
mask_t = tf.expand_dims(mask_t, -1);
}
var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank));
var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().Skip(fixed_dim).ToArray());
return tf.tile(mask_t, multiples);
}

@@ -570,9 +570,6 @@ namespace Tensorflow.Keras
// individually. The result of this will be a tuple of lists, each of
// the item in tuple is list of the tensor with shape (batch, feature)




Tensors _process_single_input_t(Tensor input_t)
{
var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim
@@ -609,7 +606,7 @@ namespace Tensorflow.Keras
var mask_list = tf.unstack(mask);
if (go_backwards)
{
mask_list.Reverse();
mask_list.Reverse().ToArray();
}

for (int i = 0; i < time_steps; i++)
@@ -629,9 +626,10 @@ namespace Tensorflow.Keras
}
else
{
prev_output = successive_outputs[successive_outputs.Length - 1];
prev_output = successive_outputs.Last();
}

// output could be a tensor
output = tf.where(tiled_mask_t, output, prev_output);

var flat_states = Nest.Flatten(states).ToList();
@@ -661,13 +659,13 @@ namespace Tensorflow.Keras
}

}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
last_output = successive_outputs.Last();
new_states = successive_states.Last();
outputs = tf.stack(successive_outputs);

if (zero_output_for_mask)
{
last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
last_output = tf.where(_expand_mask(mask_list.Last(), last_output), last_output, tf.zeros_like(last_output));
outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
}
else // mask is null
@@ -689,8 +687,8 @@ namespace Tensorflow.Keras
successive_states = new Tensors { newStates };
}
}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
last_output = successive_outputs.Last();
new_states = successive_states.Last();
outputs = tf.stack(successive_outputs);
}
}
@@ -701,6 +699,8 @@ namespace Tensorflow.Keras
// Create input tensor array, if the inputs is nested tensors, then it
// will be flattened first, and tensor array will be created one per
// flattened tensor.


var input_ta = new List<TensorArray>();
for (int i = 0; i < flatted_inptus.Count; i++)
{
@@ -719,6 +719,7 @@ namespace Tensorflow.Keras
}
}


// Get the time(0) input and compute the output for that, the output will
// be used to determine the dtype of output tensor array. Don't read from
// input_ta due to TensorArray clear_after_read default to True.
@@ -773,7 +774,7 @@ namespace Tensorflow.Keras
return res;
};
}
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor), it could be an integer or tensor
else if (input_length is Tensor)
{
if (go_backwards)


+ 77
- 0
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -685,6 +685,34 @@ namespace Tensorflow.Keras.Layers
Alpha = alpha
});


public IRnnCell SimpleRNNCell(
int units,
string activation = "tanh",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f)
=> new SimpleRNNCell(new SimpleRNNCellArgs
{
Units = units,
Activation = keras.activations.GetActivationFromName(activation),
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
Dropout = dropout,
RecurrentDropout = recurrent_dropout
});

public IRnnCell StackedRNNCells(
IEnumerable<IRnnCell> cells)
=> new StackedRNNCells(new StackedRNNCellsArgs
{
Cells = cells.ToList()
});

/// <summary>
///
/// </summary>
@@ -709,6 +737,55 @@ namespace Tensorflow.Keras.Layers
ReturnState = return_state
});

/// <summary>
///
/// </summary>
/// <param name="cell"></param>
/// <param name="return_sequences"></param>
/// <param name="return_state"></param>
/// <param name="go_backwards"></param>
/// <param name="stateful"></param>
/// <param name="unroll"></param>
/// <param name="time_major"></param>
/// <returns></returns>
public ILayer RNN(
IRnnCell cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cell = cell,
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});

public ILayer RNN(
IEnumerable<IRnnCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cells = cell.ToList(),
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});

/// <summary>
/// Long Short-Term Memory layer - Hochreiter 1997.
/// </summary>


+ 15
- 0
src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs View File

@@ -17,6 +17,21 @@ namespace Tensorflow.Keras.Layers.Rnn

}

protected void _create_non_trackable_mask_cache()
{
}

public void reset_dropout_mask()
{

}

public void reset_recurrent_dropout_mask()
{

}

public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
{
if (dropout == 0f)


+ 53
- 23
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -38,7 +38,17 @@ namespace Tensorflow.Keras.Layers.Rnn
SupportsMasking = true;

// if is StackedRnncell
_cell = args.Cell;
if (args.Cells != null)
{
_cell = new StackedRNNCells(new StackedRNNCellsArgs
{
Cells = args.Cells
});
}
else
{
_cell = args.Cell;
}

// get input_shape
_args = PreConstruct(args);
@@ -122,6 +132,8 @@ namespace Tensorflow.Keras.Layers.Rnn
var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list());
return new Shape(state_shape);
};


var state_shape = _get_state_shape(state_size);

return new List<Shape> { output_shape, state_shape };
@@ -240,7 +252,7 @@ namespace Tensorflow.Keras.Layers.Rnn
if (_cell is StackedRNNCells)
{
var stack_cell = _cell as StackedRNNCells;
foreach (var cell in stack_cell.Cells)
foreach (IRnnCell cell in stack_cell.Cells)
{
_maybe_reset_cell_dropout_mask(cell);
}
@@ -253,7 +265,7 @@ namespace Tensorflow.Keras.Layers.Rnn
}

Shape input_shape;
if (!inputs.IsSingle())
if (!inputs.IsNested())
{
// In the case of nested input, use the first element for shape check
// input_shape = nest.flatten(inputs)[0].shape;
@@ -267,7 +279,7 @@ namespace Tensorflow.Keras.Layers.Rnn

var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];

if (_args.Unroll && timesteps != null)
if (_args.Unroll && timesteps == null)
{
throw new ValueError(
"Cannot unroll a RNN if the " +
@@ -302,7 +314,6 @@ namespace Tensorflow.Keras.Layers.Rnn
states = new Tensors(states.SkipLast(_num_constants));
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
// TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors?
return (output, new_states.Single);
};
}
@@ -310,13 +321,14 @@ namespace Tensorflow.Keras.Layers.Rnn
{
step = (inputs, states) =>
{
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states;
var (output, new_states) = _cell.Apply(inputs, states);
return (output, new_states.Single);
return (output, new_states);
};
}

var (last_output, outputs, states) = keras.backend.rnn(step,
var (last_output, outputs, states) = keras.backend.rnn(
step,
inputs,
initial_state,
constants: constants,
@@ -394,6 +406,7 @@ namespace Tensorflow.Keras.Layers.Rnn
initial_state = null;
inputs = inputs[0];
}

if (_args.Stateful)
{
@@ -402,7 +415,7 @@ namespace Tensorflow.Keras.Layers.Rnn
var tmp = new Tensor[] { };
foreach (var s in nest.flatten(States))
{
tmp.add(tf.math.count_nonzero((Tensor)s));
tmp.add(tf.math.count_nonzero(s.Single()));
}
var non_zero_count = tf.add_n(tmp);
//initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
@@ -415,6 +428,15 @@ namespace Tensorflow.Keras.Layers.Rnn
{
initial_state = States;
}
// TODO(Wanglongzhi2001),
// initial_state = tf.nest.map_structure(
//# When the layer has a inferred dtype, use the dtype from the
//# cell.
// lambda v: tf.cast(
// v, self.compute_dtype or self.cell.compute_dtype
// ),
// initial_state,
// )

}
else if (initial_state is null)
@@ -424,10 +446,9 @@ namespace Tensorflow.Keras.Layers.Rnn

if (initial_state.Length != States.Length)
{
throw new ValueError(
$"Layer {this} expects {States.Length} state(s), " +
$"but it received {initial_state.Length} " +
$"initial state(s). Input received: {inputs}");
throw new ValueError($"Layer {this} expects {States.Length} state(s), " +
$"but it received {initial_state.Length} " +
$"initial state(s). Input received: {inputs}");
}

return (inputs, initial_state, constants);
@@ -458,11 +479,11 @@ namespace Tensorflow.Keras.Layers.Rnn

void _maybe_reset_cell_dropout_mask(ILayer cell)
{
//if (cell is DropoutRNNCellMixin)
//{
// cell.reset_dropout_mask();
// cell.reset_recurrent_dropout_mask();
//}
if (cell is DropoutRNNCellMixin CellDRCMixin)
{
CellDRCMixin.reset_dropout_mask();
CellDRCMixin.reset_recurrent_dropout_mask();
}
}

private static RNNArgs PreConstruct(RNNArgs args)
@@ -537,15 +558,24 @@ namespace Tensorflow.Keras.Layers.Rnn

protected Tensors get_initial_state(Tensors inputs)
{
var get_initial_state_fn = _cell.GetType().GetMethod("get_initial_state");

var input = inputs[0];
var input_shape = input.shape;
var input_shape = inputs.shape;
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
var dtype = input.dtype;
Tensors init_state;
if (_cell is RnnCellBase rnn_base_cell)

Tensors init_state = new Tensors();

if(get_initial_state_fn != null)
{
init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype);
init_state = (Tensors)get_initial_state_fn.Invoke(_cell, new object[] { inputs, batch_size, dtype });
}
//if (_cell is RnnCellBase rnn_base_cell)
//{
// init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype);
//}
else
{
init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype);


+ 9
- 1
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -6,6 +6,7 @@ using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;
using Tensorflow.Common.Extensions;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers.Rnn
{
@@ -77,8 +78,10 @@ namespace Tensorflow.Keras.Layers.Rnn
var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value);

Tensor h;
var ranks = inputs.rank;
if (dp_mask != null)
{

h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor());
}
else
@@ -95,7 +98,7 @@ namespace Tensorflow.Keras.Layers.Rnn
{
prev_output = math_ops.multiply(prev_output, rec_dp_mask);
}
var tmp = _recurrent_kernel.AsTensor();
Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor());

if (_args.Activation != null)
@@ -113,5 +116,10 @@ namespace Tensorflow.Keras.Layers.Rnn
return new Tensors(output, output);
}
}

public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value);
}
}
}

+ 97
- 62
src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs View File

@@ -1,17 +1,20 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers.Rnn
{
public class StackedRNNCells : Layer, IRnnCell
{
public IList<RnnCell> Cells { get; set; }
public IList<IRnnCell> Cells { get; set; }
public bool reverse_state_order;

public StackedRNNCells(StackedRNNCellsArgs args) : base(args)
@@ -20,8 +23,19 @@ namespace Tensorflow.Keras.Layers.Rnn
{
args.Kwargs = new Dictionary<string, object>();
}

foreach (var cell in args.Cells)
{
//Type type = cell.GetType();
//var CallMethodInfo = type.GetMethod("Call");
//if (CallMethodInfo == null)
//{
// throw new ValueError(
// "All cells must have a `Call` method. " +
// $"Received cell without a `Call` method: {cell}");
//}
}
Cells = args.Cells;
reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false);

if (reverse_state_order)
@@ -33,91 +47,112 @@ namespace Tensorflow.Keras.Layers.Rnn
}
}

public object state_size
public GeneralizedTensorShape StateSize
{
get => throw new NotImplementedException();
//@property
//def state_size(self) :
// return tuple(c.state_size for c in
// (self.cells[::- 1] if self.reverse_state_order else self.cells))
get
{
GeneralizedTensorShape state_size = new GeneralizedTensorShape(1, Cells.Count);
if (reverse_state_order && Cells.Count > 0)
{
var idxAndCell = Cells.Reverse().Select((cell, idx) => (idx, cell));
foreach (var cell in idxAndCell)
{
state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First();
}
}
else
{
//foreach (var cell in Cells)
//{
// state_size.Shapes.add(cell.StateSize.Shapes.First());

//}
var idxAndCell = Cells.Select((cell, idx) => (idx, cell));
foreach (var cell in idxAndCell)
{
state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First();
}
}
return state_size;
}
}

public object output_size
{
get
{
var lastCell = Cells[Cells.Count - 1];

if (lastCell.output_size != -1)
var lastCell = Cells.LastOrDefault();
if (lastCell.OutputSize.ToSingleShape() != -1)
{
return lastCell.output_size;
return lastCell.OutputSize;
}
else if (RNN.is_multiple_state(lastCell.StateSize))
{
// return ((dynamic)Cells[-1].state_size)[0];
throw new NotImplementedException("");
return lastCell.StateSize.First();
//throw new NotImplementedException("");
}
else
{
return Cells[-1].state_size;
return lastCell.StateSize;
}
}
}

public object get_initial_state()
public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
{
throw new NotImplementedException();
// def get_initial_state(self, inputs= None, batch_size= None, dtype= None) :
// initial_states = []
// for cell in self.cells[::- 1] if self.reverse_state_order else self.cells:
// get_initial_state_fn = getattr(cell, 'get_initial_state', None)
// if get_initial_state_fn:
// initial_states.append(get_initial_state_fn(
// inputs=inputs, batch_size=batch_size, dtype=dtype))
// else:
// initial_states.append(_generate_zero_filled_state_for_cell(
// cell, inputs, batch_size, dtype))

// return tuple(initial_states)
var cells = reverse_state_order ? Cells.Reverse() : Cells;
Tensors initial_states = new Tensors();
foreach (var cell in cells)
{
var get_initial_state_fn = cell.GetType().GetMethod("get_initial_state");
if (get_initial_state_fn != null)
{
var result = (Tensors)get_initial_state_fn.Invoke(cell, new object[] { inputs, batch_size, dtype });
initial_states.Add(result);
}
else
{
initial_states.Add(RnnUtils.generate_zero_filled_state_for_cell(cell, inputs, batch_size.Value, dtype.Value));
}
}
return initial_states;
}

public object call()
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
throw new NotImplementedException();
// def call(self, inputs, states, constants= None, training= None, ** kwargs):
// # Recover per-cell states.
// state_size = (self.state_size[::- 1]
// if self.reverse_state_order else self.state_size)
// nested_states = nest.pack_sequence_as(state_size, nest.flatten(states))

// # Call the cells in order and store the returned states.
// new_nested_states = []
// for cell, states in zip(self.cells, nested_states) :
// states = states if nest.is_nested(states) else [states]
//# TF cell does not wrap the state into list when there is only one state.
// is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None
// states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
// if generic_utils.has_arg(cell.call, 'training'):
// kwargs['training'] = training
// else:
// kwargs.pop('training', None)
// # Use the __call__ function for callable objects, eg layers, so that it
// # will have the proper name scopes for the ops, etc.
// cell_call_fn = cell.__call__ if callable(cell) else cell.call
// if generic_utils.has_arg(cell.call, 'constants'):
// inputs, states = cell_call_fn(inputs, states,
// constants= constants, ** kwargs)
// else:
// inputs, states = cell_call_fn(inputs, states, ** kwargs)
// new_nested_states.append(states)
// Recover per-cell states.
var state_size = reverse_state_order ? StateSize.Reverse() : StateSize;
var nested_states = reverse_state_order ? state.Flatten().Reverse() : state.Flatten();

// return inputs, nest.pack_sequence_as(state_size,
// nest.flatten(new_nested_states))

var new_nest_states = new Tensors();
// Call the cells in order and store the returned states.
foreach (var (cell, states) in zip(Cells, nested_states))
{
// states = states if tf.nest.is_nested(states) else [states]
var type = cell.GetType();
bool IsTFRnnCell = type.GetProperty("IsTFRnnCell") != null;
state = len(state) == 1 && IsTFRnnCell ? state.FirstOrDefault() : state;

RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
Tensors? constants = rnn_optional_args?.Constants;

Tensors new_states;
(inputs, new_states) = cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });

new_nest_states.Add(new_states);
}
new_nest_states = reverse_state_order ? new_nest_states.Reverse().ToArray() : new_nest_states.ToArray();
return new Nest<Tensor>(new List<Nest<Tensor>> {
new Nest<Tensor>(new List<Nest<Tensor>> { new Nest<Tensor>(inputs.Single()) }), new Nest<Tensor>(new_nest_states) })
.ToTensors();
}

public void build()
{
throw new NotImplementedException();
built = true;
// @tf_utils.shape_type_conversion
// def build(self, input_shape) :
// if isinstance(input_shape, list) :
@@ -168,9 +203,9 @@ namespace Tensorflow.Keras.Layers.Rnn
{
throw new NotImplementedException();
}
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
public bool IsTFRnnCell => throw new NotImplementedException();
public bool IsTFRnnCell => true;
public bool SupportOptionalArgs => throw new NotImplementedException();
}
}

+ 18
- 7
test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs View File

@@ -2,6 +2,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using Tensorflow.Keras.Callbacks;
using Tensorflow.Keras.Engine;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;


@@ -18,7 +19,7 @@ namespace Tensorflow.Keras.UnitTest.Callbacks
var layers = keras.layers;
var model = keras.Sequential(new List<ILayer>
{
layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)),
layers.Rescaling(1.0f / 255, input_shape: (28, 28, 1)),
layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu),
layers.MaxPooling2D(),
layers.Flatten(),
@@ -36,8 +37,20 @@ namespace Tensorflow.Keras.UnitTest.Callbacks
var num_epochs = 3;
var batch_size = 8;

var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
x_train = x_train / 255.0f;
var data_loader = new MnistModelLoader();

var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 59900,
}).Result;

NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1));
NDArray x2 = x1;

var x = new NDArray[] { x1, x2 };

// define a CallbackParams first, the parameters you pass al least contain Model and Epochs.
CallbackParams callback_parameters = new CallbackParams
{
@@ -47,10 +60,8 @@ namespace Tensorflow.Keras.UnitTest.Callbacks
// define your earlystop
ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy");
// define a callbcaklist, then add the earlystopping to it.
var callbacks = new List<ICallback>();
callbacks.add(earlystop);

model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs, callbacks: callbacks);
var callbacks = new List<ICallback>{ earlystop};
model.fit(x, dataset.Train.Labels, batch_size, num_epochs, callbacks: callbacks);
}

}


+ 94
- 8
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -4,25 +4,111 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Train;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class Rnn
{
[TestMethod]
public void SimpleRNNCell()
{
//var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f);
//var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
//var x = tf.random.normal((4, 100));
//var (y, h1) = cell.Apply(inputs: x, states: h0);
//var h2 = h1;
//Assert.AreEqual((4, 64), y.shape);
//Assert.AreEqual((4, 64), h2[0].shape);

//var model = keras.Sequential(new List<ILayer>
//{
// keras.layers.InputLayer(input_shape: (4,100)),
// keras.layers.SimpleRNNCell(64)
//});
//model.summary();

var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f);
var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
var x = tf.random.normal((4, 100));
var (y, h1) = cell.Apply(inputs: x, states: h0);
var h2 = h1;
Assert.AreEqual((4, 64), y.shape);
Assert.AreEqual((4, 64), h2[0].shape);
}

[TestMethod]
public void StackedRNNCell()
{
var inputs = tf.ones((32, 10));
var states = new Tensors { tf.zeros((32, 4)), tf.zeros((32, 5)) };
var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) };
var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells);
var (output, state) = stackedRNNCell.Apply(inputs, states);
Console.WriteLine(output);
Console.WriteLine(state.shape);
Assert.AreEqual((32, 5), output.shape);
Assert.AreEqual((32, 4), state[0].shape);
}

[TestMethod]
public void SimpleRNN()
{
var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
/*var simple_rnn = keras.layers.SimpleRNN(4);
var output = simple_rnn.Apply(inputs);
Assert.AreEqual((32, 4), output.shape);*/
var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
Console.WriteLine(whole_sequence_output);
Console.WriteLine(final_state);
//var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
///*var simple_rnn = keras.layers.SimpleRNN(4);
//var output = simple_rnn.Apply(inputs);
//Assert.AreEqual((32, 4), output.shape);*/

//var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
//var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
//Assert.AreEqual((6, 10, 4), whole_sequence_output.shape);
//Assert.AreEqual((6, 4), final_state.shape);

var inputs = keras.Input(shape: (10, 8));
var x = keras.layers.SimpleRNN(4).Apply(inputs);
var output = keras.layers.Dense(10).Apply(x);
var model = keras.Model(inputs, output);
model.summary();
}
[TestMethod]
public void RNNForSimpleRNNCell()
{
var inputs = tf.random.normal((32, 10, 8));
var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f);
var rnn = tf.keras.layers.RNN(cell: cell);
var output = rnn.Apply(inputs);
Assert.AreEqual((32, 10), output.shape);

}
[TestMethod]
public void RNNForStackedRNNCell()
{
var inputs = tf.random.normal((32, 10, 8));
var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) };
var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells);
var rnn = tf.keras.layers.RNN(cell: stackedRNNCell);
var output = rnn.Apply(inputs);
Assert.AreEqual((32, 5), output.shape);
}

[TestMethod]
public void WlzTest()
{
long[] b = { 1, 2, 3 };
Shape a = new Shape(Unknown).concatenate(b);
Console.WriteLine(a);

}


}
}

Loading…
Cancel
Save