Browse Source

Draft PR for RNN

pull/1090/head
Wanglongzhi2001 2 years ago
parent
commit
46b86e845f
52 changed files with 1187 additions and 70 deletions
  1. +15
    -0
      src/TensorFlowNET.Core/APIs/tf.control_flow.cs
  2. +10
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
  3. +2
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
  4. +63
    -0
      src/TensorFlowNET.Core/NumPy/StateSizeWrapper.cs
  5. +2
    -1
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  6. +47
    -0
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  7. +44
    -0
      src/TensorFlowNET.Core/Util/nest.py.cs
  8. +529
    -11
      src/TensorFlowNET.Keras/BackendImpl.cs
  9. +2
    -2
      src/TensorFlowNET.Keras/Engine/Layer.cs
  10. +1
    -1
      src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
  11. +1
    -1
      src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
  12. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs
  13. +1
    -1
      src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs
  14. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
  15. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs
  16. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs
  17. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs
  18. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/Swish.cs
  19. +1
    -1
      src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs
  20. +1
    -1
      src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs
  21. +1
    -1
      src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs
  22. +1
    -1
      src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
  23. +1
    -1
      src/TensorFlowNET.Keras/Layers/Core/Dense.cs
  24. +1
    -1
      src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs
  25. +1
    -1
      src/TensorFlowNET.Keras/Layers/Core/Embedding.cs
  26. +1
    -1
      src/TensorFlowNET.Keras/Layers/Merging/Merge.cs
  27. +1
    -1
      src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs
  28. +1
    -1
      src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs
  29. +1
    -1
      src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs
  30. +1
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs
  31. +1
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs
  32. +1
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs
  33. +1
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs
  34. +1
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs
  35. +1
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs
  36. +1
    -1
      src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs
  37. +1
    -1
      src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs
  38. +1
    -1
      src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs
  39. +1
    -1
      src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs
  40. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs
  41. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs
  42. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs
  43. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs
  44. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs
  45. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs
  46. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs
  47. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs
  48. +2
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs
  49. +417
    -5
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  50. +2
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  51. +7
    -5
      src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
  52. +1
    -1
      src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs

+ 15
- 0
src/TensorFlowNET.Core/APIs/tf.control_flow.cs View File

@@ -57,6 +57,21 @@ namespace Tensorflow
new[] { loop_vars });
return results[0];
}
public (Tensor, List<TensorArray>, Tensors, Tensors) while_loop(Func<Tensor, Tensor> cond,
Func<Tensor, List<TensorArray>, Tensors, Tensors, (Tensor, List<TensorArray>, Tensors, Tensors)> body,
(Tensor, List<TensorArray>, Tensors, Tensors) loop_vars,
int parallel_iterations = 10)
=> control_flow_ops.while_loop(cond,
body,
loop_vars);

public (Tensor, List<TensorArray>, Tensors) while_loop(Func<Tensor, Tensor> cond,
Func<Tensor, List<TensorArray>, Tensors, (Tensor, List<TensorArray>, Tensors)> body,
(Tensor, List<TensorArray>, Tensors) loop_vars,
int parallel_iterations = 10)
=> control_flow_ops.while_loop(cond,
body,
loop_vars);

public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,


+ 10
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs View File

@@ -1,5 +1,9 @@
using Newtonsoft.Json;
using OneOf;
using System.Collections.Generic;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
@@ -7,11 +11,14 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public interface IRnnArgCell : ILayer
{
object state_size { get; }
public Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null);
public StateSizeWrapper state_size { get; set; }
public int output_size { get; set; }
}
[JsonProperty("cell")]
// TODO: the cell should be serialized with `serialize_keras_object`.
public IRnnArgCell Cell { get; set; } = null;
public OneOf<IList<IRnnArgCell>, IRnnArgCell> Cell { get; set; }

[JsonProperty("return_sequences")]
public bool ReturnSequences { get; set; } = false;
[JsonProperty("return_state")]
@@ -25,6 +32,7 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
[JsonProperty("time_major")]
public bool TimeMajor { get; set; } = false;
// TODO: Add `num_constants` and `zero_output_for_mask`.
public bool ZeroOutputForMask { get; set; } = false;
public Dictionary<string, object> Kwargs { get; set; } = null;

public int Units { get; set; }


+ 2
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs View File

@@ -1,10 +1,11 @@
using System.Collections.Generic;
using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class StackedRNNCellsArgs : LayerArgs
{
public IList<RnnCell> Cells { get; set; }
public IList<IRnnArgCell> Cells { get; set; }
public Dictionary<string, object> Kwargs { get; set; } = null;
}
}

+ 63
- 0
src/TensorFlowNET.Core/NumPy/StateSizeWrapper.cs View File

@@ -0,0 +1,63 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Collections;


namespace Tensorflow.NumPy
{
// Since state_size in RNN is a single integer or array of integer, so use StateSizeWrapper to hold it
public class StateSizeWrapper : IEnumerable<int>
{
int[] _state_size;
public int[] state_size => _state_size;

public StateSizeWrapper(int state_size)
{
_state_size = new int[] { state_size };
}

public StateSizeWrapper(params int[] state_size)
{
_state_size = state_size;
}
public StateSizeWrapper(IEnumerable<int> state_size)
{
_state_size = state_size.ToArray();
}

public static implicit operator StateSizeWrapper(int[] state_size)
=> new StateSizeWrapper(state_size);

public static implicit operator StateSizeWrapper(int state_size)
=> new StateSizeWrapper(state_size);

public static implicit operator StateSizeWrapper((int, int) state_size)
=> new StateSizeWrapper(state_size.Item1, state_size.Item2);

public static implicit operator StateSizeWrapper(List<int> v)
=> new StateSizeWrapper(v);
public override string ToString()
{
return $"{state_size}";
}

public int this[int n]
{
get => n < 0 ? state_size[state_size.Length + n] : state_size[n];
set => state_size[n] = value;
}

public IEnumerator<int> GetEnumerator()
{
return state_size.ToList().GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
}



+ 2
- 1
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -26,6 +26,7 @@ using Tensorflow.Operations;
using Tensorflow.Train;
using Tensorflow.Util;
using static Tensorflow.Binding;
using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs;

namespace Tensorflow
{
@@ -50,7 +51,7 @@ namespace Tensorflow
/// matching structure of Tensors having shape `[batch_size].concatenate(s)`
/// for each `s` in `self.batch_size`.
/// </summary>
public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
public abstract class RnnCell : ILayer
{
/// <summary>
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight


+ 47
- 0
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -698,6 +698,53 @@ namespace Tensorflow
});
}

public static (Tensor, List<TensorArray>, Tensors, Tensors) while_loop(Func<Tensor, Tensor> cond,
Func<Tensor, List<TensorArray>, Tensors, Tensors, (Tensor, List<TensorArray>, Tensors, Tensors)> body,
(Tensor, List<TensorArray>, Tensors, Tensors) loop_vars,
int parallel_iterations = 10,
string name = null)
{
var executing_eagerly = tf.Context.executing_eagerly();
if (!executing_eagerly)
{
throw new NotImplementedException("");
}

return tf_with(ops.name_scope("name", "while"), delegate
{
while ((bool)cond(loop_vars.Item1))
{
loop_vars = body(loop_vars.Item1, loop_vars.Item2, loop_vars.Item3, loop_vars.Item4);
}

return loop_vars;
});
}

public static (Tensor, List<TensorArray>, Tensors) while_loop(Func<Tensor, Tensor> cond,
Func<Tensor, List<TensorArray>, Tensors, (Tensor, List<TensorArray>, Tensors)> body,
(Tensor, List<TensorArray>, Tensors) loop_vars,
int parallel_iterations = 10,
string name = null)
{
var executing_eagerly = tf.Context.executing_eagerly();
if (!executing_eagerly)
{
throw new NotImplementedException("");
}

return tf_with(ops.name_scope("name", "while"), delegate
{
while ((bool)cond(loop_vars.Item1))
{
loop_vars = body(loop_vars.Item1, loop_vars.Item2, loop_vars.Item3);
}

return loop_vars;
});
}


/// <summary>
/// Repeat `body` while the condition `cond` is true.
/// </summary>


+ 44
- 0
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -211,6 +211,28 @@ namespace Tensorflow.Util
=> arg is IEnumerable && !(arg is string) && !(arg is NDArray) &&
!(arg.GetType().IsGenericType && arg.GetType().GetGenericTypeDefinition() == typeof(HashSet<>));

public static bool is_nested(object obj)
{
// Check if the object is an IEnumerable
if (obj is IEnumerable)
{
// If it is, check if it is a nested structure
foreach (object item in (IEnumerable)obj)
{
if (is_nested(item))
{
return true;
}
}
return true;
}
else
{
// If it is not, return false
return false;
}
}

public static bool is_mapping(object arg) => arg is IDictionary;

//# See the swig file (util.i) for documentation.
@@ -263,7 +285,29 @@ namespace Tensorflow.Util
}
}

public static List<T> FlattenTupple<T>(object tuple)
{
List<T> items = new List<T>();
var type = tuple.GetType();

if (type.GetInterface("ITuple") == null)
throw new ArgumentException("This is not a tuple!");

foreach (var property in type.GetProperties())
{
var value = property.GetValue(tuple);
if (property.PropertyType.GetInterface("ITuple") != null)
{
var subItems = FlattenTupple<T>(value);
items.AddRange(subItems);
}
else
{
items.Add((T)value);
}
}
return items;
}
//# See the swig file (util.i) for documentation.
//_same_namedtuples = _pywrap_tensorflow.SameNamedtuples



+ 529
- 11
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -22,6 +22,9 @@ using Tensorflow.Functions;
using Tensorflow.Graphs;
using static Tensorflow.Binding;
using static Tensorflow.Graphs.SubGraphUtility;
using Tensorflow.Util;
using Tensorflow.Operations;
using OneOf;

namespace Tensorflow.Keras
{
@@ -65,7 +68,7 @@ namespace Tensorflow.Keras
return;
}
var graph = v.Graph;
if(graph is null)
if (graph is null)
{
graph = get_graph();
}
@@ -95,7 +98,7 @@ namespace Tensorflow.Keras
{
if (_GRAPH == null)
_GRAPH = new FuncGraph("keras_graph");
return _GRAPH;
}
return ops.get_default_graph();
@@ -105,7 +108,7 @@ namespace Tensorflow.Keras
{
if (_CURRENT_SCRATCH_GRAPH == null)
_CURRENT_SCRATCH_GRAPH = new FuncGraph("keras_scratch_graph");
return _CURRENT_SCRATCH_GRAPH;
}

@@ -230,16 +233,16 @@ namespace Tensorflow.Keras
{
if (outputs[0].op.type == "Const")
return tensor_util.constant_value(outputs);
var source_graph = outputs.graph;
var exec_graph = _scratch_graph();
var global_graph = get_graph();
if (source_graph == global_graph && exec_graph != global_graph)
{
var lifted_map = lift_to_graph(outputs, exec_graph,
new List<Tensor>(),
add_sources: true,
handle_captures: true,
var lifted_map = lift_to_graph(outputs, exec_graph,
new List<Tensor>(),
add_sources: true,
handle_captures: true,
base_graph: source_graph);
}
if (outputs[0].op.type == "Placeholder"
@@ -250,7 +253,7 @@ namespace Tensorflow.Keras
exec_graph.as_default();
exec_graph.Inputs = exec_graph.internal_captures;
exec_graph.Outputs = outputs;
var graph_fn = new ConcreteFunction(exec_graph);

_CURRENT_SCRATCH_GRAPH = null;
@@ -370,7 +373,7 @@ namespace Tensorflow.Keras
/// <param name="data_format"></param>
/// <param name="interpolation"></param>
/// <returns></returns>
public Tensor resize_images(Tensor x, int height_factor, int width_factor,
public Tensor resize_images(Tensor x, int height_factor, int width_factor,
string data_format, string interpolation = "nearest")
{
var (rows, cols) = (0, 0);
@@ -412,7 +415,7 @@ namespace Tensorflow.Keras
/// <returns></returns>
public Tensor concatenate(Tensors tensors, int axis = -1)
{
if(axis < 0)
if (axis < 0)
{
var rank = tensors[0].ndim;
if (rank > -1)
@@ -450,5 +453,520 @@ namespace Tensorflow.Keras

return x;
}

public static (Tensors, Tensors) convert_inputs_if_ragged(OneOf<Tensor, RaggedTensor> inputs)
{
throw new NotImplementedException();
}

//
public static (Tensors, Tensors, Tensors) rnn(
Func<Tensors, Tensors, (Tensors, Tensors)> step_function, // args:inputs, states, return:output, new_states
Tensors inputs, // inputs is a tuple of tensors (one per input sequence)
Tensors initial_states,
bool go_backwards = false,
Tensor? mask = null,
Tensors? constants = null,
bool unroll = false,
Tensors? input_length = null, // An integer or a 1-D Tensor,depending on whether the time dimension is fixed-length or not
bool time_major = false,
bool zero_output_for_mask = false,
bool return_all_outputs = true)
{

Tensors swap_batch_timestep(Tensors input_t)
{
var axes = Enumerable.Range(0, input_t.rank).ToArray();
axes[0] = 1;
axes[1] = 0;
return tf.transpose(input_t, axes);
}

if (!time_major)
{
inputs = nest.map_structure(swap_batch_timestep, inputs);
}

var flatted_inptus = nest.flatten(inputs);
var time_steps = flatted_inptus[0].shape[0];
var batch = flatted_inptus[0].shape[1];
var time_step_t = tf.shape(flatted_inptus[0])[0];

foreach (var input_ in flatted_inptus)
{
input_.shape.with_rank_at_least(3);
}

if (mask != null)
{
if (mask.dtype != TF_DataType.TF_BOOL)
{
mask = tf.cast(mask, TF_DataType.TF_BOOL);
}

if (mask.rank == 2)
{
mask = tf.expand_dims(mask, -1);
}

if (!time_major)
{
mask = swap_batch_timestep(mask);
}

}

if (constants == null)
{
constants = new List<Tensor>();
}

// tf.where needs its condition tensor to be the same shape as its two
// result tensors, but in our case the condition (mask) tensor is
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
// So we need to broadcast the mask to match the shape of inputs.
// That's what the tile call does, it just repeats the mask along its
// second dimension n times.

Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
{
if (nest.is_nested(mask_t))
{
throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}");
}

if (nest.is_nested(input_t))
{
throw new ValueError($"input_t is expected to be tensor, but got {input_t}");
}

var rank_diff = input_t.rank - mask_t.rank;
for (int i = 0; i < rank_diff; i++)
{
mask_t = tf.expand_dims(mask_t, -1);
}
var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank));
return tf.tile(mask_t, multiples);
}

Tensors outputs = new Tensors();
Tensors output_time_zero = new Tensors();
Tensors last_output = new Tensors();
Tensors new_states = new Tensors();
if (unroll)
{
if (time_steps == 0)
{
throw new ValueError("Unrolling requires a fixed number of timesteps.");
}

// Process the input tensors. The input tensor need to be split on the
// time_step dim, and reverse if go_backwards is True. In the case of
// nested input, the input is flattened and then transformed
// individually. The result of this will be a tuple of lists, each of
// the item in tuple is list of the tensor with shape (batch, feature)


// TODO(Wanglongzhi2001),step_func接受的第二个参数为List,但是最后却用的tuple
//var states = Tuple.Create(initial_states);
var states = initial_states;

var successive_states = new Tensors();
var successive_outputs = new Tensors();

// Process the input tensors. The input tensor need to be split on the
// time_step dim, and reverse if go_backwards is True. In the case of
// nested input, the input is flattened and then transformed
// individually. The result of this will be a tuple of lists, each of
// the item in tuple is list of the tensor with shape (batch, feature)




Tensors _process_single_input_t(Tensors input_t)
{
input_t = tf.unstack(input_t); // unstack for time_step dim
if (go_backwards)
{
input_t.Reverse();
}
return input_t;
}

// TODO(Wanglongzhi2001)
Tensors processed_input;
if (nest.is_nested(inputs))
{
processed_input = nest.map_structure(_process_single_input_t, inputs);
}
else
{
processed_input = _process_single_input_t(inputs);
}

object _get_input_tensor(int time)
{
List<Tensor> inp = new List<Tensor>();
foreach (var t_ in processed_input)
{
inp.Add(t_[time]);
}
return nest.pack_sequence_as(inputs, inp);
}

if (mask != null)
{
var mask_list = tf.unstack(mask);
if (go_backwards)
{
mask_list.Reverse();
}

for (int i = 0; i < time_steps; i++)
{
// TODO(Wanglongzhi2001),deal with _get_input_tensor
var inp = _get_input_tensor(i);
var mask_t = mask_list[i];
// TODO
var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants });

var tiled_mask_t = _expand_mask(mask_t, output);

Tensors prev_output;
if (successive_outputs == null)
{
prev_output = tf.zeros_like(output);
}
else
{
prev_output = successive_outputs[successive_outputs.Length - 1];
}

output = tf.where(tiled_mask_t, output, prev_output);

//var flat_states = nest.flatten(states);
//var flat_new_states = nest.flatten(newStates);
var flat_states = states.ToList();
var flat_new_states = newStates.ToList();

var tiledMaskT = flat_states
.Select(s => _expand_mask(mask_t, s))
.ToArray();
var tuple = Tuple.Create(tiledMaskT);

List<Tensor> flat_final_states = new List<Tensor>();
foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states))
{
flat_final_states.Add(tf.where(m, s, ps));
}

states = (Tensors)nest.pack_sequence_as(states, flat_final_states);
if (return_all_outputs)
{
successive_outputs.Add(output);
successive_states.Add(states);
}
else
{
successive_outputs = new Tensors { output };
successive_states = new Tensors { states };
}

}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
outputs = tf.stack(successive_outputs);

if (zero_output_for_mask)
{
last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
}
else // mask is null
{
for (int i = 0; i < time_steps; i++)
{
var inp = _get_input_tensor(i);
var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants });
states = newStates;

if (return_all_outputs)
{
successive_outputs.Add(output);
successive_states.Add(newStates);
}
else
{
successive_outputs = new Tensors { output };
successive_states = new Tensors { newStates };
}
}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
outputs = tf.stack(successive_outputs);
}
}
}
else // unroll == false
{
var states = initial_states;
// Create input tensor array, if the inputs is nested tensors, then it
// will be flattened first, and tensor array will be created one per
// flattened tensor.
var input_ta = new List<TensorArray>();
for (int i = 0; i < flatted_inptus.Count; i++)
{
input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t));
}

// Get the time(0) input and compute the output for that, the output will
// be used to determine the dtype of output tensor array. Don't read from
// input_ta due to TensorArray clear_after_read default to True.
var inps = new Tensors();
foreach (var inp in flatted_inptus)
{
inps.Add(inp[0]);
}
var input_time_zero = nest.pack_sequence_as(inputs, inps);

// output_time_zero is used to determine the cell output shape and its
// dtype. the value is discarded.
(output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants });

var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1);
var output_ta = new List<TensorArray>();
for (int i = 0; i < output_time_zero.ToList().Count; i++)
{
var Out = output_time_zero.ToList()[i];
output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape));
}

var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time");



Func<Tensor, Tensor>? masking_fn;
Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null;
if (mask != null)
{
if (go_backwards)
{
mask = tf.reverse(mask, axis: new[] { 0 });
}
var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t);
mask_ta = mask_ta.unstack(mask);

masking_fn = (time) =>
{
return mask_ta.read(time);
};

compute_masked_output = (mask_t, flat_out, flat_mask) =>
{
var tiled_mask_t = new Tensors();
foreach (var o in flat_out)
{
tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank));
}

Tensors res = new Tensors();
foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask))
{
res.Add(tf.where(m, o, fm));
}
return res;
};
}
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
else if (input_length is Tensor)
{
if (go_backwards)
{
var max_len = tf.reduce_max(input_length, axis: 0);
var rev_input_length = tf.subtract(max_len - 1, input_length);

masking_fn = (time) =>
{
return tf.less(rev_input_length, time);
};
}
else
{
masking_fn = (time) =>
{
return tf.greater(input_length, time);
};
}

compute_masked_output = (mask_t, flat_out, flat_mask) =>
{
var res = new List<Tensor>();
foreach (var (o, zo) in zip(flat_out, flat_mask))
{
res.Add(tf.where(mask_t, o, zo));
}
return res;
};
}
else
{
masking_fn = null;
}


if (masking_fn != null)
{
// Mask for the T output will be base on the output of T - 1. In the
// case T = 0, a zero filled tensor will be used.
var flat_zero_output = new Tensors();
foreach (var o in nest.flatten(output_time_zero))
{
flat_zero_output.Add(tf.zeros_like(o));
}


(Tensor, List<TensorArray>, Tensors, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors prev_output, Tensors states)
{
/*
RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
prev_output: tuple of outputs from time - 1.
*states: List of states.
Returns:
Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)`
*/

var current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input);
var mask_t = masking_fn(time);
var (output, new_states) = step_function(current_input, new Tensors { states, constants });
// mask output
//var flat_output = nest.flatten(output);
var flat_output = output.ToList();

var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList();

// TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type
var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output);

// mask states
var flat_state = states.ToList();
var flat_new_state = new_states.ToList();

foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
if (new_state is Tensor)
{
new_state.set_shape(state.shape);
}
}

var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state);
new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state);

var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
var Output_ta_t = new List<TensorArray>();
// TODO(Wanglongzhi2001),deal with zip output_ta_t
foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
{
Output_ta_t.Add(ta.write(ta_index_to_write, Out));
}



//new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state);


return (time + 1, Output_ta_t, flat_new_output, new_states);

}
Func<Tensor, Tensor> cond = (time) => (time < time_step_t);

var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states));
new_states = final_outputs.Item4;
output_ta = final_outputs.Item2;

}
else
{
(Tensor, List<TensorArray>, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors states)
{
var current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input);
var (output, new_states) = step_function(current_input, new Tensors { states, constants });
var flat_state = states.ToList();
var flat_new_state = new_states.ToList();
foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
if (new_state is Tensor)
{
new_state.set_shape(state.shape);
}
}
var flat_output = output.ToList();
var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
var Output_ta_t = new List<TensorArray>();
foreach (var (ta, out_) in zip(output_ta_t, flat_output))
{
Output_ta_t.Add(ta.write(ta_index_to_write, out_));
}

new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state);
return (time + 1, Output_ta_t, new_states);
}
Func<Tensor, Tensor> cond = (time) => (time < time_step_t);
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states));
new_states = final_outputs.Item3;
output_ta = final_outputs.Item2;

}
//Tensors outputs = new Tensors();
foreach (var o in output_ta)
{
outputs.Add(o.stack());
}
foreach (var o in outputs)
{
last_output.Add(o[-1]);
}
outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs);
last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output);

}

Func<Tensor, Tensor> set_shape;
set_shape = (output_) =>
{
if (output_ is Tensor)
{
var shape = output_.shape.as_int_list();
if (return_all_outputs)
{
shape[0] = (int)time_steps;
}
else
{
shape[0] = 1;
}
shape[1] = (int)batch;
output_.set_shape(new Tensor(shape));
}
return output_;
};

var Outputs = (Tensors)nest.map_structure(set_shape, outputs);
if (!time_major)
{
Outputs = nest.map_structure(swap_batch_timestep, outputs);
}
return (last_output, Outputs, new_states);

}
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -332,9 +332,9 @@ namespace Tensorflow.Keras.Engine
/// <param name="state"></param>
/// <param name="training"></param>
/// <returns></returns>
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected virtual Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
if(ReplacedCall is not null)
if (ReplacedCall is not null)
{
return ReplacedCall(inputs);
}


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Activation/ELU.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers {
base.build(input_shape);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor output = inputs;
output = tf.where(output > 0f, output,


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers {
{
base.build(input_shape);
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor output = inputs;
return tf.exp(output);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs View File

@@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers {
public HardSigmoid ( LayerArgs args ) : base(args) {
// hard sigmoid has no arguments
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor x = inputs;
return tf.clip_by_value(
tf.add(tf.multiply(x, 0.2f), 0.5f), 0f, 1f);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
return tf.nn.leaky_relu(inputs, alpha: alpha);
}


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/SELU.cs View File

@@ -22,7 +22,8 @@ namespace Tensorflow.Keras.Layers {
}
base.build(input_shape);
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor output = inputs;
return tf.where(output > 0f,
tf.multiply(scale, output),


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs View File

@@ -11,7 +11,8 @@ namespace Tensorflow.Keras.Layers {
public Softmax ( SoftmaxArgs args ) : base(args) {
axis = args.axis;
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9)
: inputs;
Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true)));


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs View File

@@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers {
public Softplus ( LayerArgs args ) : base(args) {
// Softplus has no arguments
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor x = inputs;
return tf.log(
tf.add(tf.exp(x), 1f));


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs View File

@@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers {
public Softsign ( LayerArgs args ) : base(args) {
// Softsign has no arguments
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor x = inputs;
// x / (abs(x) + 1)
return tf.div(x, tf.add(1f, tf.abs(x)));


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/Swish.cs View File

@@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers {
public Swish ( LayerArgs args ) : base(args) {
// Swish has no arguments
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor x = inputs;

// x / (1 + exp(-x))


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Layers
{
// Tanh has no arguments
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor x = inputs;



+ 1
- 1
src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs View File

@@ -114,7 +114,7 @@ namespace Tensorflow.Keras.Layers
return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensors _inp;
Tensors _mask = null;


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs View File

@@ -252,7 +252,7 @@ namespace Tensorflow.Keras.Layers
return (attention_output, attention_scores);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensors _inp;
Tensor _mask = null;


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs View File

@@ -103,7 +103,7 @@ namespace Tensorflow.Keras.Layers
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
var outputs = _convolution_op.Apply(inputs, kernel.AsTensor());
if (use_bias)


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Core/Dense.cs View File

@@ -69,7 +69,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor outputs = null;
var rank = inputs.rank;


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs View File

@@ -189,7 +189,7 @@ namespace Tensorflow.Keras.Layers
// return new dict(base_config.items().ToList() + config.items().ToList());
//}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
var ret = tf.linalg.einsum(this.equation, (inputs, this.kernel.AsTensor()));
if (this.bias != null)


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Core/Embedding.cs View File

@@ -66,7 +66,7 @@ namespace Tensorflow.Keras.Layers
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64)


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Merging/Merge.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
return _merge_function(inputs);
}


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs View File

@@ -146,7 +146,7 @@ namespace Tensorflow.Keras.Layers
return false;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor outputs = null;
var training_tensor = training == null


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs View File

@@ -101,7 +101,7 @@ namespace Tensorflow.Keras.Layers
return input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor outputs = null;
var inputs_dtype = inputs.dtype.as_base_dtype();


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs View File

@@ -157,7 +157,7 @@ namespace Tensorflow.Keras.Layers
base.adapt(data, batch_size: batch_size, steps: steps);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
if (_args.Invert)
{


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers
{
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
if (data_format == "channels_last")
return math_ops.reduce_mean(inputs, 1, false);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers
{
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
if (data_format == "channels_last")
return math_ops.reduce_mean(inputs, (1, 2), false);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers
{
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
if (data_format == "channels_last")
return math_ops.reduce_max(inputs, 1, false);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers
{
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
if (data_format == "channels_last")
return math_ops.reduce_max(inputs, (1, 2), false);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers
input_spec = new InputSpec(ndim: 3);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
int pad_axis = args.DataFormat == "channels_first" ? 2 : 3;
inputs = tf.expand_dims(inputs, pad_axis);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers
input_spec = new InputSpec(ndim: 4);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
int[] pool_shape;
int[] strides;


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
var depth = args.NumTokens;
var max_value = tf.reduce_max(inputs);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
scale = constant_op.constant(args.Scale, args.DType);
offset = constant_op.constant(args.Offset, args.DType);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
return image_ops_impl.resize_images_v2(inputs, new[] { args.Height, args.Width }, method: args.Interpolation);
}


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
if (training == null)
training = false;


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs View File

@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Layers.Reshaping
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor output = inputs;
if (output.rank != 3)


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers.Reshaping
built = true;
_buildInputShape = input_shape;
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor output = inputs;
if (output.rank != 4)


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers.Reshaping
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor output = inputs;
if (output.rank != 5)


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Layers
_channels_first = args.DataFormat == "channels_first";
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
if (_channels_first)
{


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow.Keras.Layers {
built = true;
_buildInputShape = input_shape;
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
Tensor outputs = inputs;
return tf.transpose(outputs, new Axis(permute));


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
var shapes = new List<Tensor>();
shapes.Add(array_ops.shape(inputs)[0]);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers
inputSpec = new InputSpec(ndim: 4);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
return keras.backend.resize_images(inputs,
size[0], size[1],


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs View File

@@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Layers
this.input_spec = new InputSpec(ndim: 4);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
return keras.backend.spatial_2d_padding(inputs,
padding: padding,


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs View File

@@ -26,9 +26,9 @@ namespace Tensorflow.Keras.Layers.Rnn
.ToArray();
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
return base.Call(inputs, state: state, training: training);
return base.Call(inputs, initial_state: initial_state, training: training);
}
}
}

+ 417
- 5
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -1,9 +1,15 @@
using System;
using System.Collections;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using System.Reflection;
using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Util;
using OneOf;
using OneOf.Types;
using Tensorflow.Common.Extensions;
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;

namespace Tensorflow.Keras.Layers.Rnn
@@ -19,11 +25,46 @@ namespace Tensorflow.Keras.Layers.Rnn
protected IVariableV1 kernel;
protected IVariableV1 bias;
protected ILayer cell;

public RNN(RNNArgs args) : base(PreConstruct(args))
{
this.args = args;
SupportsMasking = true;

// if is StackedRnncell
if (args.Cell.IsT0)
{
cell = new StackedRNNCells(new StackedRNNCellsArgs
{
Cells = args.Cell.AsT0,
});
}
else
{
cell = args.Cell.AsT1;
}





Type type = cell.GetType();
MethodInfo methodInfo = type.GetMethod("Call");
if (methodInfo == null)
{
throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. ");
}

PropertyInfo propertyInfo = type.GetProperty("state_size");
if (propertyInfo == null)
{
throw new ValueError(@"The RNN cell should have a `state_size` attribute");
}



// get input_shape
this.args = PreConstruct(args);
// The input shape is unknown yet, it could have nested tensor inputs, and
// the input spec will be the list of specs for nested inputs, the structure
// of the input_spec will be the same as the input.
@@ -37,17 +78,384 @@ namespace Tensorflow.Keras.Layers.Rnn
//}
}

// States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...)
// state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape
public object States
{
get
{
if (_states == null)
{
var state = nest.map_structure(x => null, cell.state_size);
return nest.is_nested(state) ? state : new Tensors { state };
}
return _states;
}
set { _states = value; }
}

private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
{
var batch = input_shape[0];
var time_step = input_shape[1];
if (args.TimeMajor)
{
(batch, time_step) = (time_step, batch);
}

// state_size is a array of ints or a positive integer
var state_size = cell.state_size;


// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
Func<Shape, Shape> _get_output_shape;
_get_output_shape = (flat_output_size) =>
{
var output_dim = flat_output_size.as_int_list();
Shape output_shape;
if (args.ReturnSequences)
{
if (args.TimeMajor)
{
output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim));
}
else
{
output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim));

}
}
else
{
output_shape = new Shape(new int[] { (int)batch }.concat(output_dim));
}
return output_shape;
};

Shape output_shape;
if (cell.output_size != 0)
{
output_shape = nest.map_structure(_get_output_shape, cell.output_size);
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape);
}
else
{
output_shape = _get_output_shape(state_size[0]);
}

if (args.ReturnState)
{
Func<Shape, Shape> _get_state_shape;
_get_state_shape = (flat_state) =>
{
var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list());
return new Shape(state_shape);
};
var state_shape = _get_state_shape(new Shape(state_size.ToArray()));

return new List<Shape> { output_shape, state_shape };
}
else
{
return output_shape;
}
}

private Tensors compute_mask(Tensors inputs, Tensors mask)
{
// Time step masks must be the same for each input.
// This is because the mask for an RNN is of size [batch, time_steps, 1],
// and specifies which time steps should be skipped, and a time step
// must be skipped for all inputs.

mask = nest.flatten(mask)[0];
var output_mask = args.ReturnSequences ? mask : null;
if (args.ReturnState)
{
var state_mask = new List<Tensor>();
for (int i = 0; i < len(States); i++)
{
state_mask.Add(null);
}
return new List<Tensor> { output_mask }.concat(state_mask);
}
else
{
return output_mask;
}


}

public override void build(KerasShapesWrapper input_shape)
{
object get_input_spec(Shape shape)
{
var input_spec_shape = shape.as_int_list();

var (batch_index, time_step_index) = args.TimeMajor ? (1, 0) : (0, 1);
if (!args.Stateful)
{
input_spec_shape[batch_index] = -1;
}
input_spec_shape[time_step_index] = -1;
return new InputSpec(shape: input_spec_shape);
}

Shape get_step_input_shape(Shape shape)
{

// return shape[1:] if self.time_major else (shape[0],) + shape[2:]
if (args.TimeMajor)
{
return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray();
}
else
{
return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray());
}


}

object get_state_spec(Shape shape)
{
var state_spec_shape = shape.as_int_list();
// append bacth dim
state_spec_shape = new int[] { -1 }.concat(state_spec_shape);
return new InputSpec(shape: state_spec_shape);

}

// Check whether the input shape contains any nested shapes. It could be
// (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
// numpy inputs.


if (!cell.Built)
{
cell.build(input_shape);
}
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
// inputs: Tensors
// mask: Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked
// training: bool
// initial_state: List of initial state tensors to be passed to the first call of the cell
// constants: List of constant tensors to be passed to the cell at each timestep
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
return base.Call(inputs, state, training);
//var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
//bool is_ragged_input = row_length != null;
//_validate_args_if_ragged(is_ragged_input, mask);
var (inputs_processed, initial_state_processed, constants_processed) = _process_inputs(inputs, initial_state, constants);

_maybe_reset_cell_dropout_mask(cell);
if (cell is StackedRNNCells)
{
foreach (var cell in ((StackedRNNCells)cell).Cells)
{
_maybe_reset_cell_dropout_mask(cell);
}
}

if (mask != null)
{
// Time step masks must be the same for each input.
//mask = nest.flatten(mask)[0];
mask = mask[0];
}


Shape input_shape;
if (nest.is_nested(initial_state_processed))
{
// In the case of nested input, use the first element for shape check
// input_shape = nest.flatten(inputs)[0].shape;
input_shape = inputs[0].shape;
}
else
{
input_shape = inputs.shape;
}

var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1];

if (args.Unroll && timesteps != null)
{
throw new ValueError(
"Cannot unroll a RNN if the " +
"time dimension is undefined. \n" +
"- If using a Sequential model, " +
"specify the time dimension by passing " +
"an `input_shape` or `batch_input_shape` " +
"argument to your first layer. If your " +
"first layer is an Embedding, you can " +
"also use the `input_length` argument.\n" +
"- If using the functional API, specify " +
"the time dimension by passing a `shape` " +
"or `batch_shape` argument to your Input layer."
);
}

// cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call)
var cell_call_fn = cell.Call;
Func<Tensors, Tensors, (Tensors, Tensors)> step;
if (constants != null)
{
ParameterInfo[] parameters = cell_call_fn.GetMethodInfo().GetParameters();
bool hasParam = parameters.Any(p => p.Name == "constants");
if (!hasParam)
{
throw new ValueError(
$"RNN cell {cell} does not support constants." +
$"Received: constants={constants}");
}

step = (inputs, states) =>
{
// constants = states[-self._num_constants :]
constants = states.numpy()[new Slice(states.Length - _num_constants, states.Length)];
// states = states[: -self._num_constants]
states = states.numpy()[new Slice(0, states.Length - _num_constants)];
// states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states)
states = states.Length == 1 ? states[0] : states;
var (output, new_states) = cell_call_fn(inputs, null, null, states, constants);
if (!nest.is_nested(new_states))
{
return (output, new Tensors { new_states });
}
return (output, new_states);
};
}
else
{
step = (inputs, states) =>
{
// states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states)
states = states.Length == 1 ? states[0] : states;
var (output, new_states) = cell_call_fn(inputs, null, null, states, constants);
if (!nest.is_nested(new_states))
{
return (output, new Tensors { new_states });
}
return (output, new_states);
};
}

var (last_output, outputs, states) = BackendImpl.rnn(step,
inputs,
initial_state,
constants: constants,
go_backwards: args.GoBackwards,
mask: mask,
unroll: args.Unroll,
input_length: row_length != null ? row_length : new Tensor(timesteps),
time_major: args.TimeMajor,
zero_output_for_mask: args.ZeroOutputForMask,
return_all_outputs: args.ReturnSequences);

if (args.Stateful)
{
throw new NotImplementedException("this argument havn't been developed!");
}

Tensors output = new Tensors();
if (args.ReturnSequences)
{
throw new NotImplementedException("this argument havn't been developed!");

}
else
{
output = last_output;
}

if (args.ReturnState)
{

foreach (var state in states)
{
output.Add(state);
}
return output;
}
else
{
return output;
}
}

private (Tensors, Tensors, Tensors) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants)
{
bool IsSequence(object obj)
{
// Check if the object is an IEnumerable
if (obj is IEnumerable)
{
// If it is, check if it is a tuple
if (!(obj is Tuple))
{
return true;
}
}
// If it is not, return false
return false;
}

if (IsSequence(input))
{
if (_num_constants != 0)
{
initial_state = inputs[new Slice(1, len(inputs))];
}
else
{
initial_state = inputs[new Slice(1, len(inputs) - _num_constants)];
}
if (len(initial_state) == 0)
initial_state = null;
inputs = inputs[0];
}

if (args.Stateful)
{
throw new NotImplementedException("argument stateful has not been implemented!");

}

return (inputs, initial_state, constants);

}

private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
{
if (is_ragged_input)
{
if (args.Unroll)
{
throw new ValueError("The input received contains RaggedTensors and does " +
"not support unrolling. Disable unrolling by passing " +
"`unroll=False` in the RNN Layer constructor.");
}
if (mask != null)
{
throw new ValueError($"The mask that was passed in was {mask}, which " +
"cannot be applied to RaggedTensor inputs. Please " +
"make sure that there is no mask injected by upstream " +
"layers.");
}
}
}

void _maybe_reset_cell_dropout_mask(ILayer cell)
{
//if (cell is DropoutRNNCellMixin)
//{
// cell.reset_dropout_mask();
// cell.reset_recurrent_dropout_mask();
//}
}

private static RNNArgs PreConstruct(RNNArgs args)
@@ -77,6 +485,10 @@ namespace Tensorflow.Keras.Layers.Rnn
return args;
}

public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null)
{
throw new NotImplementedException();
}
public RNN New(LayerRnnCell cell,
bool return_sequences = false,
bool return_state = false,
@@ -95,7 +507,7 @@ namespace Tensorflow.Keras.Layers.Rnn
TimeMajor = time_major
});

public RNN New(IList<RnnCell> cell,
public RNN New(IList<IRnnArgCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
@@ -125,7 +537,7 @@ namespace Tensorflow.Keras.Layers.Rnn
}

// Check whether the state_size contains multiple states.
public static bool _is_multiple_state(object state_size)
public static bool is_multiple_state(object state_size)
{
var myIndexerProperty = state_size.GetType().GetProperty("Item");
return myIndexerProperty != null


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -42,9 +42,9 @@ namespace Tensorflow.Keras.Layers.Rnn
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
return base.Call(inputs, state, training);
return base.Call(inputs, initial_state, training);
}
}
}

+ 7
- 5
src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs View File

@@ -2,15 +2,16 @@
using System.Collections.Generic;
using System.ComponentModel;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.ArgsDefinition.Rnn;

namespace Tensorflow.Keras.Layers.Rnn
{
public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell
public class StackedRNNCells : Layer
{
public IList<RnnCell> Cells { get; set; }
public IList<IRnnArgCell> Cells { get; set; }
public bool reverse_state_order;

public StackedRNNCells(StackedRNNCellsArgs args) : base(args)
@@ -51,7 +52,7 @@ namespace Tensorflow.Keras.Layers.Rnn
{
return lastCell.output_size;
}
else if (RNN._is_multiple_state(lastCell.state_size))
else if (RNN.is_multiple_state(lastCell.state_size))
{
// return ((dynamic)Cells[-1].state_size)[0];
throw new NotImplementedException("");
@@ -63,6 +64,7 @@ namespace Tensorflow.Keras.Layers.Rnn
}
}


public object get_initial_state()
{
throw new NotImplementedException();
@@ -80,7 +82,7 @@ namespace Tensorflow.Keras.Layers.Rnn
// return tuple(initial_states)
}

public object call()
public Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
throw new NotImplementedException();
// def call(self, inputs, states, constants= None, training= None, ** kwargs):


+ 1
- 1
src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs View File

@@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
{
if (tf.Context.executing_eagerly())
return DeFunCall(inputs);


Loading…
Cancel
Save