Browse Source

feat: support simple RNN.

pull/1106/head
Yaohui Liu Wanglongzhi2001 2 years ago
parent
commit
3de5c18b76
9 changed files with 507 additions and 414 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs
  2. +1
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  3. +98
    -19
      src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
  4. +372
    -349
      src/TensorFlowNET.Keras/BackendImpl.cs
  5. +13
    -11
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  6. +1
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs
  7. +18
    -32
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  8. +1
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
  9. +2
    -1
      src/TensorflowNET.Hub/KerasLayer.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs View File

@@ -9,11 +9,11 @@ namespace Tensorflow.Keras.Layers.Rnn
{
GeneralizedTensorShape StateSize { get; }
GeneralizedTensorShape OutputSize { get; }
bool IsTFRnnCell { get; }
/// <summary>
/// Whether the optional RNN args are supported when appying the layer.
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`.
/// </summary>
bool SupportOptionalArgs { get; }
(Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null);
}
}

+ 1
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -183,6 +183,7 @@ namespace Tensorflow
}
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
public bool IsTFRnnCell => throw new NotImplementedException();
public bool SupportOptionalArgs => throw new NotImplementedException();
}
}

+ 98
- 19
src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Framework;
using static Tensorflow.Binding;

@@ -48,6 +49,7 @@ namespace Tensorflow.Operations
public override Tensor flow => _flow;
bool _clear_after_read;
List<Tensor> _tensor_array;
List<int> _previous_read_indices;

public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false,
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
@@ -61,16 +63,20 @@ namespace Tensorflow.Operations
_dtype = dtype.as_base_dtype();
_dynamic_size = dynamic_size;
_clear_after_read = clear_after_read;
_tensor_array = new List<Tensor>();
_tensor_array = Enumerable.Repeat<Tensor>(null, size.numpy()).ToList();
_previous_read_indices = new();
}

public override TensorArray unstack(Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
var tensors = array_ops.unstack(value, name: name);
if(tensors.Length > _tensor_array.Count && !_dynamic_size)
{
var num_elements = array_ops.shape(value)[0];
return scatter(indices: math_ops.range(0, num_elements), value: value, name: name);
});
throw new ValueError($"Cannot unstack {tensors.Length} tensors into a TensorArray of static size {_tensor_array.Count}");
}
_tensor_array = tensors.ToList();
// TODO(Rinne): revise the implementation. Here we should return `parent()`.
return this;
}

public TensorArray scatter(Tensor indices, Tensor value, string name = null)
@@ -116,9 +122,19 @@ namespace Tensorflow.Operations
_colocate_with.Add(value);
}

private Tensor _maybe_zero(int ix)
{
var val = _tensor_array[ix];
if(val is null)
{
val = _tensor_array[ix] = array_ops.zeros(_element_shape, _dtype);
}
return val;
}

public override Tensor read<T>(T index, string name = null)
{
int index_int = -1;
int index_int;
if (index is int int_index)
index_int = int_index;
else if (index is Tensor tensor_index)
@@ -126,27 +142,75 @@ namespace Tensorflow.Operations
else
throw new ValueError("");

if(index_int >= _tensor_array.Count)
{
throw new OutOfRangeError($"Tried to read from index {index_int} but array size is: {_tensor_array.Count} ");
}

var res = _tensor_array[index_int];
if(res is null)
{
if (_previous_read_indices.Contains(index_int))
{
throw new InvalidArgumentError($"Could not read index {index_int} twice because it was cleared after " +
$"a previous read (perhaps try setting clear_after_read = false?)");
}
else
{
res = _maybe_zero(index_int);
}
}

if (_clear_after_read)
{
_tensor_array[index_int] = null;
_previous_read_indices.Add(index_int);
}

return _tensor_array[index_int];
return res;
}

public override TensorArray write(Tensor index, Tensor value, string name = null)
{
if (_infer_shape)
_element_shape = _element_shape.merge_with(value.shape);
_tensor_array.add(value);
return this;
int index_int;
if(index is EagerTensor eager)
{
return write<Tensor>(eager.numpy(), value, name);
}
throw new InvalidArgumentError("The index is supposed to be an EagerTensor");
}

public override TensorArray write<T>(int index, T value, string name = null)
{
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
var index_tensor = ops.convert_to_tensor(index, name: "index");
return write(index_tensor, value_tensor, name: name);
int size = _tensor_array.Count;
if(index >= size)
{
if (!_dynamic_size)
{
throw new OutOfRangeError($"Tried to write to index {index} but array is not resizeable and size " +
$"is: {size} ");
}
_tensor_array.AddRange(Enumerable.Repeat<Tensor>(null, index - size + 1));
}

Tensor tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
if(_dtype != tensor.dtype)
{
throw new InvalidArgumentError($"TensorArray dtype is {_dtype.as_python_name()} but Op is " +
$"trying to write dtype {tensor.dtype.as_python_name()} ");
}

if (!_element_shape.is_compatible_with(tensor.shape))
{
throw new ValueError($"Incompatible shape for value ({tensor.shape}), expected ({_element_shape})");
}

if (_infer_shape)
{
_element_shape = _element_shape.merge_with(tensor.shape);
}
_tensor_array[index] = tensor;
return this;
}

private Tensor size(string name = null)
@@ -156,11 +220,26 @@ namespace Tensorflow.Operations

public override Tensor stack(string name = null)
{
ops.colocate_with(_handle);
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
if(_tensor_array.Count > 0)
{
return gather(math_ops.range(0, size()), name: name);
});
for(int i = 0; i < _tensor_array.Count; i++)
{
_maybe_zero(i);
}
}
if(_tensor_array.Count == 0 && _element_shape.IsFullyDefined)
{
return ops.convert_to_tensor(new Shape(new long[] { 0 }.Concat(_element_shape.dims).ToArray()), name: name, dtype: _dtype);
}
else
{
return ops.convert_to_tensor(_tensor_array, name: name, dtype: _dtype);
}
//ops.colocate_with(_handle);
//return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
//{
// return gather(math_ops.range(0, size()), name: name);
//});
}

public override Tensor gather(Tensor indices, string name = null)


+ 372
- 349
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -20,9 +20,11 @@ using System.Linq;
using System.Collections.Generic;
using Tensorflow.Functions;
using Tensorflow.Graphs;
using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;
using static Tensorflow.Graphs.SubGraphUtility;
using Tensorflow.Util;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras
{
@@ -452,7 +454,7 @@ namespace Tensorflow.Keras
return x;
}

public static (Tensors, Tensors, Tensors) rnn(
public (Tensors, Tensors, Tensors) rnn(
Func<Tensors, Tensors, (Tensors, Tensors)> step_function, // args:inputs, states, return:output, new_states
Tensors inputs, // inputs is a tuple of tensors (one per input sequence)
Tensors initial_states,
@@ -466,7 +468,7 @@ namespace Tensorflow.Keras
bool return_all_outputs = true)
{

Tensors swap_batch_timestep(Tensors input_t)
Tensor swap_batch_timestep(Tensor input_t)
{
var axes = Enumerable.Range(0, input_t.rank).ToArray();
axes[0] = 1;
@@ -476,13 +478,14 @@ namespace Tensorflow.Keras

if (!time_major)
{
inputs = nest.map_structure(swap_batch_timestep, inputs);
inputs = Nest.MapStructure(swap_batch_timestep, inputs).ToTensors();
}

var flatted_inptus = nest.flatten(inputs);
var time_steps = flatted_inptus[0].shape[0];
var batch = flatted_inptus[0].shape[1];
var time_step_t = tf.shape(flatted_inptus[0])[0];
var flatted_inptus = Nest.Flatten(inputs).ToList();
var first_flatted_input = flatted_inptus[0];
var time_steps = first_flatted_input.shape[0];
var batch = first_flatted_input.shape[1];
var time_steps_t = (int)first_flatted_input.shape[0];

foreach (var input_ in flatted_inptus)
{
@@ -508,11 +511,6 @@ namespace Tensorflow.Keras

}

if (constants == null)
{
constants = new List<Tensor>();
}

// tf.where needs its condition tensor to be the same shape as its two
// result tensors, but in our case the condition (mask) tensor is
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
@@ -522,12 +520,12 @@ namespace Tensorflow.Keras

Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
{
if (nest.is_nested(mask_t))
if (!mask_t.IsSingle())
{
throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}");
}

if (nest.is_nested(input_t))
if (!input_t.IsSingle())
{
throw new ValueError($"input_t is expected to be tensor, but got {input_t}");
}
@@ -575,21 +573,21 @@ namespace Tensorflow.Keras



Tensors _process_single_input_t(Tensors input_t)
Tensors _process_single_input_t(Tensor input_t)
{
input_t = tf.unstack(input_t); // unstack for time_step dim
var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim
if (go_backwards)
{
input_t.Reverse();
unstaked_input_t = unstaked_input_t.Reverse().ToArray();
}
return input_t;
return unstaked_input_t;
}

// TODO(Wanglongzhi2001)
Tensors processed_input;
if (nest.is_nested(inputs))
if (!inputs.IsSingle())
{
processed_input = nest.map_structure(_process_single_input_t, inputs);
processed_input = inputs.MapStructure(_process_single_input_t).ReduceTo<Tensors, Tensor>().ToTensors();
}
else
{
@@ -603,334 +601,339 @@ namespace Tensorflow.Keras
{
inp.Add(t_[time]);
}
return nest.pack_sequence_as(inputs, inp);
return Nest.PackSequenceAs(inputs, inp);
}

//if (mask != null)
//{
// var mask_list = tf.unstack(mask);
// if (go_backwards)
// {
// mask_list.Reverse();
// }

// for (int i = 0; i < time_steps; i++)
// {
// // TODO(Wanglongzhi2001),deal with _get_input_tensor
// var inp = _get_input_tensor(i);
// var mask_t = mask_list[i];
// // TODO
// var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants });

// var tiled_mask_t = _expand_mask(mask_t, output);

// Tensors prev_output;
// if (successive_outputs == null)
// {
// prev_output = tf.zeros_like(output);
// }
// else
// {
// prev_output = successive_outputs[successive_outputs.Length - 1];
// }

// output = tf.where(tiled_mask_t, output, prev_output);

// //var flat_states = nest.flatten(states);
// //var flat_new_states = nest.flatten(newStates);
// var flat_states = states.ToList();
// var flat_new_states = newStates.ToList();

// var tiledMaskT = flat_states
// .Select(s => _expand_mask(mask_t, s))
// .ToArray();
// var tuple = Tuple.Create(tiledMaskT);

// List<Tensor> flat_final_states = new List<Tensor>();
// foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states))
// {
// flat_final_states.Add(tf.where(m, s, ps));
// }

// states = (Tensors)nest.pack_sequence_as(states, flat_final_states);
// if (return_all_outputs)
// {
// successive_outputs.Add(output);
// successive_states.Add(states);
// }
// else
// {
// successive_outputs = new Tensors { output };
// successive_states = new Tensors { states };
// }

// }
// last_output = successive_outputs[successive_outputs.Length - 1];
// new_states = successive_states[successive_states.Length - 1];
// outputs = tf.stack(successive_outputs);

// if (zero_output_for_mask)
// {
// last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
// outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
// }
// else // mask is null
// {
// for (int i = 0; i < time_steps; i++)
// {
// var inp = _get_input_tensor(i);
// var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants });
// states = newStates;

// if (return_all_outputs)
// {
// successive_outputs.Add(output);
// successive_states.Add(newStates);
// }
// else
// {
// successive_outputs = new Tensors { output };
// successive_states = new Tensors { newStates };
// }
// }
// last_output = successive_outputs[successive_outputs.Length - 1];
// new_states = successive_states[successive_states.Length - 1];
// outputs = tf.stack(successive_outputs);
// }
//}
if (mask != null)
{
var mask_list = tf.unstack(mask);
if (go_backwards)
{
mask_list.Reverse();
}

for (int i = 0; i < time_steps; i++)
{
// TODO(Wanglongzhi2001),deal with _get_input_tensor
var inp = _get_input_tensor(i);
var mask_t = mask_list[i];
// TODO
var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants));

var tiled_mask_t = _expand_mask(mask_t, output);

Tensors prev_output;
if (successive_outputs == null)
{
prev_output = tf.zeros_like(output);
}
else
{
prev_output = successive_outputs[successive_outputs.Length - 1];
}

output = tf.where(tiled_mask_t, output, prev_output);

var flat_states = Nest.Flatten(states).ToList();
var flat_new_states = Nest.Flatten(newStates).ToList();

var tiledMaskT = flat_states
.Select(s => _expand_mask(mask_t, s))
.ToArray();
var tuple = Tuple.Create(tiledMaskT);

List<Tensor> flat_final_states = new List<Tensor>();
foreach (var (m, s, ps) in zip(tiled_mask_t.ToList(), flat_new_states, flat_states))
{
flat_final_states.Add(tf.where(m, s, ps));
}

states = Nest.PackSequenceAs(states, flat_final_states).ToTensors();
if (return_all_outputs)
{
successive_outputs.Add(output);
successive_states.Add(states);
}
else
{
successive_outputs = new Tensors { output };
successive_states = new Tensors { states };
}

}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
outputs = tf.stack(successive_outputs);

if (zero_output_for_mask)
{
last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
}
else // mask is null
{
for (int i = 0; i < time_steps; i++)
{
var inp = _get_input_tensor(i);
var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants));
states = newStates;

if (return_all_outputs)
{
successive_outputs.Add(output);
successive_states.Add(newStates);
}
else
{
successive_outputs = new Tensors { output };
successive_states = new Tensors { newStates };
}
}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
outputs = tf.stack(successive_outputs);
}
}
}
else // unroll == false
{
var states = initial_states;
// Create input tensor array, if the inputs is nested tensors, then it
// will be flattened first, and tensor array will be created one per
// flattened tensor.
var input_ta = new List<TensorArray>();
for (int i = 0; i < flatted_inptus.Count; i++)
{
input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t));
}

foreach(var (ta, input_) in zip(input_ta, flatted_inptus))
{
if (!go_backwards)
{
ta.unstack(input_);
}
else
{
ta.unstack(reverse(input_, 0));
}
}

// Get the time(0) input and compute the output for that, the output will
// be used to determine the dtype of output tensor array. Don't read from
// input_ta due to TensorArray clear_after_read default to True.
var inps = new Tensors();
foreach (var inp in flatted_inptus)
{
inps.Add(inp[0]);
}
var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors();

// output_time_zero is used to determine the cell output shape and its
// dtype. the value is discarded.
(output_time_zero, _) = step_function((Tensor)input_time_zero,
constants is null ? initial_states : initial_states.MergeWith(constants));

int output_ta_size = return_all_outputs ? time_steps_t : 1;
var output_ta = new List<TensorArray>();
for (int i = 0; i < output_time_zero.ToList().Count; i++)
{
var Out = output_time_zero.ToList()[i];
output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape));
}

var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time");



Func<Tensor, Tensor>? masking_fn;
Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null;
if (mask != null)
{
if (go_backwards)
{
mask = tf.reverse(mask, axis: new[] { 0 });
}
var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t);
mask_ta = mask_ta.unstack(mask);

masking_fn = (time) =>
{
return mask_ta.read(time);
};

compute_masked_output = (mask_t, flat_out, flat_mask) =>
{
var tiled_mask_t = new Tensors();
foreach (var o in flat_out)
{
tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank));
}

Tensors res = new Tensors();
foreach (var (m, o, fm) in zip(tiled_mask_t.ToList(), flat_out.ToList(), flat_mask.ToList()))
{
res.Add(tf.where(m, o, fm));
}
return res;
};
}
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
else if (input_length is Tensor)
{
if (go_backwards)
{
var max_len = tf.reduce_max(input_length, axis: 0);
var rev_input_length = tf.subtract(max_len - 1, input_length);

masking_fn = (time) =>
{
return tf.less(rev_input_length, time);
};
}
else
{
masking_fn = (time) =>
{
return tf.greater(input_length, time);
};
}

compute_masked_output = (mask_t, flat_out, flat_mask) =>
{
var res = new List<Tensor>();
foreach (var (o, zo) in zip(flat_out, flat_mask))
{
res.Add(tf.where(mask_t, o, zo));
}
return res;
};
}
else
{
masking_fn = null;
}

Func<Tensor, Tensor> cond = (time) => (time < time_steps_t);
int parallel_iterations = 32;
if (masking_fn != null)
{
// Mask for the T output will be base on the output of T - 1. In the
// case T = 0, a zero filled tensor will be used.
var flat_zero_output = new Tensors();
foreach (var o in Nest.Flatten(output_time_zero))
{
flat_zero_output.Add(tf.zeros_like(o));
}

var prev_output = flat_zero_output;
var output_ta_t = output_ta;
Tensor _step(Tensor time)
{
/*
RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
prev_output: tuple of outputs from time - 1.
*states: List of states.
Returns:
Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)`
*/

var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var mask_t = masking_fn(time);
var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants));
// mask output
var flat_output = Nest.Flatten(output).ToList();

var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList();

// TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type
var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output);

// mask states
var flat_state = states.ToList();
var flat_new_state = new_states_internal.ToList();

foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
if (new_state is Tensor)
{
new_state.shape = state.shape;
}
}

var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state);
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors();

var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
// TODO(Wanglongzhi2001),deal with zip output_ta_t
foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
{
output_ta_t.Add(ta.write(ta_index_to_write, Out));
}

new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();

output_ta = output_ta_t;
new_states = new_states_internal;
return time + 1;

}
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
}
else
{
var output_ta_t = output_ta;
new_states = states;
Tensor _step(Tensor time)
{
var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
var flat_state = new_states.Flatten().ToList();
var flat_new_state = new_states_internal.Flatten().ToList();
foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
if (new_state is Tensor)
{
new_state.shape = state.shape;
}
}
var flat_output = Nest.Flatten(output);
var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
output_ta_t = zip(output_ta_t, flat_output).Select(item =>
{
var (ta, out_) = item;
return ta.write(ta_index_to_write, out_);
}).ToList();

new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
output_ta = output_ta_t;
new_states = new_states_internal;
return time + 1;
}
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
}
//Tensors outputs = new Tensors();
foreach (var o in output_ta)
{
outputs.Add(o.stack());
}
foreach (var o in outputs)
{
last_output.Add(o[-1]);
}
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors();
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors();

}
//else // unroll == false
//{
// var states = initial_states;
// // Create input tensor array, if the inputs is nested tensors, then it
// // will be flattened first, and tensor array will be created one per
// // flattened tensor.
// var input_ta = new List<TensorArray>();
// for (int i = 0; i < flatted_inptus.Count; i++)
// {
// input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t));
// }

// // Get the time(0) input and compute the output for that, the output will
// // be used to determine the dtype of output tensor array. Don't read from
// // input_ta due to TensorArray clear_after_read default to True.
// var inps = new Tensors();
// foreach (var inp in flatted_inptus)
// {
// inps.Add(inp[0]);
// }
// var input_time_zero = nest.pack_sequence_as(inputs, inps);

// // output_time_zero is used to determine the cell output shape and its
// // dtype. the value is discarded.
// (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants });

// var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1);
// var output_ta = new List<TensorArray>();
// for (int i = 0; i < output_time_zero.ToList().Count; i++)
// {
// var Out = output_time_zero.ToList()[i];
// output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape));
// }

// var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time");



// Func<Tensor, Tensor>? masking_fn;
// Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null;
// if (mask != null)
// {
// if (go_backwards)
// {
// mask = tf.reverse(mask, axis: new[] { 0 });
// }
// var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t);
// mask_ta = mask_ta.unstack(mask);

// masking_fn = (time) =>
// {
// return mask_ta.read(time);
// };

// compute_masked_output = (mask_t, flat_out, flat_mask) =>
// {
// var tiled_mask_t = new Tensors();
// foreach (var o in flat_out)
// {
// tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank));
// }

// Tensors res = new Tensors();
// foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask))
// {
// res.Add(tf.where(m, o, fm));
// }
// return res;
// };
// }
// // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
// else if (input_length is Tensor)
// {
// if (go_backwards)
// {
// var max_len = tf.reduce_max(input_length, axis: 0);
// var rev_input_length = tf.subtract(max_len - 1, input_length);

// masking_fn = (time) =>
// {
// return tf.less(rev_input_length, time);
// };
// }
// else
// {
// masking_fn = (time) =>
// {
// return tf.greater(input_length, time);
// };
// }

// compute_masked_output = (mask_t, flat_out, flat_mask) =>
// {
// var res = new List<Tensor>();
// foreach (var (o, zo) in zip(flat_out, flat_mask))
// {
// res.Add(tf.where(mask_t, o, zo));
// }
// return res;
// };
// }
// else
// {
// masking_fn = null;
// }


// if (masking_fn != null)
// {
// // Mask for the T output will be base on the output of T - 1. In the
// // case T = 0, a zero filled tensor will be used.
// var flat_zero_output = new Tensors();
// foreach (var o in nest.flatten(output_time_zero))
// {
// flat_zero_output.Add(tf.zeros_like(o));
// }


// (Tensor, List<TensorArray>, Tensors, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors prev_output, Tensors states)
// {
// /*
// RNN step function.
// Args:
// time: Current timestep value.
// output_ta_t: TensorArray.
// prev_output: tuple of outputs from time - 1.
// *states: List of states.
// Returns:
// Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)`
// */

// var current_input = input_ta.Select(x => x.read(time)).ToList();
// // maybe set shape
// // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
// current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input);
// var mask_t = masking_fn(time);
// var (output, new_states) = step_function(current_input, new Tensors { states, constants });
// // mask output
// //var flat_output = nest.flatten(output);
// var flat_output = output.ToList();

// var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList();

// // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type
// var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output);

// // mask states
// var flat_state = states.ToList();
// var flat_new_state = new_states.ToList();

// foreach (var (state, new_state) in zip(flat_state, flat_new_state))
// {
// if (new_state is Tensor)
// {
// new_state.set_shape(state.shape);
// }
// }

// var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state);
// new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state);

// var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
// var Output_ta_t = new List<TensorArray>();
// // TODO(Wanglongzhi2001),deal with zip output_ta_t
// foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
// {
// Output_ta_t.Add(ta.write(ta_index_to_write, Out));
// }



// //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state);


// return (time + 1, Output_ta_t, flat_new_output, new_states);

// }
// Func<Tensor, Tensor> cond = (time) => (time < time_step_t);

// var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states));
// new_states = final_outputs.Item4;
// output_ta = final_outputs.Item2;

// }
// else
// {
// (Tensor, List<TensorArray>, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors states)
// {
// var current_input = input_ta.Select(x => x.read(time)).ToList();
// // maybe set shape
// // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
// current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input);
// var (output, new_states) = step_function(current_input, new Tensors { states, constants });
// var flat_state = states.ToList();
// var flat_new_state = new_states.ToList();
// foreach (var (state, new_state) in zip(flat_state, flat_new_state))
// {
// if (new_state is Tensor)
// {
// new_state.set_shape(state.shape);
// }
// }
// var flat_output = output.ToList();
// var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
// var Output_ta_t = new List<TensorArray>();
// foreach (var (ta, out_) in zip(output_ta_t, flat_output))
// {
// Output_ta_t.Add(ta.write(ta_index_to_write, out_));
// }

// new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state);
// return (time + 1, Output_ta_t, new_states);
// }
// Func<Tensor, Tensor> cond = (time) => (time < time_step_t);
// var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states));
// new_states = final_outputs.Item3;
// output_ta = final_outputs.Item2;

// }
// //Tensors outputs = new Tensors();
// foreach (var o in output_ta)
// {
// outputs.Add(o.stack());
// }
// foreach (var o in outputs)
// {
// last_output.Add(o[-1]);
// }
// outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs);
// last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output);

//}

Func<Tensor, Tensor> set_shape;
set_shape = (output_) =>
@@ -947,18 +950,38 @@ namespace Tensorflow.Keras
shape[0] = 1;
}
shape[1] = (int)batch;
output_.set_shape(new Tensor(shape));
output_.shape = shape;
}
return output_;
};

var Outputs = (Tensors)nest.map_structure(set_shape, outputs);
outputs = Nest.MapStructure(set_shape, outputs).ToTensors();
if (!time_major)
{
Outputs = nest.map_structure(swap_batch_timestep, outputs);
outputs = Nest.MapStructure(swap_batch_timestep, outputs).ToTensors();
}
return (last_output, outputs, new_states);

}

public Tensor reverse(Tensor input, int axis)
{
return reverse(input, new int[] { axis });
}

public Tensor reverse(Tensor input, int[] axes)
{
return tf.reverse(input, axes);
}

public Tensor maybe_convert_to_ragged(bool is_ragged_output, Tensor output, int nested_row_lengths, bool go_backwards = false)
{
if (!is_ragged_output)
{
return output;
}
return (last_output, Outputs, new_states);

throw new NotImplementedException("Not implemented currently, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}
}
}

+ 13
- 11
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -55,8 +55,8 @@ namespace Tensorflow.Keras.Layers.Rnn
if (_states == null)
{
// CHECK(Rinne): check if this is correct.
var state = nest.map_structure(x => null, _cell.StateSize);
return new Tensors { state };
var nested = _cell.StateSize.MapStructure<Tensor?>(x => null);
_states = nested.AsNest().ToTensors();
}
return _states;
}
@@ -230,7 +230,7 @@ namespace Tensorflow.Keras.Layers.Rnn
Tensors? mask = rnn_optional_args?.Mask;
//var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
// 暂时先不接受ragged tensor
int? row_length = null;
int row_length = 0; // TODO(Rinne): support this param.
bool is_ragged_input = false;
_validate_args_if_ragged(is_ragged_input, mask);

@@ -249,16 +249,16 @@ namespace Tensorflow.Keras.Layers.Rnn
if (mask != null)
{
// Time step masks must be the same for each input.
mask = nest.flatten(mask)[0];
mask = mask.Flatten().First();
}

Shape input_shape;
if (nest.is_nested(inputs))
if (!inputs.IsSingle())
{
// In the case of nested input, use the first element for shape check
// input_shape = nest.flatten(inputs)[0].shape;
// TODO(Wanglongzhi2001)
input_shape = nest.flatten(inputs)[0].shape;
input_shape = inputs.Flatten().First().shape;
}
else
{
@@ -286,6 +286,7 @@ namespace Tensorflow.Keras.Layers.Rnn

// cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call)
Func<Tensors, Tensors, (Tensors, Tensors)> step;
bool is_tf_rnn_cell = _cell.IsTFRnnCell;
if (constants is not null)
{
if (!_cell.SupportOptionalArgs)
@@ -299,7 +300,8 @@ namespace Tensorflow.Keras.Layers.Rnn
{
constants = new Tensors(states.TakeLast(_num_constants));
states = new Tensors(states.SkipLast(_num_constants));
var(output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
// TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors?
return (output, new_states.Single);
};
@@ -308,13 +310,13 @@ namespace Tensorflow.Keras.Layers.Rnn
{
step = (inputs, states) =>
{
// states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states)
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = _cell.Apply(inputs, states);
return (output, new_states.Single);
};
}

var (last_output, outputs, states) = BackendImpl.rnn(step,
var (last_output, outputs, states) = keras.backend.rnn(step,
inputs,
initial_state,
constants: constants,
@@ -334,8 +336,8 @@ namespace Tensorflow.Keras.Layers.Rnn
Tensors output = new Tensors();
if (_args.ReturnSequences)
{
throw new NotImplementedException("this argument havn't been developed.");
// TODO(Rinne): add go_backwards parameter and revise the `row_length` param
output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false);
}
else
{


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs View File

@@ -14,8 +14,8 @@ namespace Tensorflow.Keras.Layers.Rnn
public RnnCellBase(LayerArgs args) : base(args) { }
public abstract GeneralizedTensorShape StateSize { get; }
public abstract GeneralizedTensorShape OutputSize { get; }
public abstract bool IsTFRnnCell { get; }
public abstract bool SupportOptionalArgs { get; }
public abstract (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null);
public virtual Tensors GetInitialState(Tensors inputs, long batch_size, TF_DataType dtype)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype);


+ 18
- 32
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -5,6 +5,7 @@ using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;
using Tensorflow.Common.Extensions;

namespace Tensorflow.Keras.Layers.Rnn
{
@@ -26,6 +27,7 @@ namespace Tensorflow.Keras.Layers.Rnn

public override GeneralizedTensorShape StateSize => _state_size;
public override GeneralizedTensorShape OutputSize => _output_size;
public override bool IsTFRnnCell => true;
public override bool SupportOptionalArgs => false;

public SimpleRNNCell(SimpleRNNCellArgs args) : base(args)
@@ -66,37 +68,22 @@ namespace Tensorflow.Keras.Layers.Rnn
built = true;
}

public override (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null)
// TODO(Rinne): revise the trining param (with refactoring of the framework)
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
{
// TODO(Rinne): check if it will have multiple tensors when not nested.
Tensor prev_output = states[0];
Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states;
var dp_mask = get_dropout_maskcell_for_cell(inputs, training.Value);
var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value);

Tensor h;
var ranks = inputs.rank;
if (dp_mask != null)
{
if (ranks > 2)
{
// 因为multiply函数会自动添加第一个维度,所以加上下标0
h = tf.linalg.tensordot(math_ops.multiply(inputs, dp_mask)[0], _kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
}
else
{
h = math_ops.matmul(math_ops.multiply(inputs, dp_mask)[0], _kernel.AsTensor());
}
h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor());
}
else
{
if (ranks > 2)
{
h = tf.linalg.tensordot(inputs, _kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
}
else
{
h = math_ops.matmul(inputs, _kernel.AsTensor());
}
h = math_ops.matmul(inputs, _kernel.AsTensor());
}

if (_bias != null)
@@ -106,26 +93,25 @@ namespace Tensorflow.Keras.Layers.Rnn

if (rec_dp_mask != null)
{
prev_output = math_ops.multiply(prev_output, rec_dp_mask)[0];
prev_output = math_ops.multiply(prev_output, rec_dp_mask);
}

ranks = prev_output.rank;
Tensor output;
if (ranks > 2)
Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor());
if (_args.Activation != null)
{
output = h + tf.linalg.tensordot(prev_output[0], _recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
output = _args.Activation.Apply(output);
}
else
if (Nest.IsNested(states))
{
output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor());
return new Nest<Tensor>(new List<Nest<Tensor>> {
new Nest<Tensor>(new List<Nest<Tensor>> { new Nest<Tensor>(output) }), new Nest<Tensor>(output) })
.ToTensors();
}
Console.WriteLine($"shape of output: {output.shape}");

if (_args.Activation != null)
else
{
output = _args.Activation.Apply(output);
return new Tensors(output, output);
}
return (output, new Tensors { output });
}
}
}

+ 1
- 0
src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs View File

@@ -170,6 +170,7 @@ namespace Tensorflow.Keras.Layers.Rnn
}
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
public bool IsTFRnnCell => throw new NotImplementedException();
public bool SupportOptionalArgs => throw new NotImplementedException();
}
}

+ 2
- 1
src/TensorflowNET.Hub/KerasLayer.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Engine;
using Tensorflow.Train;
using Tensorflow.Training;
@@ -89,7 +90,7 @@ namespace Tensorflow.Hub
}
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optionalArgs = null)
{
_check_trainability();



Loading…
Cancel
Save