diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs
index df6222cd..d12ed1ad 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs
@@ -9,11 +9,11 @@ namespace Tensorflow.Keras.Layers.Rnn
{
GeneralizedTensorShape StateSize { get; }
GeneralizedTensorShape OutputSize { get; }
+ bool IsTFRnnCell { get; }
///
/// Whether the optional RNN args are supported when appying the layer.
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`.
///
bool SupportOptionalArgs { get; }
- (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
index 71fdc301..26646b76 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
@@ -183,6 +183,7 @@ namespace Tensorflow
}
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
+ public bool IsTFRnnCell => throw new NotImplementedException();
public bool SupportOptionalArgs => throw new NotImplementedException();
}
}
diff --git a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
index cf1b50af..ed65a08d 100644
--- a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
+++ b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using Tensorflow.Eager;
using Tensorflow.Framework;
using static Tensorflow.Binding;
@@ -48,6 +49,7 @@ namespace Tensorflow.Operations
public override Tensor flow => _flow;
bool _clear_after_read;
List _tensor_array;
+ List _previous_read_indices;
public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false,
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
@@ -61,16 +63,20 @@ namespace Tensorflow.Operations
_dtype = dtype.as_base_dtype();
_dynamic_size = dynamic_size;
_clear_after_read = clear_after_read;
- _tensor_array = new List();
+ _tensor_array = Enumerable.Repeat(null, size.numpy()).ToList();
+ _previous_read_indices = new();
}
public override TensorArray unstack(Tensor value, string name = null)
{
- return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
+ var tensors = array_ops.unstack(value, name: name);
+ if(tensors.Length > _tensor_array.Count && !_dynamic_size)
{
- var num_elements = array_ops.shape(value)[0];
- return scatter(indices: math_ops.range(0, num_elements), value: value, name: name);
- });
+ throw new ValueError($"Cannot unstack {tensors.Length} tensors into a TensorArray of static size {_tensor_array.Count}");
+ }
+ _tensor_array = tensors.ToList();
+ // TODO(Rinne): revise the implementation. Here we should return `parent()`.
+ return this;
}
public TensorArray scatter(Tensor indices, Tensor value, string name = null)
@@ -116,9 +122,19 @@ namespace Tensorflow.Operations
_colocate_with.Add(value);
}
+ private Tensor _maybe_zero(int ix)
+ {
+ var val = _tensor_array[ix];
+ if(val is null)
+ {
+ val = _tensor_array[ix] = array_ops.zeros(_element_shape, _dtype);
+ }
+ return val;
+ }
+
public override Tensor read(T index, string name = null)
{
- int index_int = -1;
+ int index_int;
if (index is int int_index)
index_int = int_index;
else if (index is Tensor tensor_index)
@@ -126,27 +142,75 @@ namespace Tensorflow.Operations
else
throw new ValueError("");
+ if(index_int >= _tensor_array.Count)
+ {
+ throw new OutOfRangeError($"Tried to read from index {index_int} but array size is: {_tensor_array.Count} ");
+ }
+
+ var res = _tensor_array[index_int];
+ if(res is null)
+ {
+ if (_previous_read_indices.Contains(index_int))
+ {
+ throw new InvalidArgumentError($"Could not read index {index_int} twice because it was cleared after " +
+ $"a previous read (perhaps try setting clear_after_read = false?)");
+ }
+ else
+ {
+ res = _maybe_zero(index_int);
+ }
+ }
+
if (_clear_after_read)
{
_tensor_array[index_int] = null;
+ _previous_read_indices.Add(index_int);
}
-
- return _tensor_array[index_int];
+ return res;
}
public override TensorArray write(Tensor index, Tensor value, string name = null)
{
- if (_infer_shape)
- _element_shape = _element_shape.merge_with(value.shape);
- _tensor_array.add(value);
- return this;
+ int index_int;
+ if(index is EagerTensor eager)
+ {
+ return write(eager.numpy(), value, name);
+ }
+ throw new InvalidArgumentError("The index is supposed to be an EagerTensor");
}
public override TensorArray write(int index, T value, string name = null)
{
- var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
- var index_tensor = ops.convert_to_tensor(index, name: "index");
- return write(index_tensor, value_tensor, name: name);
+ int size = _tensor_array.Count;
+ if(index >= size)
+ {
+ if (!_dynamic_size)
+ {
+ throw new OutOfRangeError($"Tried to write to index {index} but array is not resizeable and size " +
+ $"is: {size} ");
+ }
+ _tensor_array.AddRange(Enumerable.Repeat(null, index - size + 1));
+ }
+
+ Tensor tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
+
+ if(_dtype != tensor.dtype)
+ {
+ throw new InvalidArgumentError($"TensorArray dtype is {_dtype.as_python_name()} but Op is " +
+ $"trying to write dtype {tensor.dtype.as_python_name()} ");
+ }
+
+ if (!_element_shape.is_compatible_with(tensor.shape))
+ {
+ throw new ValueError($"Incompatible shape for value ({tensor.shape}), expected ({_element_shape})");
+ }
+
+ if (_infer_shape)
+ {
+ _element_shape = _element_shape.merge_with(tensor.shape);
+ }
+ _tensor_array[index] = tensor;
+ return this;
}
private Tensor size(string name = null)
@@ -156,11 +220,26 @@ namespace Tensorflow.Operations
public override Tensor stack(string name = null)
{
- ops.colocate_with(_handle);
- return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
+ if(_tensor_array.Count > 0)
{
- return gather(math_ops.range(0, size()), name: name);
- });
+ for(int i = 0; i < _tensor_array.Count; i++)
+ {
+ _maybe_zero(i);
+ }
+ }
+ if(_tensor_array.Count == 0 && _element_shape.IsFullyDefined)
+ {
+ return ops.convert_to_tensor(new Shape(new long[] { 0 }.Concat(_element_shape.dims).ToArray()), name: name, dtype: _dtype);
+ }
+ else
+ {
+ return ops.convert_to_tensor(_tensor_array, name: name, dtype: _dtype);
+ }
+ //ops.colocate_with(_handle);
+ //return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
+ //{
+ // return gather(math_ops.range(0, size()), name: name);
+ //});
}
public override Tensor gather(Tensor indices, string name = null)
diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs
index a7c1bcad..30b73e82 100644
--- a/src/TensorFlowNET.Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Keras/BackendImpl.cs
@@ -20,9 +20,11 @@ using System.Linq;
using System.Collections.Generic;
using Tensorflow.Functions;
using Tensorflow.Graphs;
+using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;
using static Tensorflow.Graphs.SubGraphUtility;
using Tensorflow.Util;
+using Tensorflow.Common.Types;
namespace Tensorflow.Keras
{
@@ -452,7 +454,7 @@ namespace Tensorflow.Keras
return x;
}
- public static (Tensors, Tensors, Tensors) rnn(
+ public (Tensors, Tensors, Tensors) rnn(
Func step_function, // args:inputs, states, return:output, new_states
Tensors inputs, // inputs is a tuple of tensors (one per input sequence)
Tensors initial_states,
@@ -466,7 +468,7 @@ namespace Tensorflow.Keras
bool return_all_outputs = true)
{
- Tensors swap_batch_timestep(Tensors input_t)
+ Tensor swap_batch_timestep(Tensor input_t)
{
var axes = Enumerable.Range(0, input_t.rank).ToArray();
axes[0] = 1;
@@ -476,13 +478,14 @@ namespace Tensorflow.Keras
if (!time_major)
{
- inputs = nest.map_structure(swap_batch_timestep, inputs);
+ inputs = Nest.MapStructure(swap_batch_timestep, inputs).ToTensors();
}
- var flatted_inptus = nest.flatten(inputs);
- var time_steps = flatted_inptus[0].shape[0];
- var batch = flatted_inptus[0].shape[1];
- var time_step_t = tf.shape(flatted_inptus[0])[0];
+ var flatted_inptus = Nest.Flatten(inputs).ToList();
+ var first_flatted_input = flatted_inptus[0];
+ var time_steps = first_flatted_input.shape[0];
+ var batch = first_flatted_input.shape[1];
+ var time_steps_t = (int)first_flatted_input.shape[0];
foreach (var input_ in flatted_inptus)
{
@@ -508,11 +511,6 @@ namespace Tensorflow.Keras
}
- if (constants == null)
- {
- constants = new List();
- }
-
// tf.where needs its condition tensor to be the same shape as its two
// result tensors, but in our case the condition (mask) tensor is
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
@@ -522,12 +520,12 @@ namespace Tensorflow.Keras
Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
{
- if (nest.is_nested(mask_t))
+ if (!mask_t.IsSingle())
{
throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}");
}
- if (nest.is_nested(input_t))
+ if (!input_t.IsSingle())
{
throw new ValueError($"input_t is expected to be tensor, but got {input_t}");
}
@@ -575,21 +573,21 @@ namespace Tensorflow.Keras
- Tensors _process_single_input_t(Tensors input_t)
+ Tensors _process_single_input_t(Tensor input_t)
{
- input_t = tf.unstack(input_t); // unstack for time_step dim
+ var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim
if (go_backwards)
{
- input_t.Reverse();
+ unstaked_input_t = unstaked_input_t.Reverse().ToArray();
}
- return input_t;
+ return unstaked_input_t;
}
// TODO(Wanglongzhi2001)
Tensors processed_input;
- if (nest.is_nested(inputs))
+ if (!inputs.IsSingle())
{
- processed_input = nest.map_structure(_process_single_input_t, inputs);
+ processed_input = inputs.MapStructure(_process_single_input_t).ReduceTo().ToTensors();
}
else
{
@@ -603,334 +601,339 @@ namespace Tensorflow.Keras
{
inp.Add(t_[time]);
}
- return nest.pack_sequence_as(inputs, inp);
+ return Nest.PackSequenceAs(inputs, inp);
}
- //if (mask != null)
- //{
- // var mask_list = tf.unstack(mask);
- // if (go_backwards)
- // {
- // mask_list.Reverse();
- // }
-
- // for (int i = 0; i < time_steps; i++)
- // {
- // // TODO(Wanglongzhi2001),deal with _get_input_tensor
- // var inp = _get_input_tensor(i);
- // var mask_t = mask_list[i];
- // // TODO
- // var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants });
-
- // var tiled_mask_t = _expand_mask(mask_t, output);
-
- // Tensors prev_output;
- // if (successive_outputs == null)
- // {
- // prev_output = tf.zeros_like(output);
- // }
- // else
- // {
- // prev_output = successive_outputs[successive_outputs.Length - 1];
- // }
-
- // output = tf.where(tiled_mask_t, output, prev_output);
-
- // //var flat_states = nest.flatten(states);
- // //var flat_new_states = nest.flatten(newStates);
- // var flat_states = states.ToList();
- // var flat_new_states = newStates.ToList();
-
- // var tiledMaskT = flat_states
- // .Select(s => _expand_mask(mask_t, s))
- // .ToArray();
- // var tuple = Tuple.Create(tiledMaskT);
-
- // List flat_final_states = new List();
- // foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states))
- // {
- // flat_final_states.Add(tf.where(m, s, ps));
- // }
-
- // states = (Tensors)nest.pack_sequence_as(states, flat_final_states);
- // if (return_all_outputs)
- // {
- // successive_outputs.Add(output);
- // successive_states.Add(states);
- // }
- // else
- // {
- // successive_outputs = new Tensors { output };
- // successive_states = new Tensors { states };
- // }
-
- // }
- // last_output = successive_outputs[successive_outputs.Length - 1];
- // new_states = successive_states[successive_states.Length - 1];
- // outputs = tf.stack(successive_outputs);
-
- // if (zero_output_for_mask)
- // {
- // last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
- // outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
- // }
- // else // mask is null
- // {
- // for (int i = 0; i < time_steps; i++)
- // {
- // var inp = _get_input_tensor(i);
- // var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants });
- // states = newStates;
-
- // if (return_all_outputs)
- // {
- // successive_outputs.Add(output);
- // successive_states.Add(newStates);
- // }
- // else
- // {
- // successive_outputs = new Tensors { output };
- // successive_states = new Tensors { newStates };
- // }
- // }
- // last_output = successive_outputs[successive_outputs.Length - 1];
- // new_states = successive_states[successive_states.Length - 1];
- // outputs = tf.stack(successive_outputs);
- // }
- //}
+ if (mask != null)
+ {
+ var mask_list = tf.unstack(mask);
+ if (go_backwards)
+ {
+ mask_list.Reverse();
+ }
+
+ for (int i = 0; i < time_steps; i++)
+ {
+ // TODO(Wanglongzhi2001),deal with _get_input_tensor
+ var inp = _get_input_tensor(i);
+ var mask_t = mask_list[i];
+ // TODO
+ var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants));
+
+ var tiled_mask_t = _expand_mask(mask_t, output);
+
+ Tensors prev_output;
+ if (successive_outputs == null)
+ {
+ prev_output = tf.zeros_like(output);
+ }
+ else
+ {
+ prev_output = successive_outputs[successive_outputs.Length - 1];
+ }
+
+ output = tf.where(tiled_mask_t, output, prev_output);
+
+ var flat_states = Nest.Flatten(states).ToList();
+ var flat_new_states = Nest.Flatten(newStates).ToList();
+
+ var tiledMaskT = flat_states
+ .Select(s => _expand_mask(mask_t, s))
+ .ToArray();
+ var tuple = Tuple.Create(tiledMaskT);
+
+ List flat_final_states = new List();
+ foreach (var (m, s, ps) in zip(tiled_mask_t.ToList(), flat_new_states, flat_states))
+ {
+ flat_final_states.Add(tf.where(m, s, ps));
+ }
+
+ states = Nest.PackSequenceAs(states, flat_final_states).ToTensors();
+ if (return_all_outputs)
+ {
+ successive_outputs.Add(output);
+ successive_states.Add(states);
+ }
+ else
+ {
+ successive_outputs = new Tensors { output };
+ successive_states = new Tensors { states };
+ }
+
+ }
+ last_output = successive_outputs[successive_outputs.Length - 1];
+ new_states = successive_states[successive_states.Length - 1];
+ outputs = tf.stack(successive_outputs);
+
+ if (zero_output_for_mask)
+ {
+ last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
+ outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
+ }
+ else // mask is null
+ {
+ for (int i = 0; i < time_steps; i++)
+ {
+ var inp = _get_input_tensor(i);
+ var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants));
+ states = newStates;
+
+ if (return_all_outputs)
+ {
+ successive_outputs.Add(output);
+ successive_states.Add(newStates);
+ }
+ else
+ {
+ successive_outputs = new Tensors { output };
+ successive_states = new Tensors { newStates };
+ }
+ }
+ last_output = successive_outputs[successive_outputs.Length - 1];
+ new_states = successive_states[successive_states.Length - 1];
+ outputs = tf.stack(successive_outputs);
+ }
+ }
+ }
+ else // unroll == false
+ {
+ var states = initial_states;
+ // Create input tensor array, if the inputs is nested tensors, then it
+ // will be flattened first, and tensor array will be created one per
+ // flattened tensor.
+ var input_ta = new List();
+ for (int i = 0; i < flatted_inptus.Count; i++)
+ {
+ input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t));
+ }
+
+ foreach(var (ta, input_) in zip(input_ta, flatted_inptus))
+ {
+ if (!go_backwards)
+ {
+ ta.unstack(input_);
+ }
+ else
+ {
+ ta.unstack(reverse(input_, 0));
+ }
+ }
+
+ // Get the time(0) input and compute the output for that, the output will
+ // be used to determine the dtype of output tensor array. Don't read from
+ // input_ta due to TensorArray clear_after_read default to True.
+ var inps = new Tensors();
+ foreach (var inp in flatted_inptus)
+ {
+ inps.Add(inp[0]);
+ }
+ var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors();
+
+ // output_time_zero is used to determine the cell output shape and its
+ // dtype. the value is discarded.
+ (output_time_zero, _) = step_function((Tensor)input_time_zero,
+ constants is null ? initial_states : initial_states.MergeWith(constants));
+
+ int output_ta_size = return_all_outputs ? time_steps_t : 1;
+ var output_ta = new List();
+ for (int i = 0; i < output_time_zero.ToList().Count; i++)
+ {
+ var Out = output_time_zero.ToList()[i];
+ output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape));
+ }
+
+ var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time");
+
+
+
+ Func? masking_fn;
+ Func? compute_masked_output = null;
+ if (mask != null)
+ {
+ if (go_backwards)
+ {
+ mask = tf.reverse(mask, axis: new[] { 0 });
+ }
+ var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t);
+ mask_ta = mask_ta.unstack(mask);
+
+ masking_fn = (time) =>
+ {
+ return mask_ta.read(time);
+ };
+
+ compute_masked_output = (mask_t, flat_out, flat_mask) =>
+ {
+ var tiled_mask_t = new Tensors();
+ foreach (var o in flat_out)
+ {
+ tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank));
+ }
+
+ Tensors res = new Tensors();
+ foreach (var (m, o, fm) in zip(tiled_mask_t.ToList(), flat_out.ToList(), flat_mask.ToList()))
+ {
+ res.Add(tf.where(m, o, fm));
+ }
+ return res;
+ };
+ }
+ // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
+ else if (input_length is Tensor)
+ {
+ if (go_backwards)
+ {
+ var max_len = tf.reduce_max(input_length, axis: 0);
+ var rev_input_length = tf.subtract(max_len - 1, input_length);
+
+ masking_fn = (time) =>
+ {
+ return tf.less(rev_input_length, time);
+ };
+ }
+ else
+ {
+ masking_fn = (time) =>
+ {
+ return tf.greater(input_length, time);
+ };
+ }
+
+ compute_masked_output = (mask_t, flat_out, flat_mask) =>
+ {
+ var res = new List();
+ foreach (var (o, zo) in zip(flat_out, flat_mask))
+ {
+ res.Add(tf.where(mask_t, o, zo));
+ }
+ return res;
+ };
+ }
+ else
+ {
+ masking_fn = null;
+ }
+
+ Func cond = (time) => (time < time_steps_t);
+ int parallel_iterations = 32;
+ if (masking_fn != null)
+ {
+ // Mask for the T output will be base on the output of T - 1. In the
+ // case T = 0, a zero filled tensor will be used.
+ var flat_zero_output = new Tensors();
+ foreach (var o in Nest.Flatten(output_time_zero))
+ {
+ flat_zero_output.Add(tf.zeros_like(o));
+ }
+
+ var prev_output = flat_zero_output;
+ var output_ta_t = output_ta;
+ Tensor _step(Tensor time)
+ {
+ /*
+ RNN step function.
+ Args:
+ time: Current timestep value.
+ output_ta_t: TensorArray.
+ prev_output: tuple of outputs from time - 1.
+ *states: List of states.
+ Returns:
+ Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)`
+ */
+
+ var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
+ // maybe set shape
+ // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
+ var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
+ var mask_t = masking_fn(time);
+ var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants));
+ // mask output
+ var flat_output = Nest.Flatten(output).ToList();
+
+ var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList();
+
+ // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type
+ var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output);
+
+ // mask states
+ var flat_state = states.ToList();
+ var flat_new_state = new_states_internal.ToList();
+
+ foreach (var (state, new_state) in zip(flat_state, flat_new_state))
+ {
+ if (new_state is Tensor)
+ {
+ new_state.shape = state.shape;
+ }
+ }
+
+ var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state);
+ new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors();
+
+ var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
+ // TODO(Wanglongzhi2001),deal with zip output_ta_t
+ foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
+ {
+ output_ta_t.Add(ta.write(ta_index_to_write, Out));
+ }
+
+ new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
+
+ output_ta = output_ta_t;
+ new_states = new_states_internal;
+ return time + 1;
+
+ }
+ var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
+ }
+ else
+ {
+ var output_ta_t = output_ta;
+ new_states = states;
+ Tensor _step(Tensor time)
+ {
+ var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
+ // maybe set shape
+ // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
+ var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
+ var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
+ var flat_state = new_states.Flatten().ToList();
+ var flat_new_state = new_states_internal.Flatten().ToList();
+ foreach (var (state, new_state) in zip(flat_state, flat_new_state))
+ {
+ if (new_state is Tensor)
+ {
+ new_state.shape = state.shape;
+ }
+ }
+ var flat_output = Nest.Flatten(output);
+ var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
+ output_ta_t = zip(output_ta_t, flat_output).Select(item =>
+ {
+ var (ta, out_) = item;
+ return ta.write(ta_index_to_write, out_);
+ }).ToList();
+
+ new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
+ output_ta = output_ta_t;
+ new_states = new_states_internal;
+ return time + 1;
+ }
+ var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
+ }
+ //Tensors outputs = new Tensors();
+ foreach (var o in output_ta)
+ {
+ outputs.Add(o.stack());
+ }
+ foreach (var o in outputs)
+ {
+ last_output.Add(o[-1]);
+ }
+ outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors();
+ last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors();
+
}
- //else // unroll == false
- //{
- // var states = initial_states;
- // // Create input tensor array, if the inputs is nested tensors, then it
- // // will be flattened first, and tensor array will be created one per
- // // flattened tensor.
- // var input_ta = new List();
- // for (int i = 0; i < flatted_inptus.Count; i++)
- // {
- // input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t));
- // }
-
- // // Get the time(0) input and compute the output for that, the output will
- // // be used to determine the dtype of output tensor array. Don't read from
- // // input_ta due to TensorArray clear_after_read default to True.
- // var inps = new Tensors();
- // foreach (var inp in flatted_inptus)
- // {
- // inps.Add(inp[0]);
- // }
- // var input_time_zero = nest.pack_sequence_as(inputs, inps);
-
- // // output_time_zero is used to determine the cell output shape and its
- // // dtype. the value is discarded.
- // (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants });
-
- // var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1);
- // var output_ta = new List();
- // for (int i = 0; i < output_time_zero.ToList().Count; i++)
- // {
- // var Out = output_time_zero.ToList()[i];
- // output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape));
- // }
-
- // var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time");
-
-
-
- // Func? masking_fn;
- // Func? compute_masked_output = null;
- // if (mask != null)
- // {
- // if (go_backwards)
- // {
- // mask = tf.reverse(mask, axis: new[] { 0 });
- // }
- // var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t);
- // mask_ta = mask_ta.unstack(mask);
-
- // masking_fn = (time) =>
- // {
- // return mask_ta.read(time);
- // };
-
- // compute_masked_output = (mask_t, flat_out, flat_mask) =>
- // {
- // var tiled_mask_t = new Tensors();
- // foreach (var o in flat_out)
- // {
- // tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank));
- // }
-
- // Tensors res = new Tensors();
- // foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask))
- // {
- // res.Add(tf.where(m, o, fm));
- // }
- // return res;
- // };
- // }
- // // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
- // else if (input_length is Tensor)
- // {
- // if (go_backwards)
- // {
- // var max_len = tf.reduce_max(input_length, axis: 0);
- // var rev_input_length = tf.subtract(max_len - 1, input_length);
-
- // masking_fn = (time) =>
- // {
- // return tf.less(rev_input_length, time);
- // };
- // }
- // else
- // {
- // masking_fn = (time) =>
- // {
- // return tf.greater(input_length, time);
- // };
- // }
-
- // compute_masked_output = (mask_t, flat_out, flat_mask) =>
- // {
- // var res = new List();
- // foreach (var (o, zo) in zip(flat_out, flat_mask))
- // {
- // res.Add(tf.where(mask_t, o, zo));
- // }
- // return res;
- // };
- // }
- // else
- // {
- // masking_fn = null;
- // }
-
-
- // if (masking_fn != null)
- // {
- // // Mask for the T output will be base on the output of T - 1. In the
- // // case T = 0, a zero filled tensor will be used.
- // var flat_zero_output = new Tensors();
- // foreach (var o in nest.flatten(output_time_zero))
- // {
- // flat_zero_output.Add(tf.zeros_like(o));
- // }
-
-
- // (Tensor, List, Tensors, Tensors) _step(Tensor time, List output_ta_t, Tensors prev_output, Tensors states)
- // {
- // /*
- // RNN step function.
- // Args:
- // time: Current timestep value.
- // output_ta_t: TensorArray.
- // prev_output: tuple of outputs from time - 1.
- // *states: List of states.
- // Returns:
- // Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)`
- // */
-
- // var current_input = input_ta.Select(x => x.read(time)).ToList();
- // // maybe set shape
- // // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
- // current_input = (List)nest.pack_sequence_as(inputs, current_input);
- // var mask_t = masking_fn(time);
- // var (output, new_states) = step_function(current_input, new Tensors { states, constants });
- // // mask output
- // //var flat_output = nest.flatten(output);
- // var flat_output = output.ToList();
-
- // var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList();
-
- // // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type
- // var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output);
-
- // // mask states
- // var flat_state = states.ToList();
- // var flat_new_state = new_states.ToList();
-
- // foreach (var (state, new_state) in zip(flat_state, flat_new_state))
- // {
- // if (new_state is Tensor)
- // {
- // new_state.set_shape(state.shape);
- // }
- // }
-
- // var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state);
- // new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state);
-
- // var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
- // var Output_ta_t = new List();
- // // TODO(Wanglongzhi2001),deal with zip output_ta_t
- // foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
- // {
- // Output_ta_t.Add(ta.write(ta_index_to_write, Out));
- // }
-
-
-
- // //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state);
-
-
- // return (time + 1, Output_ta_t, flat_new_output, new_states);
-
- // }
- // Func cond = (time) => (time < time_step_t);
-
- // var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states));
- // new_states = final_outputs.Item4;
- // output_ta = final_outputs.Item2;
-
- // }
- // else
- // {
- // (Tensor, List, Tensors) _step(Tensor time, List output_ta_t, Tensors states)
- // {
- // var current_input = input_ta.Select(x => x.read(time)).ToList();
- // // maybe set shape
- // // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
- // current_input = (List)nest.pack_sequence_as(inputs, current_input);
- // var (output, new_states) = step_function(current_input, new Tensors { states, constants });
- // var flat_state = states.ToList();
- // var flat_new_state = new_states.ToList();
- // foreach (var (state, new_state) in zip(flat_state, flat_new_state))
- // {
- // if (new_state is Tensor)
- // {
- // new_state.set_shape(state.shape);
- // }
- // }
- // var flat_output = output.ToList();
- // var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
- // var Output_ta_t = new List();
- // foreach (var (ta, out_) in zip(output_ta_t, flat_output))
- // {
- // Output_ta_t.Add(ta.write(ta_index_to_write, out_));
- // }
-
- // new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state);
- // return (time + 1, Output_ta_t, new_states);
- // }
- // Func cond = (time) => (time < time_step_t);
- // var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states));
- // new_states = final_outputs.Item3;
- // output_ta = final_outputs.Item2;
-
- // }
- // //Tensors outputs = new Tensors();
- // foreach (var o in output_ta)
- // {
- // outputs.Add(o.stack());
- // }
- // foreach (var o in outputs)
- // {
- // last_output.Add(o[-1]);
- // }
- // outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs);
- // last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output);
-
- //}
Func set_shape;
set_shape = (output_) =>
@@ -947,18 +950,38 @@ namespace Tensorflow.Keras
shape[0] = 1;
}
shape[1] = (int)batch;
- output_.set_shape(new Tensor(shape));
+ output_.shape = shape;
}
return output_;
};
- var Outputs = (Tensors)nest.map_structure(set_shape, outputs);
+ outputs = Nest.MapStructure(set_shape, outputs).ToTensors();
if (!time_major)
{
- Outputs = nest.map_structure(swap_batch_timestep, outputs);
+ outputs = Nest.MapStructure(swap_batch_timestep, outputs).ToTensors();
+ }
+ return (last_output, outputs, new_states);
+
+ }
+
+ public Tensor reverse(Tensor input, int axis)
+ {
+ return reverse(input, new int[] { axis });
+ }
+
+ public Tensor reverse(Tensor input, int[] axes)
+ {
+ return tf.reverse(input, axes);
+ }
+
+ public Tensor maybe_convert_to_ragged(bool is_ragged_output, Tensor output, int nested_row_lengths, bool go_backwards = false)
+ {
+ if (!is_ragged_output)
+ {
+ return output;
}
- return (last_output, Outputs, new_states);
+ throw new NotImplementedException("Not implemented currently, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
index b014737f..ab4cef12 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
@@ -55,8 +55,8 @@ namespace Tensorflow.Keras.Layers.Rnn
if (_states == null)
{
// CHECK(Rinne): check if this is correct.
- var state = nest.map_structure(x => null, _cell.StateSize);
- return new Tensors { state };
+ var nested = _cell.StateSize.MapStructure(x => null);
+ _states = nested.AsNest().ToTensors();
}
return _states;
}
@@ -230,7 +230,7 @@ namespace Tensorflow.Keras.Layers.Rnn
Tensors? mask = rnn_optional_args?.Mask;
//var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
// 暂时先不接受ragged tensor
- int? row_length = null;
+ int row_length = 0; // TODO(Rinne): support this param.
bool is_ragged_input = false;
_validate_args_if_ragged(is_ragged_input, mask);
@@ -249,16 +249,16 @@ namespace Tensorflow.Keras.Layers.Rnn
if (mask != null)
{
// Time step masks must be the same for each input.
- mask = nest.flatten(mask)[0];
+ mask = mask.Flatten().First();
}
Shape input_shape;
- if (nest.is_nested(inputs))
+ if (!inputs.IsSingle())
{
// In the case of nested input, use the first element for shape check
// input_shape = nest.flatten(inputs)[0].shape;
// TODO(Wanglongzhi2001)
- input_shape = nest.flatten(inputs)[0].shape;
+ input_shape = inputs.Flatten().First().shape;
}
else
{
@@ -286,6 +286,7 @@ namespace Tensorflow.Keras.Layers.Rnn
// cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call)
Func step;
+ bool is_tf_rnn_cell = _cell.IsTFRnnCell;
if (constants is not null)
{
if (!_cell.SupportOptionalArgs)
@@ -299,7 +300,8 @@ namespace Tensorflow.Keras.Layers.Rnn
{
constants = new Tensors(states.TakeLast(_num_constants));
states = new Tensors(states.SkipLast(_num_constants));
- var(output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
+ states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
+ var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
// TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors?
return (output, new_states.Single);
};
@@ -308,13 +310,13 @@ namespace Tensorflow.Keras.Layers.Rnn
{
step = (inputs, states) =>
{
- // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states)
+ states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = _cell.Apply(inputs, states);
return (output, new_states.Single);
};
}
- var (last_output, outputs, states) = BackendImpl.rnn(step,
+ var (last_output, outputs, states) = keras.backend.rnn(step,
inputs,
initial_state,
constants: constants,
@@ -334,8 +336,8 @@ namespace Tensorflow.Keras.Layers.Rnn
Tensors output = new Tensors();
if (_args.ReturnSequences)
{
- throw new NotImplementedException("this argument havn't been developed.");
-
+ // TODO(Rinne): add go_backwards parameter and revise the `row_length` param
+ output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false);
}
else
{
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs
index fcb5d1eb..751312e5 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs
@@ -14,8 +14,8 @@ namespace Tensorflow.Keras.Layers.Rnn
public RnnCellBase(LayerArgs args) : base(args) { }
public abstract GeneralizedTensorShape StateSize { get; }
public abstract GeneralizedTensorShape OutputSize { get; }
+ public abstract bool IsTFRnnCell { get; }
public abstract bool SupportOptionalArgs { get; }
- public abstract (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null);
public virtual Tensors GetInitialState(Tensors inputs, long batch_size, TF_DataType dtype)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype);
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
index abb57d8a..f0b2ed4d 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
@@ -5,6 +5,7 @@ using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;
+using Tensorflow.Common.Extensions;
namespace Tensorflow.Keras.Layers.Rnn
{
@@ -26,6 +27,7 @@ namespace Tensorflow.Keras.Layers.Rnn
public override GeneralizedTensorShape StateSize => _state_size;
public override GeneralizedTensorShape OutputSize => _output_size;
+ public override bool IsTFRnnCell => true;
public override bool SupportOptionalArgs => false;
public SimpleRNNCell(SimpleRNNCellArgs args) : base(args)
@@ -66,37 +68,22 @@ namespace Tensorflow.Keras.Layers.Rnn
built = true;
}
- public override (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null)
+ // TODO(Rinne): revise the trining param (with refactoring of the framework)
+ protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
{
// TODO(Rinne): check if it will have multiple tensors when not nested.
- Tensor prev_output = states[0];
+ Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states;
var dp_mask = get_dropout_maskcell_for_cell(inputs, training.Value);
var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value);
Tensor h;
- var ranks = inputs.rank;
if (dp_mask != null)
{
- if (ranks > 2)
- {
- // 因为multiply函数会自动添加第一个维度,所以加上下标0
- h = tf.linalg.tensordot(math_ops.multiply(inputs, dp_mask)[0], _kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
- }
- else
- {
- h = math_ops.matmul(math_ops.multiply(inputs, dp_mask)[0], _kernel.AsTensor());
- }
+ h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor());
}
else
{
- if (ranks > 2)
- {
- h = tf.linalg.tensordot(inputs, _kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
- }
- else
- {
- h = math_ops.matmul(inputs, _kernel.AsTensor());
- }
+ h = math_ops.matmul(inputs, _kernel.AsTensor());
}
if (_bias != null)
@@ -106,26 +93,25 @@ namespace Tensorflow.Keras.Layers.Rnn
if (rec_dp_mask != null)
{
- prev_output = math_ops.multiply(prev_output, rec_dp_mask)[0];
+ prev_output = math_ops.multiply(prev_output, rec_dp_mask);
}
- ranks = prev_output.rank;
- Tensor output;
- if (ranks > 2)
+ Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor());
+
+ if (_args.Activation != null)
{
- output = h + tf.linalg.tensordot(prev_output[0], _recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } });
+ output = _args.Activation.Apply(output);
}
- else
+ if (Nest.IsNested(states))
{
- output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor());
+ return new Nest(new List> {
+ new Nest(new List> { new Nest(output) }), new Nest(output) })
+ .ToTensors();
}
- Console.WriteLine($"shape of output: {output.shape}");
-
- if (_args.Activation != null)
+ else
{
- output = _args.Activation.Apply(output);
+ return new Tensors(output, output);
}
- return (output, new Tensors { output });
}
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
index 7923192f..0b92fd3c 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
@@ -170,6 +170,7 @@ namespace Tensorflow.Keras.Layers.Rnn
}
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
+ public bool IsTFRnnCell => throw new NotImplementedException();
public bool SupportOptionalArgs => throw new NotImplementedException();
}
}
diff --git a/src/TensorflowNET.Hub/KerasLayer.cs b/src/TensorflowNET.Hub/KerasLayer.cs
index b9ca949b..20d9851b 100644
--- a/src/TensorflowNET.Hub/KerasLayer.cs
+++ b/src/TensorflowNET.Hub/KerasLayer.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using Tensorflow.Common.Types;
using Tensorflow.Keras.Engine;
using Tensorflow.Train;
using Tensorflow.Training;
@@ -89,7 +90,7 @@ namespace Tensorflow.Hub
}
}
- protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optionalArgs = null)
{
_check_trainability();