Browse Source

update draft pr for RNN

pull/1090/head
Wanglongzhi2001 2 years ago
parent
commit
bad4e81605
20 changed files with 883 additions and 557 deletions
  1. +4
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs
  2. +6
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  3. +8
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  4. +7
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  5. +46
    -22
      src/TensorFlowNET.Core/Util/nest.py.cs
  6. +346
    -324
      src/TensorFlowNET.Keras/BackendImpl.cs
  7. +1
    -1
      src/TensorFlowNET.Keras/Engine/Functional.cs
  8. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
  9. +56
    -47
      src/TensorFlowNET.Keras/Engine/Layer.cs
  10. +3
    -3
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  11. +1
    -1
      src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
  12. +17
    -0
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  13. +80
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs
  14. +142
    -94
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  15. +59
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/RNNUtils.cs
  16. +90
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  17. +1
    -1
      src/TensorflowNET.Hub/KerasLayer.cs
  18. +0
    -60
      test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs
  19. +11
    -0
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
  20. +4
    -0
      test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj

+ 4
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs View File

@@ -2,6 +2,9 @@
{
public class SimpleRNNArgs : RNNArgs
{

public float Dropout = 0f;
public float RecurrentDropout = 0f;
public int state_size;
public int output_size;
}
}

+ 6
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -27,5 +27,11 @@ namespace Tensorflow.Keras
TF_DataType DType { get; }
int count_params();
void adapt(Tensor data, int? batch_size = null, int? steps = null);

Tensors Call(Tensors inputs, Tensor? mask = null, bool? training = null, Tensors? initial_state = null, Tensors? constants = null);

StateSizeWrapper state_size { get; }

int output_size { get; }
}
}

+ 8
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -200,6 +200,14 @@ namespace Tensorflow.Keras.Layers
bool return_sequences = false,
bool return_state = false);

public ILayer SimpleRNNCell(
int units,
string activation = "tanh",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros");

public ILayer Subtract();
}
}

+ 7
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -89,6 +89,8 @@ namespace Tensorflow
protected bool built = false;
public bool Built => built;

StateSizeWrapper ILayer.state_size => throw new NotImplementedException();

public RnnCell(bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
@@ -174,5 +176,10 @@ namespace Tensorflow
{
throw new NotImplementedException();
}

public Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
throw new NotImplementedException();
}
}
}

+ 46
- 22
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -19,6 +19,7 @@ using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;

namespace Tensorflow.Util
{
@@ -213,6 +214,17 @@ namespace Tensorflow.Util

public static bool is_nested(object obj)
{
// Refer to https://www.tensorflow.org/api_docs/python/tf/nest
//if (obj is IList || obj is IDictionary || obj is ITuple)
// return true;
if (obj is IList || obj is IDictionary)
return true;

if (obj is NDArray || obj is Tensor || obj is string || obj.GetType().IsGenericType
|| obj is ISet<int> || obj is ISet<float> || obj is ISet<double>)
return false;

if (obj.GetType().IsNested) return true;
// Check if the object is an IEnumerable
if (obj is IEnumerable)
{
@@ -244,7 +256,13 @@ namespace Tensorflow.Util
_flatten_recursive(structure, list);
return list;
}

// TODO(Wanglongzhi2001), ITuple must used in .NET standard 2.1, but now is 2.0
// If you want to flatten a nested tuple, please specify the type of the tuple
//public static List<T> flatten<T>(ITuple structure)
//{
// var list = FlattenTuple<T>(structure).ToList();
// return list;
//}
public static List<T> flatten<T>(IEnumerable<T> structure)
{
var list = new List<T>();
@@ -272,9 +290,13 @@ namespace Tensorflow.Util
case String str:
list.Add(obj);
break;
case NDArray nd:
// This case can hold both Tensor and NDArray
case Tensor tensor:
list.Add(obj);
break;
//case NDArray nd:
// list.Add(obj);
// break;
case IEnumerable structure:
foreach (var child in structure)
_flatten_recursive((T)child, list);
@@ -285,28 +307,26 @@ namespace Tensorflow.Util
}
}

public static List<T> FlattenTupple<T>(object tuple)
private static IEnumerable<T> FlattenTuple<T>(object tuple)
{
List<T> items = new List<T>();
var type = tuple.GetType();

if (type.GetInterface("ITuple") == null)
throw new ArgumentException("This is not a tuple!");
//if (tuple is ITuple t)
//{
// for (int i = 0; i < t.Length; i++)
// {
// foreach (var item in FlattenTuple<T>(t[i]))
// {
// yield return item;
// }
// }
//}
if(false)
{

foreach (var property in type.GetProperties())
}
else
{
var value = property.GetValue(tuple);
if (property.PropertyType.GetInterface("ITuple") != null)
{
var subItems = FlattenTupple<T>(value);
items.AddRange(subItems);
}
else
{
items.Add((T)value);
}
yield return (T)tuple;
}
return items;
}
//# See the swig file (util.i) for documentation.
//_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
@@ -494,8 +514,12 @@ namespace Tensorflow.Util
throw new ArgumentException("flat_sequence must not be null");
// if not is_sequence(flat_sequence):
// raise TypeError("flat_sequence must be a sequence")

if (!is_sequence(structure))
if (!is_nested(flat_sequence))
{
throw new ArrayTypeMismatchException($"Attempted to pack value:\\n {flat_sequence}\\ninto a structure, " +
$"but found incompatible type `{flat_sequence.GetType()}` instead.");
}
if (!is_nested(structure))
{
if (len(flat) != 1)
throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat)} > 1");


+ 346
- 324
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -614,331 +614,331 @@ namespace Tensorflow.Keras
return nest.pack_sequence_as(inputs, inp);
}

if (mask != null)
{
var mask_list = tf.unstack(mask);
if (go_backwards)
{
mask_list.Reverse();
}

for (int i = 0; i < time_steps; i++)
{
// TODO(Wanglongzhi2001),deal with _get_input_tensor
var inp = _get_input_tensor(i);
var mask_t = mask_list[i];
// TODO
var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants });

var tiled_mask_t = _expand_mask(mask_t, output);

Tensors prev_output;
if (successive_outputs == null)
{
prev_output = tf.zeros_like(output);
}
else
{
prev_output = successive_outputs[successive_outputs.Length - 1];
}

output = tf.where(tiled_mask_t, output, prev_output);

//var flat_states = nest.flatten(states);
//var flat_new_states = nest.flatten(newStates);
var flat_states = states.ToList();
var flat_new_states = newStates.ToList();

var tiledMaskT = flat_states
.Select(s => _expand_mask(mask_t, s))
.ToArray();
var tuple = Tuple.Create(tiledMaskT);

List<Tensor> flat_final_states = new List<Tensor>();
foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states))
{
flat_final_states.Add(tf.where(m, s, ps));
}

states = (Tensors)nest.pack_sequence_as(states, flat_final_states);
if (return_all_outputs)
{
successive_outputs.Add(output);
successive_states.Add(states);
}
else
{
successive_outputs = new Tensors { output };
successive_states = new Tensors { states };
}

}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
outputs = tf.stack(successive_outputs);

if (zero_output_for_mask)
{
last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
}
else // mask is null
{
for (int i = 0; i < time_steps; i++)
{
var inp = _get_input_tensor(i);
var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants });
states = newStates;

if (return_all_outputs)
{
successive_outputs.Add(output);
successive_states.Add(newStates);
}
else
{
successive_outputs = new Tensors { output };
successive_states = new Tensors { newStates };
}
}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
outputs = tf.stack(successive_outputs);
}
}
}
else // unroll == false
{
var states = initial_states;
// Create input tensor array, if the inputs is nested tensors, then it
// will be flattened first, and tensor array will be created one per
// flattened tensor.
var input_ta = new List<TensorArray>();
for (int i = 0; i < flatted_inptus.Count; i++)
{
input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t));
}

// Get the time(0) input and compute the output for that, the output will
// be used to determine the dtype of output tensor array. Don't read from
// input_ta due to TensorArray clear_after_read default to True.
var inps = new Tensors();
foreach (var inp in flatted_inptus)
{
inps.Add(inp[0]);
}
var input_time_zero = nest.pack_sequence_as(inputs, inps);

// output_time_zero is used to determine the cell output shape and its
// dtype. the value is discarded.
(output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants });

var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1);
var output_ta = new List<TensorArray>();
for (int i = 0; i < output_time_zero.ToList().Count; i++)
{
var Out = output_time_zero.ToList()[i];
output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape));
}

var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time");



Func<Tensor, Tensor>? masking_fn;
Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null;
if (mask != null)
{
if (go_backwards)
{
mask = tf.reverse(mask, axis: new[] { 0 });
}
var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t);
mask_ta = mask_ta.unstack(mask);

masking_fn = (time) =>
{
return mask_ta.read(time);
};

compute_masked_output = (mask_t, flat_out, flat_mask) =>
{
var tiled_mask_t = new Tensors();
foreach (var o in flat_out)
{
tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank));
}

Tensors res = new Tensors();
foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask))
{
res.Add(tf.where(m, o, fm));
}
return res;
};
}
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
else if (input_length is Tensor)
{
if (go_backwards)
{
var max_len = tf.reduce_max(input_length, axis: 0);
var rev_input_length = tf.subtract(max_len - 1, input_length);

masking_fn = (time) =>
{
return tf.less(rev_input_length, time);
};
}
else
{
masking_fn = (time) =>
{
return tf.greater(input_length, time);
};
}

compute_masked_output = (mask_t, flat_out, flat_mask) =>
{
var res = new List<Tensor>();
foreach (var (o, zo) in zip(flat_out, flat_mask))
{
res.Add(tf.where(mask_t, o, zo));
}
return res;
};
}
else
{
masking_fn = null;
}


if (masking_fn != null)
{
// Mask for the T output will be base on the output of T - 1. In the
// case T = 0, a zero filled tensor will be used.
var flat_zero_output = new Tensors();
foreach (var o in nest.flatten(output_time_zero))
{
flat_zero_output.Add(tf.zeros_like(o));
}


(Tensor, List<TensorArray>, Tensors, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors prev_output, Tensors states)
{
/*
RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
prev_output: tuple of outputs from time - 1.
*states: List of states.
Returns:
Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)`
*/

var current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input);
var mask_t = masking_fn(time);
var (output, new_states) = step_function(current_input, new Tensors { states, constants });
// mask output
//var flat_output = nest.flatten(output);
var flat_output = output.ToList();

var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList();

// TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type
var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output);

// mask states
var flat_state = states.ToList();
var flat_new_state = new_states.ToList();

foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
if (new_state is Tensor)
{
new_state.set_shape(state.shape);
}
}

var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state);
new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state);

var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
var Output_ta_t = new List<TensorArray>();
// TODO(Wanglongzhi2001),deal with zip output_ta_t
foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
{
Output_ta_t.Add(ta.write(ta_index_to_write, Out));
}



//new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state);


return (time + 1, Output_ta_t, flat_new_output, new_states);

}
Func<Tensor, Tensor> cond = (time) => (time < time_step_t);

var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states));
new_states = final_outputs.Item4;
output_ta = final_outputs.Item2;

}
else
{
(Tensor, List<TensorArray>, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors states)
{
var current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input);
var (output, new_states) = step_function(current_input, new Tensors { states, constants });
var flat_state = states.ToList();
var flat_new_state = new_states.ToList();
foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
if (new_state is Tensor)
{
new_state.set_shape(state.shape);
}
}
var flat_output = output.ToList();
var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
var Output_ta_t = new List<TensorArray>();
foreach (var (ta, out_) in zip(output_ta_t, flat_output))
{
Output_ta_t.Add(ta.write(ta_index_to_write, out_));
}

new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state);
return (time + 1, Output_ta_t, new_states);
}
Func<Tensor, Tensor> cond = (time) => (time < time_step_t);
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states));
new_states = final_outputs.Item3;
output_ta = final_outputs.Item2;

}
//Tensors outputs = new Tensors();
foreach (var o in output_ta)
{
outputs.Add(o.stack());
}
foreach (var o in outputs)
{
last_output.Add(o[-1]);
}
outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs);
last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output);

//if (mask != null)
//{
// var mask_list = tf.unstack(mask);
// if (go_backwards)
// {
// mask_list.Reverse();
// }

// for (int i = 0; i < time_steps; i++)
// {
// // TODO(Wanglongzhi2001),deal with _get_input_tensor
// var inp = _get_input_tensor(i);
// var mask_t = mask_list[i];
// // TODO
// var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants });

// var tiled_mask_t = _expand_mask(mask_t, output);

// Tensors prev_output;
// if (successive_outputs == null)
// {
// prev_output = tf.zeros_like(output);
// }
// else
// {
// prev_output = successive_outputs[successive_outputs.Length - 1];
// }

// output = tf.where(tiled_mask_t, output, prev_output);

// //var flat_states = nest.flatten(states);
// //var flat_new_states = nest.flatten(newStates);
// var flat_states = states.ToList();
// var flat_new_states = newStates.ToList();

// var tiledMaskT = flat_states
// .Select(s => _expand_mask(mask_t, s))
// .ToArray();
// var tuple = Tuple.Create(tiledMaskT);

// List<Tensor> flat_final_states = new List<Tensor>();
// foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states))
// {
// flat_final_states.Add(tf.where(m, s, ps));
// }

// states = (Tensors)nest.pack_sequence_as(states, flat_final_states);
// if (return_all_outputs)
// {
// successive_outputs.Add(output);
// successive_states.Add(states);
// }
// else
// {
// successive_outputs = new Tensors { output };
// successive_states = new Tensors { states };
// }

// }
// last_output = successive_outputs[successive_outputs.Length - 1];
// new_states = successive_states[successive_states.Length - 1];
// outputs = tf.stack(successive_outputs);

// if (zero_output_for_mask)
// {
// last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
// outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
// }
// else // mask is null
// {
// for (int i = 0; i < time_steps; i++)
// {
// var inp = _get_input_tensor(i);
// var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants });
// states = newStates;

// if (return_all_outputs)
// {
// successive_outputs.Add(output);
// successive_states.Add(newStates);
// }
// else
// {
// successive_outputs = new Tensors { output };
// successive_states = new Tensors { newStates };
// }
// }
// last_output = successive_outputs[successive_outputs.Length - 1];
// new_states = successive_states[successive_states.Length - 1];
// outputs = tf.stack(successive_outputs);
// }
//}
}
//else // unroll == false
//{
// var states = initial_states;
// // Create input tensor array, if the inputs is nested tensors, then it
// // will be flattened first, and tensor array will be created one per
// // flattened tensor.
// var input_ta = new List<TensorArray>();
// for (int i = 0; i < flatted_inptus.Count; i++)
// {
// input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t));
// }

// // Get the time(0) input and compute the output for that, the output will
// // be used to determine the dtype of output tensor array. Don't read from
// // input_ta due to TensorArray clear_after_read default to True.
// var inps = new Tensors();
// foreach (var inp in flatted_inptus)
// {
// inps.Add(inp[0]);
// }
// var input_time_zero = nest.pack_sequence_as(inputs, inps);

// // output_time_zero is used to determine the cell output shape and its
// // dtype. the value is discarded.
// (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants });

// var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1);
// var output_ta = new List<TensorArray>();
// for (int i = 0; i < output_time_zero.ToList().Count; i++)
// {
// var Out = output_time_zero.ToList()[i];
// output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape));
// }

// var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time");



// Func<Tensor, Tensor>? masking_fn;
// Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null;
// if (mask != null)
// {
// if (go_backwards)
// {
// mask = tf.reverse(mask, axis: new[] { 0 });
// }
// var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t);
// mask_ta = mask_ta.unstack(mask);

// masking_fn = (time) =>
// {
// return mask_ta.read(time);
// };

// compute_masked_output = (mask_t, flat_out, flat_mask) =>
// {
// var tiled_mask_t = new Tensors();
// foreach (var o in flat_out)
// {
// tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank));
// }

// Tensors res = new Tensors();
// foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask))
// {
// res.Add(tf.where(m, o, fm));
// }
// return res;
// };
// }
// // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
// else if (input_length is Tensor)
// {
// if (go_backwards)
// {
// var max_len = tf.reduce_max(input_length, axis: 0);
// var rev_input_length = tf.subtract(max_len - 1, input_length);

// masking_fn = (time) =>
// {
// return tf.less(rev_input_length, time);
// };
// }
// else
// {
// masking_fn = (time) =>
// {
// return tf.greater(input_length, time);
// };
// }

// compute_masked_output = (mask_t, flat_out, flat_mask) =>
// {
// var res = new List<Tensor>();
// foreach (var (o, zo) in zip(flat_out, flat_mask))
// {
// res.Add(tf.where(mask_t, o, zo));
// }
// return res;
// };
// }
// else
// {
// masking_fn = null;
// }


// if (masking_fn != null)
// {
// // Mask for the T output will be base on the output of T - 1. In the
// // case T = 0, a zero filled tensor will be used.
// var flat_zero_output = new Tensors();
// foreach (var o in nest.flatten(output_time_zero))
// {
// flat_zero_output.Add(tf.zeros_like(o));
// }


// (Tensor, List<TensorArray>, Tensors, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors prev_output, Tensors states)
// {
// /*
// RNN step function.
// Args:
// time: Current timestep value.
// output_ta_t: TensorArray.
// prev_output: tuple of outputs from time - 1.
// *states: List of states.
// Returns:
// Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)`
// */

// var current_input = input_ta.Select(x => x.read(time)).ToList();
// // maybe set shape
// // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
// current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input);
// var mask_t = masking_fn(time);
// var (output, new_states) = step_function(current_input, new Tensors { states, constants });
// // mask output
// //var flat_output = nest.flatten(output);
// var flat_output = output.ToList();

// var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList();

// // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type
// var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output);

// // mask states
// var flat_state = states.ToList();
// var flat_new_state = new_states.ToList();

// foreach (var (state, new_state) in zip(flat_state, flat_new_state))
// {
// if (new_state is Tensor)
// {
// new_state.set_shape(state.shape);
// }
// }

// var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state);
// new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state);

// var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
// var Output_ta_t = new List<TensorArray>();
// // TODO(Wanglongzhi2001),deal with zip output_ta_t
// foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
// {
// Output_ta_t.Add(ta.write(ta_index_to_write, Out));
// }



// //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state);


// return (time + 1, Output_ta_t, flat_new_output, new_states);

// }
// Func<Tensor, Tensor> cond = (time) => (time < time_step_t);

// var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states));
// new_states = final_outputs.Item4;
// output_ta = final_outputs.Item2;

// }
// else
// {
// (Tensor, List<TensorArray>, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors states)
// {
// var current_input = input_ta.Select(x => x.read(time)).ToList();
// // maybe set shape
// // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
// current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input);
// var (output, new_states) = step_function(current_input, new Tensors { states, constants });
// var flat_state = states.ToList();
// var flat_new_state = new_states.ToList();
// foreach (var (state, new_state) in zip(flat_state, flat_new_state))
// {
// if (new_state is Tensor)
// {
// new_state.set_shape(state.shape);
// }
// }
// var flat_output = output.ToList();
// var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
// var Output_ta_t = new List<TensorArray>();
// foreach (var (ta, out_) in zip(output_ta_t, flat_output))
// {
// Output_ta_t.Add(ta.write(ta_index_to_write, out_));
// }

// new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state);
// return (time + 1, Output_ta_t, new_states);
// }
// Func<Tensor, Tensor> cond = (time) => (time < time_step_t);
// var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states));
// new_states = final_outputs.Item3;
// output_ta = final_outputs.Item2;

// }
// //Tensors outputs = new Tensors();
// foreach (var o in output_ta)
// {
// outputs.Add(o.stack());
// }
// foreach (var o in outputs)
// {
// last_output.Add(o[-1]);
// }
// outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs);
// last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output);

//}

Func<Tensor, Tensor> set_shape;
set_shape = (output_) =>
@@ -968,5 +968,27 @@ namespace Tensorflow.Keras
return (last_output, Outputs, new_states);

}

// Multiplies 2 tensors (and/or variables) and returns a tensor.
// This operation corresponds to `numpy.dot(a, b, out=None)`.
public Tensor Dot(Tensor x, Tensor y)
{
//if (x.ndim != 1 && (x.ndim > 2 || y.ndim > 2))
//{
// var x_shape = new List<int>();
// foreach (var (i,s) in zip(x.shape.as_int_list(), tf.unstack(tf.shape(x))))
// {
// if (i != 0)
// {
// x_shape.append(i);
// }
// else
// {
// x_shape.append(s);
// }
// }
//}
throw new NotImplementedException();
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -325,7 +325,7 @@ namespace Tensorflow.Keras.Engine
nodes_in_decreasing_depth.append(node);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
var tensor_dict = new Dictionary<long, Queue<Tensor>>();
// map input values


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.Apply.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine
if (!built)
MaybeBuild(inputs);

var outputs = Call(inputs, state: state, training: training);
var outputs = Call(inputs, initial_state: state, training: training);

// memory leak
// _set_connectivity_metadata_(inputs, outputs);


+ 56
- 47
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -254,6 +254,10 @@ namespace Tensorflow.Keras.Engine
/// </summary>
public Func<Tensors, Tensors>? ReplacedCall { get; set; } = null;

public StateSizeWrapper state_size => throw new NotImplementedException();

public int output_size => throw new NotImplementedException();

public Layer(LayerArgs args)
{
Initialize(args);
@@ -434,56 +438,61 @@ namespace Tensorflow.Keras.Engine

public override void SetAttr(string name, object value)
{
// TODO(Rinne): deal with "_self_setattr_tracking".
//// TODO(Rinne): deal with "_self_setattr_tracking".

value = TrackableDataStructure.sticky_attribute_assignment(this, name, value);
//value = TrackableDataStructure.sticky_attribute_assignment(this, name, value);
foreach(var val in nest.flatten(value))
{
if(val is Metric)
{
// TODO(Rinne): deal with metrics.
}
}

// TODO(Rinne): deal with "_auto_track_sub_layers".

foreach(var val in nest.flatten(value))
{
if(val is not IVariableV1 variable)
{
continue;
}
if (variable.Trainable)
{
if (_trainable_weights.Contains(variable))
{
continue;
}
_trainable_weights.Add(variable);
}
else
{
if (_non_trainable_weights.Contains(variable))
{
continue;
}
_non_trainable_weights.Add(variable);
}
keras.backend.track_variable(variable);
}
//foreach(var val in nest.flatten(value))
//{
// if(val is Metric)
// {
// // TODO(Rinne): deal with metrics.
// }
//}

//// TODO(Rinne): deal with "_auto_track_sub_layers".

//foreach(var val in nest.flatten(value))
//{
// if(val is not IVariableV1 variable)
// {
// continue;
// }
// if (variable.Trainable)
// {
// if (_trainable_weights.Contains(variable))
// {
// continue;
// }
// _trainable_weights.Add(variable);
// }
// else
// {
// if (_non_trainable_weights.Contains(variable))
// {
// continue;
// }
// _non_trainable_weights.Add(variable);
// }
// keras.backend.track_variable(variable);
//}

//// Directly use the implementation of `Trackable`.
//var t = this.GetType();
//var field_info = t.GetField(name);
//if (field_info is not null)
//{
// field_info.SetValue(this, value);
//}
//else
//{
// CustomizedFields[name] = value;
//}
}

// Directly use the implementation of `Trackable`.
var t = this.GetType();
var field_info = t.GetField(name);
if (field_info is not null)
{
field_info.SetValue(this, value);
}
else
{
CustomizedFields[name] = value;
}
Tensors ILayer.Call(Tensors inputs, Tensor mask, bool? training, Tensors initial_state, Tensors constants)
{
throw new NotImplementedException();
}
}
}

+ 3
- 3
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -143,7 +143,7 @@ namespace Tensorflow.Keras.Engine
}
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
if (!_has_explicit_input_shape)
{
@@ -154,10 +154,10 @@ namespace Tensorflow.Keras.Engine
{
if (!built)
_init_graph_network(this.inputs, outputs);
return base.Call(inputs, state, training);
return base.Call(inputs, initial_state, training);
}

return base.Call(inputs, state, training);
return base.Call(inputs, initial_state, training);
}

void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype)


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs View File

@@ -83,7 +83,7 @@ namespace Tensorflow.Keras.Layers
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
var inputs_shape = array_ops.shape(inputs);
var batch_size = inputs_shape[0];


+ 17
- 0
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -709,6 +709,23 @@ namespace Tensorflow.Keras.Layers
ReturnState = return_state
});

public ILayer SimpleRNNCell(
int units,
string activation = "tanh",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros")
=> new SimpleRNNCell(new SimpleRNNArgs
{
Units = units,
Activation = keras.activations.GetActivationFromName(activation),
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
}
);

/// <summary>
/// Long Short-Term Memory layer - Hochreiter 1997.
/// </summary>


+ 80
- 0
src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs View File

@@ -0,0 +1,80 @@
using System;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;



namespace Tensorflow.Keras.Layers.Rnn
{
public class DropoutRNNCellMixin
{
public float dropout;
public float recurrent_dropout;
// Get the dropout mask for RNN cell's input.
public Tensors get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
{

return _generate_dropout_mask(
tf.ones_like(input),
dropout,
training,
count);
}

// Get the recurrent dropout mask for RNN cell.
public Tensors get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
{
return _generate_dropout_mask(
tf.ones_like(input),
recurrent_dropout,
training,
count);
}

public Tensors _create_dropout_mask(Tensors input, bool training, int count = 1)
{
return _generate_dropout_mask(
tf.ones_like(input),
dropout,
training,
count);
}

public Tensors _create_recurrent_dropout_mask(Tensors input, bool training, int count = 1)
{
return _generate_dropout_mask(
tf.ones_like(input),
recurrent_dropout,
training,
count);
}

public Tensors _generate_dropout_mask(Tensor ones, float rate, bool training, int count = 1)
{
Tensors dropped_inputs()
{
DropoutArgs args = new DropoutArgs();
args.Rate = rate;
var DropoutLayer = new Dropout(args);
var mask = DropoutLayer.Apply(ones, training: training);
return mask;
}

if (count > 1)
{
Tensors results = new Tensors();
for (int i = 0; i < count; i++)
{
results.Add(dropped_inputs());
}
return results;
}

return dropped_inputs();
}
}


}

+ 142
- 94
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers.Rnn
private RNNArgs args;
private object input_spec = null; // or NoneValue??
private object state_spec = null;
private object _states = null;
private Tensors _states = null;
private object constants_spec = null;
private int _num_constants = 0;
protected IVariableV1 kernel;
@@ -44,19 +44,15 @@ namespace Tensorflow.Keras.Layers.Rnn
cell = args.Cell.AsT1;
}





Type type = cell.GetType();
MethodInfo methodInfo = type.GetMethod("Call");
if (methodInfo == null)
MethodInfo callMethodInfo = type.GetMethod("Call");
if (callMethodInfo == null)
{
throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. ");
}

PropertyInfo propertyInfo = type.GetProperty("state_size");
if (propertyInfo == null)
PropertyInfo state_size_info = type.GetProperty("state_size");
if (state_size_info == null)
{
throw new ValueError(@"The RNN cell should have a `state_size` attribute");
}
@@ -80,7 +76,7 @@ namespace Tensorflow.Keras.Layers.Rnn

// States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...)
// state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape
public object States
public Tensors States
{
get
{
@@ -106,7 +102,6 @@ namespace Tensorflow.Keras.Layers.Rnn
// state_size is a array of ints or a positive integer
var state_size = cell.state_size;


// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
Func<Shape, Shape> _get_output_shape;
_get_output_shape = (flat_output_size) =>
@@ -132,8 +127,10 @@ namespace Tensorflow.Keras.Layers.Rnn
return output_shape;
};

Type type = cell.GetType();
PropertyInfo output_size_info = type.GetProperty("output_size");
Shape output_shape;
if (cell.output_size != 0)
if (output_size_info != null)
{
output_shape = nest.map_structure(_get_output_shape, cell.output_size);
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
@@ -160,6 +157,7 @@ namespace Tensorflow.Keras.Layers.Rnn
{
return output_shape;
}

}

private Tensors compute_mask(Tensors inputs, Tensors mask)
@@ -184,8 +182,6 @@ namespace Tensorflow.Keras.Layers.Rnn
{
return output_mask;
}


}

public override void build(KerasShapesWrapper input_shape)
@@ -247,14 +243,18 @@ namespace Tensorflow.Keras.Layers.Rnn
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
//var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
//bool is_ragged_input = row_length != null;
//_validate_args_if_ragged(is_ragged_input, mask);
var (inputs_processed, initial_state_processed, constants_processed) = _process_inputs(inputs, initial_state, constants);
// 暂时先不接受ragged tensor
int? row_length = null;
bool is_ragged_input = false;
_validate_args_if_ragged(is_ragged_input, mask);

(inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants);

_maybe_reset_cell_dropout_mask(cell);
if (cell is StackedRNNCells)
{
foreach (var cell in ((StackedRNNCells)cell).Cells)
var stack_cell = cell as StackedRNNCells;
foreach (var cell in stack_cell.Cells)
{
_maybe_reset_cell_dropout_mask(cell);
}
@@ -263,17 +263,16 @@ namespace Tensorflow.Keras.Layers.Rnn
if (mask != null)
{
// Time step masks must be the same for each input.
//mask = nest.flatten(mask)[0];
mask = mask[0];
mask = nest.flatten(mask)[0];
}


Shape input_shape;
if (nest.is_nested(initial_state_processed))
if (nest.is_nested(inputs))
{
// In the case of nested input, use the first element for shape check
// input_shape = nest.flatten(inputs)[0].shape;
input_shape = inputs[0].shape;
// TODO(Wanglongzhi2001)
input_shape = nest.flatten(inputs)[0].shape;
}
else
{
@@ -322,6 +321,7 @@ namespace Tensorflow.Keras.Layers.Rnn
// states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states)
states = states.Length == 1 ? states[0] : states;
var (output, new_states) = cell_call_fn(inputs, null, null, states, constants);
// TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors?
if (!nest.is_nested(new_states))
{
return (output, new Tensors { new_states });
@@ -351,7 +351,7 @@ namespace Tensorflow.Keras.Layers.Rnn
go_backwards: args.GoBackwards,
mask: mask,
unroll: args.Unroll,
input_length: row_length != null ? row_length : new Tensor(timesteps),
input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps),
time_major: args.TimeMajor,
zero_output_for_mask: args.ZeroOutputForMask,
return_all_outputs: args.ReturnSequences);
@@ -387,24 +387,9 @@ namespace Tensorflow.Keras.Layers.Rnn
}
}

private (Tensors, Tensors, Tensors) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants)
private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants)
{
bool IsSequence(object obj)
{
// Check if the object is an IEnumerable
if (obj is IEnumerable)
{
// If it is, check if it is a tuple
if (!(obj is Tuple))
{
return true;
}
}
// If it is not, return false
return false;
}

if (IsSequence(input))
if (nest.is_sequence(input))
{
if (_num_constants != 0)
{
@@ -413,6 +398,7 @@ namespace Tensorflow.Keras.Layers.Rnn
else
{
initial_state = inputs[new Slice(1, len(inputs) - _num_constants)];
constants = inputs[new Slice(len(inputs) - _num_constants, len(inputs))];
}
if (len(initial_state) == 0)
initial_state = null;
@@ -421,32 +407,63 @@ namespace Tensorflow.Keras.Layers.Rnn

if (args.Stateful)
{
throw new NotImplementedException("argument stateful has not been implemented!");
if (initial_state != null)
{
var tmp = new Tensor[] { };
foreach (var s in nest.flatten(States))
{
tmp.add(tf.math.count_nonzero((Tensor)s));
}
var non_zero_count = tf.add_n(tmp);
//initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
if((int)non_zero_count.numpy() > 0)
{
initial_state = States;
}
}
else
{
initial_state = States;
}

}
else if(initial_state != null)
{
initial_state = get_initial_state(inputs);
}

return (inputs, initial_state, constants);
if (initial_state.Length != States.Length)
{
throw new ValueError(
$"Layer {this} expects {States.Length} state(s), " +
$"but it received {initial_state.Length} " +
$"initial state(s). Input received: {inputs}");
}

return (inputs, initial_state, constants);
}

private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
{
if (is_ragged_input)
if (!is_ragged_input)
{
if (args.Unroll)
{
throw new ValueError("The input received contains RaggedTensors and does " +
"not support unrolling. Disable unrolling by passing " +
"`unroll=False` in the RNN Layer constructor.");
}
if (mask != null)
{
throw new ValueError($"The mask that was passed in was {mask}, which " +
"cannot be applied to RaggedTensor inputs. Please " +
"make sure that there is no mask injected by upstream " +
"layers.");
}
return;
}

if (args.Unroll)
{
throw new ValueError("The input received contains RaggedTensors and does " +
"not support unrolling. Disable unrolling by passing " +
"`unroll=False` in the RNN Layer constructor.");
}
if (mask != null)
{
throw new ValueError($"The mask that was passed in was {mask}, which " +
"cannot be applied to RaggedTensor inputs. Please " +
"make sure that there is no mask injected by upstream " +
"layers.");
}

}

void _maybe_reset_cell_dropout_mask(ILayer cell)
@@ -489,46 +506,77 @@ namespace Tensorflow.Keras.Layers.Rnn
{
throw new NotImplementedException();
}
public RNN New(LayerRnnCell cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cell = cell,
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});

public RNN New(IList<IRnnArgCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }),
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});
// 好像不能cell不能传接口类型
//public RNN New(IRnnArgCell cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });

//public RNN New(List<IRnnArgCell> cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });


protected Tensors get_initial_state(Tensor inputs)
{
Type type = cell.GetType();
MethodInfo MethodInfo = type.GetMethod("get_initial_state");

if (nest.is_nested(inputs))
{
// The input are nested sequences. Use the first element in the seq
// to get batch size and dtype.
inputs = nest.flatten(inputs)[0];
}

protected Tensor get_initial_state(Tensor inputs)
{
return _generate_zero_filled_state_for_cell(null, null);
var input_shape = tf.shape(inputs);
var batch_size = args.TimeMajor ? input_shape[1] : input_shape[0];
var dtype = inputs.dtype;
Tensor init_state;
if (MethodInfo != null)
{
init_state = (Tensor)MethodInfo.Invoke(cell, new object[] { null, batch_size, dtype });
}
else
{
init_state = RNNUtils.generate_zero_filled_state(batch_size, cell.state_size, dtype);
}

//if (!nest.is_nested(init_state))
//{
// init_state = new List<Tensor> { init_state};
//}
return new List<Tensor> { init_state };

//return _generate_zero_filled_state_for_cell(null, null);
}

Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size)


+ 59
- 0
src/TensorFlowNET.Keras/Layers/Rnn/RNNUtils.cs View File

@@ -0,0 +1,59 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Util;
using OneOf;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.Layers.Rnn
{
public class RNNUtils
{
public static Tensor generate_zero_filled_state(Tensor batch_size_tensor, StateSizeWrapper state_size, TF_DataType dtype = TF_DataType.TF_FLOAT)
{
if (batch_size_tensor == null || dtype == null)
{
throw new ValueError(
"batch_size and dtype cannot be None while constructing initial " +
$"state. Received: batch_size={batch_size_tensor}, dtype={dtype}");
}

Func<StateSizeWrapper, Tensor> create_zeros;
create_zeros = (StateSizeWrapper unnested_state_size) =>
{
var flat_dims = unnested_state_size.state_size;
//if (unnested_state_size is int[])
//{
// flat_dims = new Shape(unnested_state_size.AsT0).as_int_list();
//}
//else if (unnested_state_size.IsT1)
//{
// flat_dims = new Shape(unnested_state_size.AsT1).as_int_list();
//}
var init_state_size = batch_size_tensor.ToArray<int>().concat(flat_dims);
return tf.zeros(init_state_size, dtype: dtype);
};
//if (nest.is_nested(state_size))
//{
// return nest.map_structure(create_zeros, state_size);
//}
//else
//{
// return create_zeros(state_size);
//}
return create_zeros(state_size);
}

public static Tensor generate_zero_filled_state_for_cell(SimpleRNNCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype)
{
if (inputs != null)
{
batch_size = tf.shape(inputs)[0];
dtype = inputs.dtype;
}
return generate_zero_filled_state(batch_size, cell.state_size, dtype);
}
}
}

+ 90
- 2
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -4,6 +4,7 @@ using System.Text;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Util;

namespace Tensorflow.Keras.Layers.Rnn
{
@@ -13,10 +14,23 @@ namespace Tensorflow.Keras.Layers.Rnn
IVariableV1 kernel;
IVariableV1 recurrent_kernel;
IVariableV1 bias;
DropoutRNNCellMixin DRCMixin;
public SimpleRNNCell(SimpleRNNArgs args) : base(args)
{
this.args = args;
if (args.Units <= 0)
{
throw new ValueError(
$"units must be a positive integer, got {args.Units}");
}
this.args.Dropout = Math.Min(1f, Math.Max(0f, this.args.Dropout));
this.args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this.args.RecurrentDropout));
this.args.state_size = this.args.Units;
this.args.output_size = this.args.Units;

DRCMixin = new DropoutRNNCellMixin();
DRCMixin.dropout = this.args.Dropout;
DRCMixin.recurrent_dropout = this.args.RecurrentDropout;
}

public override void build(KerasShapesWrapper input_shape)
@@ -44,7 +58,81 @@ namespace Tensorflow.Keras.Layers.Rnn

protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
return base.Call(inputs, initial_state, training);
Console.WriteLine($"shape of input: {inputs.shape}");
Tensor states = initial_state[0];
Console.WriteLine($"shape of initial_state: {states.shape}");

var prev_output = nest.is_nested(states) ? states[0] : states;
var dp_mask = DRCMixin.get_dropout_maskcell_for_cell(inputs, training.Value);
var rec_dp_mask = DRCMixin.get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value);

Tensor h;
var ranks = inputs.rank;
//if (dp_mask != null)
if(false)
{
if (ranks > 2)
{
h = tf.linalg.tensordot(tf.multiply(inputs, dp_mask), kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
}
else
{
h = math_ops.matmul(tf.multiply(inputs, dp_mask), kernel.AsTensor());
}
}
else
{
if (ranks > 2)
{
h = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
}
else
{
h = math_ops.matmul(inputs, kernel.AsTensor());
}
}

if (bias != null)
{
h = tf.nn.bias_add(h, bias);
}

if (rec_dp_mask != null)
{
prev_output = tf.multiply(prev_output, rec_dp_mask);
}

ranks = prev_output.rank;
Console.WriteLine($"shape of h: {h.shape}");

Tensor output;
if (ranks > 2)
{
var tmp = tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
output = h + tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } })[0];
}
else
{
output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor())[0];

}
Console.WriteLine($"shape of output: {output.shape}");

if (args.Activation != null)
{
output = args.Activation.Apply(output);
}
if (nest.is_nested(states))
{
return (output, new Tensors { output });
}
return (output, output);
}

public Tensor get_initial_state(Tensors inputs, Tensor batch_size, TF_DataType dtype)
{
return RNNUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype);
}
}
}

+ 1
- 1
src/TensorflowNET.Hub/KerasLayer.cs View File

@@ -89,7 +89,7 @@ namespace Tensorflow.Hub
}
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
_check_trainability();



+ 0
- 60
test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs View File

@@ -1,60 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using Tensorflow.Keras.Callbacks;
using Tensorflow.Keras.Engine;
using static Tensorflow.KerasApi;


namespace Tensorflow.Keras.UnitTest.Callbacks
{
[TestClass]
public class EarlystoppingTest
{
[TestMethod]
// Because loading the weight variable into the model has not yet been implemented,
// so you'd better not set patience too large, because the weights will equal to the last epoch's weights.
public void Earlystopping()
{
var layers = keras.layers;
var model = keras.Sequential(new List<ILayer>
{
layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)),
layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation: keras.activations.Relu),
layers.Dense(10)
});


model.summary();

model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
metrics: new[] { "acc" });

var num_epochs = 3;
var batch_size = 8;

var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
x_train = x_train / 255.0f;
// define a CallbackParams first, the parameters you pass al least contain Model and Epochs.
CallbackParams callback_parameters = new CallbackParams
{
Model = model,
Epochs = num_epochs,
};
// define your earlystop
ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy");
// define a callbcaklist, then add the earlystopping to it.
var callbacks = new List<ICallback>();
callbacks.add(earlystop);

model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs, callbacks: callbacks);
}

}


}


+ 11
- 0
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -144,6 +144,17 @@ namespace Tensorflow.Keras.UnitTest.Layers
Assert.AreEqual(expected_output, actual_output);
}

[TestMethod]
public void SimpleRNNCell()
{
var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
var x = tf.random.normal(new Shape(4, 100));
var cell = keras.layers.SimpleRNNCell(64);
var (y, h1) = cell.Apply(inputs:x, state:h0);
Assert.AreEqual((4, 64), y.shape);
Assert.AreEqual((4, 64), h1[0].shape);
}

[TestMethod, Ignore("WIP")]
public void SimpleRNN()
{


+ 4
- 0
test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj View File

@@ -67,4 +67,8 @@
</None>
</ItemGroup>

<ItemGroup>
<Folder Include="Callbacks\" />
</ItemGroup>

</Project>

Loading…
Cancel
Save