From 46b86e845f82318f2a2684c80d485f663e7c7b72 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Sun, 21 May 2023 01:37:08 +0800 Subject: [PATCH 1/4] Draft PR for RNN --- .../APIs/tf.control_flow.cs | 15 + .../Keras/ArgsDefinition/Rnn/RNNArgs.cs | 12 +- .../ArgsDefinition/Rnn/StackedRNNCellsArgs.cs | 3 +- .../NumPy/StateSizeWrapper.cs | 63 ++ .../Operations/NnOps/RNNCell.cs | 3 +- .../Operations/control_flow_ops.cs | 47 ++ src/TensorFlowNET.Core/Util/nest.py.cs | 44 ++ src/TensorFlowNET.Keras/BackendImpl.cs | 540 +++++++++++++++++- src/TensorFlowNET.Keras/Engine/Layer.cs | 4 +- .../Layers/Activation/ELU.cs | 2 +- .../Layers/Activation/Exponential.cs | 2 +- .../Layers/Activation/HardSigmoid.cs | 3 +- .../Layers/Activation/LeakyReLu.cs | 2 +- .../Layers/Activation/SELU.cs | 3 +- .../Layers/Activation/Softmax.cs | 3 +- .../Layers/Activation/Softplus.cs | 3 +- .../Layers/Activation/Softsign.cs | 3 +- .../Layers/Activation/Swish.cs | 3 +- .../Layers/Activation/Tanh.cs | 2 +- .../Layers/Attention/BaseDenseAttention.cs | 2 +- .../Layers/Attention/MultiHeadAttention.cs | 2 +- .../Layers/Convolution/Convolutional.cs | 2 +- src/TensorFlowNET.Keras/Layers/Core/Dense.cs | 2 +- .../Layers/Core/EinsumDense.cs | 2 +- .../Layers/Core/Embedding.cs | 2 +- .../Layers/Merging/Merge.cs | 2 +- .../Normalization/BatchNormalization.cs | 2 +- .../Normalization/LayerNormalization.cs | 2 +- .../Layers/Normalization/Normalization.cs | 2 +- .../Layers/Pooling/GlobalAveragePooling1D.cs | 2 +- .../Layers/Pooling/GlobalAveragePooling2D.cs | 2 +- .../Layers/Pooling/GlobalMaxPooling1D.cs | 2 +- .../Layers/Pooling/GlobalMaxPooling2D.cs | 2 +- .../Layers/Pooling/Pooling1D.cs | 2 +- .../Layers/Pooling/Pooling2D.cs | 2 +- .../Layers/Preprocessing/CategoryEncoding.cs | 2 +- .../Layers/Preprocessing/Rescaling.cs | 2 +- .../Layers/Preprocessing/Resizing.cs | 2 +- .../Layers/Regularization/Dropout.cs | 2 +- .../Layers/Reshaping/Cropping1D.cs | 2 +- .../Layers/Reshaping/Cropping2D.cs | 2 +- .../Layers/Reshaping/Cropping3D.cs | 2 +- .../Layers/Reshaping/Flatten.cs | 2 +- .../Layers/Reshaping/Permute.cs | 2 +- .../Layers/Reshaping/Reshape.cs | 2 +- .../Layers/Reshaping/UpSampling2D.cs | 2 +- .../Layers/Reshaping/ZeroPadding2D.cs | 2 +- src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs | 4 +- src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 422 +++++++++++++- .../Layers/Rnn/SimpleRNNCell.cs | 4 +- .../Layers/Rnn/StackedRNNCells.cs | 12 +- .../Layers/TensorFlowOpLayer.cs | 2 +- 52 files changed, 1187 insertions(+), 70 deletions(-) create mode 100644 src/TensorFlowNET.Core/NumPy/StateSizeWrapper.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs index 239487e0..578f23f9 100644 --- a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs +++ b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs @@ -57,6 +57,21 @@ namespace Tensorflow new[] { loop_vars }); return results[0]; } + public (Tensor, List, Tensors, Tensors) while_loop(Func cond, + Func, Tensors, Tensors, (Tensor, List, Tensors, Tensors)> body, + (Tensor, List, Tensors, Tensors) loop_vars, + int parallel_iterations = 10) + => control_flow_ops.while_loop(cond, + body, + loop_vars); + + public (Tensor, List, Tensors) while_loop(Func cond, + Func, Tensors, (Tensor, List, Tensors)> body, + (Tensor, List, Tensors) loop_vars, + int parallel_iterations = 10) + => control_flow_ops.while_loop(cond, + body, + loop_vars); public Tensor[] while_loop(Func cond, Func body, diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs index 2585592c..911c6721 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs @@ -1,5 +1,9 @@ using Newtonsoft.Json; +using OneOf; using System.Collections.Generic; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.ArgsDefinition.Rnn; +using Tensorflow.NumPy; namespace Tensorflow.Keras.ArgsDefinition.Rnn { @@ -7,11 +11,14 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn { public interface IRnnArgCell : ILayer { - object state_size { get; } + public Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null); + public StateSizeWrapper state_size { get; set; } + public int output_size { get; set; } } [JsonProperty("cell")] // TODO: the cell should be serialized with `serialize_keras_object`. - public IRnnArgCell Cell { get; set; } = null; + public OneOf, IRnnArgCell> Cell { get; set; } + [JsonProperty("return_sequences")] public bool ReturnSequences { get; set; } = false; [JsonProperty("return_state")] @@ -25,6 +32,7 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn [JsonProperty("time_major")] public bool TimeMajor { get; set; } = false; // TODO: Add `num_constants` and `zero_output_for_mask`. + public bool ZeroOutputForMask { get; set; } = false; public Dictionary Kwargs { get; set; } = null; public int Units { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs index fdfadab8..dee7e8d3 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs @@ -1,10 +1,11 @@ using System.Collections.Generic; +using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; namespace Tensorflow.Keras.ArgsDefinition.Rnn { public class StackedRNNCellsArgs : LayerArgs { - public IList Cells { get; set; } + public IList Cells { get; set; } public Dictionary Kwargs { get; set; } = null; } } diff --git a/src/TensorFlowNET.Core/NumPy/StateSizeWrapper.cs b/src/TensorFlowNET.Core/NumPy/StateSizeWrapper.cs new file mode 100644 index 00000000..f2a9e5f2 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/StateSizeWrapper.cs @@ -0,0 +1,63 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Collections; + + +namespace Tensorflow.NumPy +{ + // Since state_size in RNN is a single integer or array of integer, so use StateSizeWrapper to hold it + public class StateSizeWrapper : IEnumerable + { + int[] _state_size; + public int[] state_size => _state_size; + + public StateSizeWrapper(int state_size) + { + _state_size = new int[] { state_size }; + } + + public StateSizeWrapper(params int[] state_size) + { + _state_size = state_size; + } + public StateSizeWrapper(IEnumerable state_size) + { + _state_size = state_size.ToArray(); + } + + public static implicit operator StateSizeWrapper(int[] state_size) + => new StateSizeWrapper(state_size); + + public static implicit operator StateSizeWrapper(int state_size) + => new StateSizeWrapper(state_size); + + public static implicit operator StateSizeWrapper((int, int) state_size) + => new StateSizeWrapper(state_size.Item1, state_size.Item2); + + public static implicit operator StateSizeWrapper(List v) + => new StateSizeWrapper(v); + public override string ToString() + { + return $"{state_size}"; + } + + public int this[int n] + { + get => n < 0 ? state_size[state_size.Length + n] : state_size[n]; + set => state_size[n] = value; + } + + public IEnumerator GetEnumerator() + { + return state_size.ToList().GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} + + diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index ecc9ca11..d49c8218 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -26,6 +26,7 @@ using Tensorflow.Operations; using Tensorflow.Train; using Tensorflow.Util; using static Tensorflow.Binding; +using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; namespace Tensorflow { @@ -50,7 +51,7 @@ namespace Tensorflow /// matching structure of Tensors having shape `[batch_size].concatenate(s)` /// for each `s` in `self.batch_size`. /// - public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell + public abstract class RnnCell : ILayer { /// /// Attribute that indicates whether the cell is a TF RNN cell, due the slight diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 862b636f..c59e5b16 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -698,6 +698,53 @@ namespace Tensorflow }); } + public static (Tensor, List, Tensors, Tensors) while_loop(Func cond, + Func, Tensors, Tensors, (Tensor, List, Tensors, Tensors)> body, + (Tensor, List, Tensors, Tensors) loop_vars, + int parallel_iterations = 10, + string name = null) + { + var executing_eagerly = tf.Context.executing_eagerly(); + if (!executing_eagerly) + { + throw new NotImplementedException(""); + } + + return tf_with(ops.name_scope("name", "while"), delegate + { + while ((bool)cond(loop_vars.Item1)) + { + loop_vars = body(loop_vars.Item1, loop_vars.Item2, loop_vars.Item3, loop_vars.Item4); + } + + return loop_vars; + }); + } + + public static (Tensor, List, Tensors) while_loop(Func cond, + Func, Tensors, (Tensor, List, Tensors)> body, + (Tensor, List, Tensors) loop_vars, + int parallel_iterations = 10, + string name = null) + { + var executing_eagerly = tf.Context.executing_eagerly(); + if (!executing_eagerly) + { + throw new NotImplementedException(""); + } + + return tf_with(ops.name_scope("name", "while"), delegate + { + while ((bool)cond(loop_vars.Item1)) + { + loop_vars = body(loop_vars.Item1, loop_vars.Item2, loop_vars.Item3); + } + + return loop_vars; + }); + } + + /// /// Repeat `body` while the condition `cond` is true. /// diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index eb94f4d0..2879fa8e 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -211,6 +211,28 @@ namespace Tensorflow.Util => arg is IEnumerable && !(arg is string) && !(arg is NDArray) && !(arg.GetType().IsGenericType && arg.GetType().GetGenericTypeDefinition() == typeof(HashSet<>)); + public static bool is_nested(object obj) + { + // Check if the object is an IEnumerable + if (obj is IEnumerable) + { + // If it is, check if it is a nested structure + foreach (object item in (IEnumerable)obj) + { + if (is_nested(item)) + { + return true; + } + } + return true; + } + else + { + // If it is not, return false + return false; + } + } + public static bool is_mapping(object arg) => arg is IDictionary; //# See the swig file (util.i) for documentation. @@ -263,7 +285,29 @@ namespace Tensorflow.Util } } + public static List FlattenTupple(object tuple) + { + List items = new List(); + var type = tuple.GetType(); + + if (type.GetInterface("ITuple") == null) + throw new ArgumentException("This is not a tuple!"); + foreach (var property in type.GetProperties()) + { + var value = property.GetValue(tuple); + if (property.PropertyType.GetInterface("ITuple") != null) + { + var subItems = FlattenTupple(value); + items.AddRange(subItems); + } + else + { + items.Add((T)value); + } + } + return items; + } //# See the swig file (util.i) for documentation. //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 80403ad6..da1d25c9 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -22,6 +22,9 @@ using Tensorflow.Functions; using Tensorflow.Graphs; using static Tensorflow.Binding; using static Tensorflow.Graphs.SubGraphUtility; +using Tensorflow.Util; +using Tensorflow.Operations; +using OneOf; namespace Tensorflow.Keras { @@ -65,7 +68,7 @@ namespace Tensorflow.Keras return; } var graph = v.Graph; - if(graph is null) + if (graph is null) { graph = get_graph(); } @@ -95,7 +98,7 @@ namespace Tensorflow.Keras { if (_GRAPH == null) _GRAPH = new FuncGraph("keras_graph"); - + return _GRAPH; } return ops.get_default_graph(); @@ -105,7 +108,7 @@ namespace Tensorflow.Keras { if (_CURRENT_SCRATCH_GRAPH == null) _CURRENT_SCRATCH_GRAPH = new FuncGraph("keras_scratch_graph"); - + return _CURRENT_SCRATCH_GRAPH; } @@ -230,16 +233,16 @@ namespace Tensorflow.Keras { if (outputs[0].op.type == "Const") return tensor_util.constant_value(outputs); - + var source_graph = outputs.graph; var exec_graph = _scratch_graph(); var global_graph = get_graph(); if (source_graph == global_graph && exec_graph != global_graph) { - var lifted_map = lift_to_graph(outputs, exec_graph, - new List(), - add_sources: true, - handle_captures: true, + var lifted_map = lift_to_graph(outputs, exec_graph, + new List(), + add_sources: true, + handle_captures: true, base_graph: source_graph); } if (outputs[0].op.type == "Placeholder" @@ -250,7 +253,7 @@ namespace Tensorflow.Keras exec_graph.as_default(); exec_graph.Inputs = exec_graph.internal_captures; exec_graph.Outputs = outputs; - + var graph_fn = new ConcreteFunction(exec_graph); _CURRENT_SCRATCH_GRAPH = null; @@ -370,7 +373,7 @@ namespace Tensorflow.Keras /// /// /// - public Tensor resize_images(Tensor x, int height_factor, int width_factor, + public Tensor resize_images(Tensor x, int height_factor, int width_factor, string data_format, string interpolation = "nearest") { var (rows, cols) = (0, 0); @@ -412,7 +415,7 @@ namespace Tensorflow.Keras /// public Tensor concatenate(Tensors tensors, int axis = -1) { - if(axis < 0) + if (axis < 0) { var rank = tensors[0].ndim; if (rank > -1) @@ -450,5 +453,520 @@ namespace Tensorflow.Keras return x; } + + public static (Tensors, Tensors) convert_inputs_if_ragged(OneOf inputs) + { + throw new NotImplementedException(); + } + + // + public static (Tensors, Tensors, Tensors) rnn( + Func step_function, // args:inputs, states, return:output, new_states + Tensors inputs, // inputs is a tuple of tensors (one per input sequence) + Tensors initial_states, + bool go_backwards = false, + Tensor? mask = null, + Tensors? constants = null, + bool unroll = false, + Tensors? input_length = null, // An integer or a 1-D Tensor,depending on whether the time dimension is fixed-length or not + bool time_major = false, + bool zero_output_for_mask = false, + bool return_all_outputs = true) + { + + Tensors swap_batch_timestep(Tensors input_t) + { + var axes = Enumerable.Range(0, input_t.rank).ToArray(); + axes[0] = 1; + axes[1] = 0; + return tf.transpose(input_t, axes); + } + + if (!time_major) + { + inputs = nest.map_structure(swap_batch_timestep, inputs); + } + + var flatted_inptus = nest.flatten(inputs); + var time_steps = flatted_inptus[0].shape[0]; + var batch = flatted_inptus[0].shape[1]; + var time_step_t = tf.shape(flatted_inptus[0])[0]; + + foreach (var input_ in flatted_inptus) + { + input_.shape.with_rank_at_least(3); + } + + if (mask != null) + { + if (mask.dtype != TF_DataType.TF_BOOL) + { + mask = tf.cast(mask, TF_DataType.TF_BOOL); + } + + if (mask.rank == 2) + { + mask = tf.expand_dims(mask, -1); + } + + if (!time_major) + { + mask = swap_batch_timestep(mask); + } + + } + + if (constants == null) + { + constants = new List(); + } + + // tf.where needs its condition tensor to be the same shape as its two + // result tensors, but in our case the condition (mask) tensor is + // (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. + // So we need to broadcast the mask to match the shape of inputs. + // That's what the tile call does, it just repeats the mask along its + // second dimension n times. + + Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1) + { + if (nest.is_nested(mask_t)) + { + throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}"); + } + + if (nest.is_nested(input_t)) + { + throw new ValueError($"input_t is expected to be tensor, but got {input_t}"); + } + + var rank_diff = input_t.rank - mask_t.rank; + for (int i = 0; i < rank_diff; i++) + { + mask_t = tf.expand_dims(mask_t, -1); + } + var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank)); + return tf.tile(mask_t, multiples); + } + + Tensors outputs = new Tensors(); + Tensors output_time_zero = new Tensors(); + Tensors last_output = new Tensors(); + Tensors new_states = new Tensors(); + if (unroll) + { + if (time_steps == 0) + { + throw new ValueError("Unrolling requires a fixed number of timesteps."); + } + + // Process the input tensors. The input tensor need to be split on the + // time_step dim, and reverse if go_backwards is True. In the case of + // nested input, the input is flattened and then transformed + // individually. The result of this will be a tuple of lists, each of + // the item in tuple is list of the tensor with shape (batch, feature) + + + // TODO(Wanglongzhi2001),step_func接受的第二个参数为List,但是最后却用的tuple + //var states = Tuple.Create(initial_states); + var states = initial_states; + + var successive_states = new Tensors(); + var successive_outputs = new Tensors(); + + // Process the input tensors. The input tensor need to be split on the + // time_step dim, and reverse if go_backwards is True. In the case of + // nested input, the input is flattened and then transformed + // individually. The result of this will be a tuple of lists, each of + // the item in tuple is list of the tensor with shape (batch, feature) + + + + + Tensors _process_single_input_t(Tensors input_t) + { + input_t = tf.unstack(input_t); // unstack for time_step dim + if (go_backwards) + { + input_t.Reverse(); + } + return input_t; + } + + // TODO(Wanglongzhi2001) + Tensors processed_input; + if (nest.is_nested(inputs)) + { + processed_input = nest.map_structure(_process_single_input_t, inputs); + } + else + { + processed_input = _process_single_input_t(inputs); + } + + object _get_input_tensor(int time) + { + List inp = new List(); + foreach (var t_ in processed_input) + { + inp.Add(t_[time]); + } + return nest.pack_sequence_as(inputs, inp); + } + + if (mask != null) + { + var mask_list = tf.unstack(mask); + if (go_backwards) + { + mask_list.Reverse(); + } + + for (int i = 0; i < time_steps; i++) + { + // TODO(Wanglongzhi2001),deal with _get_input_tensor + var inp = _get_input_tensor(i); + var mask_t = mask_list[i]; + // TODO + var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); + + var tiled_mask_t = _expand_mask(mask_t, output); + + Tensors prev_output; + if (successive_outputs == null) + { + prev_output = tf.zeros_like(output); + } + else + { + prev_output = successive_outputs[successive_outputs.Length - 1]; + } + + output = tf.where(tiled_mask_t, output, prev_output); + + //var flat_states = nest.flatten(states); + //var flat_new_states = nest.flatten(newStates); + var flat_states = states.ToList(); + var flat_new_states = newStates.ToList(); + + var tiledMaskT = flat_states + .Select(s => _expand_mask(mask_t, s)) + .ToArray(); + var tuple = Tuple.Create(tiledMaskT); + + List flat_final_states = new List(); + foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states)) + { + flat_final_states.Add(tf.where(m, s, ps)); + } + + states = (Tensors)nest.pack_sequence_as(states, flat_final_states); + if (return_all_outputs) + { + successive_outputs.Add(output); + successive_states.Add(states); + } + else + { + successive_outputs = new Tensors { output }; + successive_states = new Tensors { states }; + } + + } + last_output = successive_outputs[successive_outputs.Length - 1]; + new_states = successive_states[successive_states.Length - 1]; + outputs = tf.stack(successive_outputs); + + if (zero_output_for_mask) + { + last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); + outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); + } + else // mask is null + { + for (int i = 0; i < time_steps; i++) + { + var inp = _get_input_tensor(i); + var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); + states = newStates; + + if (return_all_outputs) + { + successive_outputs.Add(output); + successive_states.Add(newStates); + } + else + { + successive_outputs = new Tensors { output }; + successive_states = new Tensors { newStates }; + } + } + last_output = successive_outputs[successive_outputs.Length - 1]; + new_states = successive_states[successive_states.Length - 1]; + outputs = tf.stack(successive_outputs); + } + } + } + else // unroll == false + { + var states = initial_states; + // Create input tensor array, if the inputs is nested tensors, then it + // will be flattened first, and tensor array will be created one per + // flattened tensor. + var input_ta = new List(); + for (int i = 0; i < flatted_inptus.Count; i++) + { + input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t)); + } + + // Get the time(0) input and compute the output for that, the output will + // be used to determine the dtype of output tensor array. Don't read from + // input_ta due to TensorArray clear_after_read default to True. + var inps = new Tensors(); + foreach (var inp in flatted_inptus) + { + inps.Add(inp[0]); + } + var input_time_zero = nest.pack_sequence_as(inputs, inps); + + // output_time_zero is used to determine the cell output shape and its + // dtype. the value is discarded. + (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants }); + + var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1); + var output_ta = new List(); + for (int i = 0; i < output_time_zero.ToList().Count; i++) + { + var Out = output_time_zero.ToList()[i]; + output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); + } + + var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); + + + + Func? masking_fn; + Func? compute_masked_output = null; + if (mask != null) + { + if (go_backwards) + { + mask = tf.reverse(mask, axis: new[] { 0 }); + } + var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t); + mask_ta = mask_ta.unstack(mask); + + masking_fn = (time) => + { + return mask_ta.read(time); + }; + + compute_masked_output = (mask_t, flat_out, flat_mask) => + { + var tiled_mask_t = new Tensors(); + foreach (var o in flat_out) + { + tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); + } + + Tensors res = new Tensors(); + foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask)) + { + res.Add(tf.where(m, o, fm)); + } + return res; + }; + } + // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? + else if (input_length is Tensor) + { + if (go_backwards) + { + var max_len = tf.reduce_max(input_length, axis: 0); + var rev_input_length = tf.subtract(max_len - 1, input_length); + + masking_fn = (time) => + { + return tf.less(rev_input_length, time); + }; + } + else + { + masking_fn = (time) => + { + return tf.greater(input_length, time); + }; + } + + compute_masked_output = (mask_t, flat_out, flat_mask) => + { + var res = new List(); + foreach (var (o, zo) in zip(flat_out, flat_mask)) + { + res.Add(tf.where(mask_t, o, zo)); + } + return res; + }; + } + else + { + masking_fn = null; + } + + + if (masking_fn != null) + { + // Mask for the T output will be base on the output of T - 1. In the + // case T = 0, a zero filled tensor will be used. + var flat_zero_output = new Tensors(); + foreach (var o in nest.flatten(output_time_zero)) + { + flat_zero_output.Add(tf.zeros_like(o)); + } + + + (Tensor, List, Tensors, Tensors) _step(Tensor time, List output_ta_t, Tensors prev_output, Tensors states) + { + /* + RNN step function. + Args: + time: Current timestep value. + output_ta_t: TensorArray. + prev_output: tuple of outputs from time - 1. + *states: List of states. + Returns: + Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` + */ + + var current_input = input_ta.Select(x => x.read(time)).ToList(); + // maybe set shape + // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + current_input = (List)nest.pack_sequence_as(inputs, current_input); + var mask_t = masking_fn(time); + var (output, new_states) = step_function(current_input, new Tensors { states, constants }); + // mask output + //var flat_output = nest.flatten(output); + var flat_output = output.ToList(); + + var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); + + // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type + var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); + + // mask states + var flat_state = states.ToList(); + var flat_new_state = new_states.ToList(); + + foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + { + if (new_state is Tensor) + { + new_state.set_shape(state.shape); + } + } + + var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); + new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state); + + var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + var Output_ta_t = new List(); + // TODO(Wanglongzhi2001),deal with zip output_ta_t + foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) + { + Output_ta_t.Add(ta.write(ta_index_to_write, Out)); + } + + + + //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); + + + return (time + 1, Output_ta_t, flat_new_output, new_states); + + } + Func cond = (time) => (time < time_step_t); + + var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states)); + new_states = final_outputs.Item4; + output_ta = final_outputs.Item2; + + } + else + { + (Tensor, List, Tensors) _step(Tensor time, List output_ta_t, Tensors states) + { + var current_input = input_ta.Select(x => x.read(time)).ToList(); + // maybe set shape + // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + current_input = (List)nest.pack_sequence_as(inputs, current_input); + var (output, new_states) = step_function(current_input, new Tensors { states, constants }); + var flat_state = states.ToList(); + var flat_new_state = new_states.ToList(); + foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + { + if (new_state is Tensor) + { + new_state.set_shape(state.shape); + } + } + var flat_output = output.ToList(); + var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + var Output_ta_t = new List(); + foreach (var (ta, out_) in zip(output_ta_t, flat_output)) + { + Output_ta_t.Add(ta.write(ta_index_to_write, out_)); + } + + new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); + return (time + 1, Output_ta_t, new_states); + } + Func cond = (time) => (time < time_step_t); + var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states)); + new_states = final_outputs.Item3; + output_ta = final_outputs.Item2; + + } + //Tensors outputs = new Tensors(); + foreach (var o in output_ta) + { + outputs.Add(o.stack()); + } + foreach (var o in outputs) + { + last_output.Add(o[-1]); + } + outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs); + last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output); + + } + + Func set_shape; + set_shape = (output_) => + { + if (output_ is Tensor) + { + var shape = output_.shape.as_int_list(); + if (return_all_outputs) + { + shape[0] = (int)time_steps; + } + else + { + shape[0] = 1; + } + shape[1] = (int)batch; + output_.set_shape(new Tensor(shape)); + } + return output_; + }; + + var Outputs = (Tensors)nest.map_structure(set_shape, outputs); + if (!time_major) + { + Outputs = nest.map_structure(swap_batch_timestep, outputs); + } + return (last_output, Outputs, new_states); + + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 5942efd9..4216c725 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -332,9 +332,9 @@ namespace Tensorflow.Keras.Engine /// /// /// - protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected virtual Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { - if(ReplacedCall is not null) + if (ReplacedCall is not null) { return ReplacedCall(inputs); } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs index 739c0d56..9fb8781e 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs @@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers { base.build(input_shape); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor output = inputs; output = tf.where(output > 0f, output, diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs index 17636302..2f618f63 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers { { base.build(input_shape); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor output = inputs; return tf.exp(output); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs b/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs index b498d1b9..efea135b 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs @@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers { public HardSigmoid ( LayerArgs args ) : base(args) { // hard sigmoid has no arguments } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { Tensor x = inputs; return tf.clip_by_value( tf.add(tf.multiply(x, 0.2f), 0.5f), 0f, 1f); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs b/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs index 1fbbf4ea..feb98a0b 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { return tf.nn.leaky_relu(inputs, alpha: alpha); } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs index 53101fbb..b444e338 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs @@ -22,7 +22,8 @@ namespace Tensorflow.Keras.Layers { } base.build(input_shape); } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { Tensor output = inputs; return tf.where(output > 0f, tf.multiply(scale, output), diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs index 3ffae27f..62d2461e 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs @@ -11,7 +11,8 @@ namespace Tensorflow.Keras.Layers { public Softmax ( SoftmaxArgs args ) : base(args) { axis = args.axis; } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9) : inputs; Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true))); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs index e82b0198..13dfad4e 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs @@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers { public Softplus ( LayerArgs args ) : base(args) { // Softplus has no arguments } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { Tensor x = inputs; return tf.log( tf.add(tf.exp(x), 1f)); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs b/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs index 59329fd4..9933db5f 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs @@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers { public Softsign ( LayerArgs args ) : base(args) { // Softsign has no arguments } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { Tensor x = inputs; // x / (abs(x) + 1) return tf.div(x, tf.add(1f, tf.abs(x))); diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs b/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs index 1dcb92b3..727d385d 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Swish.cs @@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers { public Swish ( LayerArgs args ) : base(args) { // Swish has no arguments } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { Tensor x = inputs; // x / (1 + exp(-x)) diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs b/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs index 99b80394..802b894e 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs @@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Layers { // Tanh has no arguments } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor x = inputs; diff --git a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs index 1348e19c..fe37d860 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs @@ -114,7 +114,7 @@ namespace Tensorflow.Keras.Layers return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensors _inp; Tensors _mask = null; diff --git a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs index 701724d5..f3fee090 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs @@ -252,7 +252,7 @@ namespace Tensorflow.Keras.Layers return (attention_output, attention_scores); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensors _inp; Tensor _mask = null; diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs index c575362c..cf0c6d2b 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs @@ -103,7 +103,7 @@ namespace Tensorflow.Keras.Layers _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var outputs = _convolution_op.Apply(inputs, kernel.AsTensor()); if (use_bias) diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index aa6617dd..f574fd53 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -69,7 +69,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor outputs = null; var rank = inputs.rank; diff --git a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs index fb604f77..9aacf8f1 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs @@ -189,7 +189,7 @@ namespace Tensorflow.Keras.Layers // return new dict(base_config.items().ToList() + config.items().ToList()); //} - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var ret = tf.linalg.einsum(this.equation, (inputs, this.kernel.AsTensor())); if (this.bias != null) diff --git a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs index 9487a7d0..6e074978 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs @@ -66,7 +66,7 @@ namespace Tensorflow.Keras.Layers _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs index 7df654ee..2d7a1e7d 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { return _merge_function(inputs); } diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index d02d2509..2af14cc7 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -146,7 +146,7 @@ namespace Tensorflow.Keras.Layers return false; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor outputs = null; var training_tensor = training == null diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs index e90c0402..e708d6a8 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs @@ -101,7 +101,7 @@ namespace Tensorflow.Keras.Layers return input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor outputs = null; var inputs_dtype = inputs.dtype.as_base_dtype(); diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs index a65154bf..978d1029 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs @@ -157,7 +157,7 @@ namespace Tensorflow.Keras.Layers base.adapt(data, batch_size: batch_size, steps: steps); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (_args.Invert) { diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs index d62fb63a..21a21406 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (data_format == "channels_last") return math_ops.reduce_mean(inputs, 1, false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs index 000e4b8b..e03050a9 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (data_format == "channels_last") return math_ops.reduce_mean(inputs, (1, 2), false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs index 2de4671c..1a8f06dd 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (data_format == "channels_last") return math_ops.reduce_max(inputs, 1, false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs index b7e2c945..9ce002f0 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers { } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (data_format == "channels_last") return math_ops.reduce_max(inputs, (1, 2), false); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs index a2f4c51b..65c6130d 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs @@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers input_spec = new InputSpec(ndim: 3); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { int pad_axis = args.DataFormat == "channels_first" ? 2 : 3; inputs = tf.expand_dims(inputs, pad_axis); diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs index 27032255..4804d0ab 100644 --- a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs @@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers input_spec = new InputSpec(ndim: 4); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { int[] pool_shape; int[] strides; diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs index 5620a916..a7e1fd19 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs @@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var depth = args.NumTokens; var max_value = tf.reduce_max(inputs); diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs index 5fc581af..99194ca6 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { scale = constant_op.constant(args.Scale, args.DType); offset = constant_op.constant(args.Offset, args.DType); diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs index 603e2b07..67e4b464 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { return image_ops_impl.resize_images_v2(inputs, new[] { args.Height, args.Width }, method: args.Interpolation); } diff --git a/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs b/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs index aa3a92a4..696ab5b9 100644 --- a/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs +++ b/src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs @@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (training == null) training = false; diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs index 9ead15cb..cf93c169 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs @@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Layers.Reshaping _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor output = inputs; if (output.rank != 3) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs index 087d59a1..7872b0b0 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers.Reshaping built = true; _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor output = inputs; if (output.rank != 4) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs index 04a1af60..5bc2433b 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers.Reshaping _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor output = inputs; if (output.rank != 5) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs index 539b5f62..8ff34134 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs @@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Layers _channels_first = args.DataFormat == "channels_first"; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (_channels_first) { diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs index e391775c..79f0b569 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs @@ -28,7 +28,7 @@ namespace Tensorflow.Keras.Layers { built = true; _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { Tensor outputs = inputs; return tf.transpose(outputs, new Axis(permute)); diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs index 92a772f3..8a4e4d5f 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var shapes = new List(); shapes.Add(array_ops.shape(inputs)[0]); diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs index 8314151f..7e926dee 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs @@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers inputSpec = new InputSpec(ndim: 4); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { return keras.backend.resize_images(inputs, size[0], size[1], diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs index 7c87100a..c68def38 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs @@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { return keras.backend.spatial_2d_padding(inputs, padding: padding, diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs index 59555e62..530c409e 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs @@ -26,9 +26,9 @@ namespace Tensorflow.Keras.Layers.Rnn .ToArray(); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { - return base.Call(inputs, state: state, training: training); + return base.Call(inputs, initial_state: initial_state, training: training); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 310e8057..7bd4047a 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -1,9 +1,15 @@ using System; +using System.Collections; using System.Collections.Generic; -using Tensorflow.Keras.ArgsDefinition; +using System.Reflection; +using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Util; +using OneOf; +using OneOf.Types; +using Tensorflow.Common.Extensions; // from tensorflow.python.distribute import distribution_strategy_context as ds_context; namespace Tensorflow.Keras.Layers.Rnn @@ -19,11 +25,46 @@ namespace Tensorflow.Keras.Layers.Rnn protected IVariableV1 kernel; protected IVariableV1 bias; protected ILayer cell; + public RNN(RNNArgs args) : base(PreConstruct(args)) { this.args = args; SupportsMasking = true; + // if is StackedRnncell + if (args.Cell.IsT0) + { + cell = new StackedRNNCells(new StackedRNNCellsArgs + { + Cells = args.Cell.AsT0, + }); + } + else + { + cell = args.Cell.AsT1; + } + + + + + + Type type = cell.GetType(); + MethodInfo methodInfo = type.GetMethod("Call"); + if (methodInfo == null) + { + throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. "); + } + + PropertyInfo propertyInfo = type.GetProperty("state_size"); + if (propertyInfo == null) + { + throw new ValueError(@"The RNN cell should have a `state_size` attribute"); + } + + + + // get input_shape + this.args = PreConstruct(args); // The input shape is unknown yet, it could have nested tensor inputs, and // the input spec will be the list of specs for nested inputs, the structure // of the input_spec will be the same as the input. @@ -37,17 +78,384 @@ namespace Tensorflow.Keras.Layers.Rnn //} } + // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...) + // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape + public object States + { + get + { + if (_states == null) + { + var state = nest.map_structure(x => null, cell.state_size); + return nest.is_nested(state) ? state : new Tensors { state }; + } + return _states; + } + set { _states = value; } + } + + private OneOf> compute_output_shape(Shape input_shape) + { + var batch = input_shape[0]; + var time_step = input_shape[1]; + if (args.TimeMajor) + { + (batch, time_step) = (time_step, batch); + } + + // state_size is a array of ints or a positive integer + var state_size = cell.state_size; + + + // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor + Func _get_output_shape; + _get_output_shape = (flat_output_size) => + { + var output_dim = flat_output_size.as_int_list(); + Shape output_shape; + if (args.ReturnSequences) + { + if (args.TimeMajor) + { + output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim)); + } + else + { + output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim)); + + } + } + else + { + output_shape = new Shape(new int[] { (int)batch }.concat(output_dim)); + } + return output_shape; + }; + + Shape output_shape; + if (cell.output_size != 0) + { + output_shape = nest.map_structure(_get_output_shape, cell.output_size); + // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 + output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape); + } + else + { + output_shape = _get_output_shape(state_size[0]); + } + + if (args.ReturnState) + { + Func _get_state_shape; + _get_state_shape = (flat_state) => + { + var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); + return new Shape(state_shape); + }; + var state_shape = _get_state_shape(new Shape(state_size.ToArray())); + + return new List { output_shape, state_shape }; + } + else + { + return output_shape; + } + } + + private Tensors compute_mask(Tensors inputs, Tensors mask) + { + // Time step masks must be the same for each input. + // This is because the mask for an RNN is of size [batch, time_steps, 1], + // and specifies which time steps should be skipped, and a time step + // must be skipped for all inputs. + + mask = nest.flatten(mask)[0]; + var output_mask = args.ReturnSequences ? mask : null; + if (args.ReturnState) + { + var state_mask = new List(); + for (int i = 0; i < len(States); i++) + { + state_mask.Add(null); + } + return new List { output_mask }.concat(state_mask); + } + else + { + return output_mask; + } + + + } + public override void build(KerasShapesWrapper input_shape) { + object get_input_spec(Shape shape) + { + var input_spec_shape = shape.as_int_list(); + + var (batch_index, time_step_index) = args.TimeMajor ? (1, 0) : (0, 1); + if (!args.Stateful) + { + input_spec_shape[batch_index] = -1; + } + input_spec_shape[time_step_index] = -1; + return new InputSpec(shape: input_spec_shape); + } + + Shape get_step_input_shape(Shape shape) + { + + // return shape[1:] if self.time_major else (shape[0],) + shape[2:] + if (args.TimeMajor) + { + return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray(); + } + else + { + return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray()); + } + + + } + + object get_state_spec(Shape shape) + { + var state_spec_shape = shape.as_int_list(); + // append bacth dim + state_spec_shape = new int[] { -1 }.concat(state_spec_shape); + return new InputSpec(shape: state_spec_shape); + + } + + // Check whether the input shape contains any nested shapes. It could be + // (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from + // numpy inputs. + + if (!cell.Built) { cell.build(input_shape); } } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + // inputs: Tensors + // mask: Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked + // training: bool + // initial_state: List of initial state tensors to be passed to the first call of the cell + // constants: List of constant tensors to be passed to the cell at each timestep + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { - return base.Call(inputs, state, training); + //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); + //bool is_ragged_input = row_length != null; + //_validate_args_if_ragged(is_ragged_input, mask); + var (inputs_processed, initial_state_processed, constants_processed) = _process_inputs(inputs, initial_state, constants); + + _maybe_reset_cell_dropout_mask(cell); + if (cell is StackedRNNCells) + { + foreach (var cell in ((StackedRNNCells)cell).Cells) + { + _maybe_reset_cell_dropout_mask(cell); + } + } + + if (mask != null) + { + // Time step masks must be the same for each input. + //mask = nest.flatten(mask)[0]; + mask = mask[0]; + } + + + Shape input_shape; + if (nest.is_nested(initial_state_processed)) + { + // In the case of nested input, use the first element for shape check + // input_shape = nest.flatten(inputs)[0].shape; + input_shape = inputs[0].shape; + } + else + { + input_shape = inputs.shape; + } + + var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1]; + + if (args.Unroll && timesteps != null) + { + throw new ValueError( + "Cannot unroll a RNN if the " + + "time dimension is undefined. \n" + + "- If using a Sequential model, " + + "specify the time dimension by passing " + + "an `input_shape` or `batch_input_shape` " + + "argument to your first layer. If your " + + "first layer is an Embedding, you can " + + "also use the `input_length` argument.\n" + + "- If using the functional API, specify " + + "the time dimension by passing a `shape` " + + "or `batch_shape` argument to your Input layer." + ); + } + + // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call) + var cell_call_fn = cell.Call; + Func step; + if (constants != null) + { + ParameterInfo[] parameters = cell_call_fn.GetMethodInfo().GetParameters(); + bool hasParam = parameters.Any(p => p.Name == "constants"); + if (!hasParam) + { + throw new ValueError( + $"RNN cell {cell} does not support constants." + + $"Received: constants={constants}"); + } + + step = (inputs, states) => + { + // constants = states[-self._num_constants :] + constants = states.numpy()[new Slice(states.Length - _num_constants, states.Length)]; + // states = states[: -self._num_constants] + states = states.numpy()[new Slice(0, states.Length - _num_constants)]; + // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) + states = states.Length == 1 ? states[0] : states; + var (output, new_states) = cell_call_fn(inputs, null, null, states, constants); + if (!nest.is_nested(new_states)) + { + return (output, new Tensors { new_states }); + } + return (output, new_states); + }; + } + else + { + step = (inputs, states) => + { + // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) + states = states.Length == 1 ? states[0] : states; + var (output, new_states) = cell_call_fn(inputs, null, null, states, constants); + if (!nest.is_nested(new_states)) + { + return (output, new Tensors { new_states }); + } + return (output, new_states); + }; + } + + var (last_output, outputs, states) = BackendImpl.rnn(step, + inputs, + initial_state, + constants: constants, + go_backwards: args.GoBackwards, + mask: mask, + unroll: args.Unroll, + input_length: row_length != null ? row_length : new Tensor(timesteps), + time_major: args.TimeMajor, + zero_output_for_mask: args.ZeroOutputForMask, + return_all_outputs: args.ReturnSequences); + + if (args.Stateful) + { + throw new NotImplementedException("this argument havn't been developed!"); + } + + Tensors output = new Tensors(); + if (args.ReturnSequences) + { + throw new NotImplementedException("this argument havn't been developed!"); + + } + else + { + output = last_output; + } + + if (args.ReturnState) + { + + foreach (var state in states) + { + output.Add(state); + } + return output; + } + else + { + return output; + } + } + + private (Tensors, Tensors, Tensors) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants) + { + bool IsSequence(object obj) + { + // Check if the object is an IEnumerable + if (obj is IEnumerable) + { + // If it is, check if it is a tuple + if (!(obj is Tuple)) + { + return true; + } + } + // If it is not, return false + return false; + } + + if (IsSequence(input)) + { + if (_num_constants != 0) + { + initial_state = inputs[new Slice(1, len(inputs))]; + } + else + { + initial_state = inputs[new Slice(1, len(inputs) - _num_constants)]; + } + if (len(initial_state) == 0) + initial_state = null; + inputs = inputs[0]; + } + + if (args.Stateful) + { + throw new NotImplementedException("argument stateful has not been implemented!"); + + } + + return (inputs, initial_state, constants); + + } + + private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) + { + if (is_ragged_input) + { + if (args.Unroll) + { + throw new ValueError("The input received contains RaggedTensors and does " + + "not support unrolling. Disable unrolling by passing " + + "`unroll=False` in the RNN Layer constructor."); + } + if (mask != null) + { + throw new ValueError($"The mask that was passed in was {mask}, which " + + "cannot be applied to RaggedTensor inputs. Please " + + "make sure that there is no mask injected by upstream " + + "layers."); + } + } + } + + void _maybe_reset_cell_dropout_mask(ILayer cell) + { + //if (cell is DropoutRNNCellMixin) + //{ + // cell.reset_dropout_mask(); + // cell.reset_recurrent_dropout_mask(); + //} } private static RNNArgs PreConstruct(RNNArgs args) @@ -77,6 +485,10 @@ namespace Tensorflow.Keras.Layers.Rnn return args; } + public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null) + { + throw new NotImplementedException(); + } public RNN New(LayerRnnCell cell, bool return_sequences = false, bool return_state = false, @@ -95,7 +507,7 @@ namespace Tensorflow.Keras.Layers.Rnn TimeMajor = time_major }); - public RNN New(IList cell, + public RNN New(IList cell, bool return_sequences = false, bool return_state = false, bool go_backwards = false, @@ -125,7 +537,7 @@ namespace Tensorflow.Keras.Layers.Rnn } // Check whether the state_size contains multiple states. - public static bool _is_multiple_state(object state_size) + public static bool is_multiple_state(object state_size) { var myIndexerProperty = state_size.GetType().GetProperty("Item"); return myIndexerProperty != null diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 46061b21..86985d7e 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -42,9 +42,9 @@ namespace Tensorflow.Keras.Layers.Rnn built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { - return base.Call(inputs, state, training); + return base.Call(inputs, initial_state, training); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs index 20962df1..8e67f8e8 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -2,15 +2,16 @@ using System.Collections.Generic; using System.ComponentModel; using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.ArgsDefinition.Rnn; +using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Keras.ArgsDefinition.Rnn; namespace Tensorflow.Keras.Layers.Rnn { - public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell + public class StackedRNNCells : Layer { - public IList Cells { get; set; } + public IList Cells { get; set; } public bool reverse_state_order; public StackedRNNCells(StackedRNNCellsArgs args) : base(args) @@ -51,7 +52,7 @@ namespace Tensorflow.Keras.Layers.Rnn { return lastCell.output_size; } - else if (RNN._is_multiple_state(lastCell.state_size)) + else if (RNN.is_multiple_state(lastCell.state_size)) { // return ((dynamic)Cells[-1].state_size)[0]; throw new NotImplementedException(""); @@ -63,6 +64,7 @@ namespace Tensorflow.Keras.Layers.Rnn } } + public object get_initial_state() { throw new NotImplementedException(); @@ -80,7 +82,7 @@ namespace Tensorflow.Keras.Layers.Rnn // return tuple(initial_states) } - public object call() + public Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { throw new NotImplementedException(); // def call(self, inputs, states, constants= None, training= None, ** kwargs): diff --git a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs index 1ac4a277..6dfc089b 100644 --- a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs @@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (tf.Context.executing_eagerly()) return DeFunCall(inputs); From bad4e8160570fea5c786bd17b93c2f246a92aa36 Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Fri, 2 Jun 2023 20:55:22 +0800 Subject: [PATCH 2/4] update draft pr for RNN --- .../Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs | 5 +- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 6 + .../Keras/Layers/ILayersApi.cs | 8 + .../Operations/NnOps/RNNCell.cs | 7 + src/TensorFlowNET.Core/Util/nest.py.cs | 68 +- src/TensorFlowNET.Keras/BackendImpl.cs | 670 +++++++++--------- src/TensorFlowNET.Keras/Engine/Functional.cs | 2 +- src/TensorFlowNET.Keras/Engine/Layer.Apply.cs | 2 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 103 +-- src/TensorFlowNET.Keras/Engine/Sequential.cs | 6 +- .../Layers/Convolution/Conv2DTranspose.cs | 2 +- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 17 + .../Layers/Rnn/DropOutRNNCellMixin.cs | 80 +++ src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 236 +++--- .../Layers/Rnn/RNNUtils.cs | 59 ++ .../Layers/Rnn/SimpleRNNCell.cs | 92 ++- src/TensorflowNET.Hub/KerasLayer.cs | 2 +- .../Callbacks/EarlystoppingTest.cs | 60 -- .../Layers/LayersTest.cs | 11 + .../Tensorflow.Keras.UnitTest.csproj | 4 + 20 files changed, 883 insertions(+), 557 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs create mode 100644 src/TensorFlowNET.Keras/Layers/Rnn/RNNUtils.cs delete mode 100644 test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs index fcfd694d..d8fdfae5 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs @@ -2,6 +2,9 @@ { public class SimpleRNNArgs : RNNArgs { - + public float Dropout = 0f; + public float RecurrentDropout = 0f; + public int state_size; + public int output_size; } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index f7669394..8bcefc1d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -27,5 +27,11 @@ namespace Tensorflow.Keras TF_DataType DType { get; } int count_params(); void adapt(Tensor data, int? batch_size = null, int? steps = null); + + Tensors Call(Tensors inputs, Tensor? mask = null, bool? training = null, Tensors? initial_state = null, Tensors? constants = null); + + StateSizeWrapper state_size { get; } + + int output_size { get; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index 6a29f9e5..e60ba6fc 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -200,6 +200,14 @@ namespace Tensorflow.Keras.Layers bool return_sequences = false, bool return_state = false); + public ILayer SimpleRNNCell( + int units, + string activation = "tanh", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros"); + public ILayer Subtract(); } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index d49c8218..2dc70177 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -89,6 +89,8 @@ namespace Tensorflow protected bool built = false; public bool Built => built; + StateSizeWrapper ILayer.state_size => throw new NotImplementedException(); + public RnnCell(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, @@ -174,5 +176,10 @@ namespace Tensorflow { throw new NotImplementedException(); } + + public Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 2879fa8e..8fa9dcac 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -19,6 +19,7 @@ using System; using System.Collections; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; namespace Tensorflow.Util { @@ -213,6 +214,17 @@ namespace Tensorflow.Util public static bool is_nested(object obj) { + // Refer to https://www.tensorflow.org/api_docs/python/tf/nest + //if (obj is IList || obj is IDictionary || obj is ITuple) + // return true; + if (obj is IList || obj is IDictionary) + return true; + + if (obj is NDArray || obj is Tensor || obj is string || obj.GetType().IsGenericType + || obj is ISet || obj is ISet || obj is ISet) + return false; + + if (obj.GetType().IsNested) return true; // Check if the object is an IEnumerable if (obj is IEnumerable) { @@ -244,7 +256,13 @@ namespace Tensorflow.Util _flatten_recursive(structure, list); return list; } - + // TODO(Wanglongzhi2001), ITuple must used in .NET standard 2.1, but now is 2.0 + // If you want to flatten a nested tuple, please specify the type of the tuple + //public static List flatten(ITuple structure) + //{ + // var list = FlattenTuple(structure).ToList(); + // return list; + //} public static List flatten(IEnumerable structure) { var list = new List(); @@ -272,9 +290,13 @@ namespace Tensorflow.Util case String str: list.Add(obj); break; - case NDArray nd: + // This case can hold both Tensor and NDArray + case Tensor tensor: list.Add(obj); break; + //case NDArray nd: + // list.Add(obj); + // break; case IEnumerable structure: foreach (var child in structure) _flatten_recursive((T)child, list); @@ -285,28 +307,26 @@ namespace Tensorflow.Util } } - public static List FlattenTupple(object tuple) + private static IEnumerable FlattenTuple(object tuple) { - List items = new List(); - var type = tuple.GetType(); - - if (type.GetInterface("ITuple") == null) - throw new ArgumentException("This is not a tuple!"); + //if (tuple is ITuple t) + //{ + // for (int i = 0; i < t.Length; i++) + // { + // foreach (var item in FlattenTuple(t[i])) + // { + // yield return item; + // } + // } + //} + if(false) + { - foreach (var property in type.GetProperties()) + } + else { - var value = property.GetValue(tuple); - if (property.PropertyType.GetInterface("ITuple") != null) - { - var subItems = FlattenTupple(value); - items.AddRange(subItems); - } - else - { - items.Add((T)value); - } + yield return (T)tuple; } - return items; } //# See the swig file (util.i) for documentation. //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples @@ -494,8 +514,12 @@ namespace Tensorflow.Util throw new ArgumentException("flat_sequence must not be null"); // if not is_sequence(flat_sequence): // raise TypeError("flat_sequence must be a sequence") - - if (!is_sequence(structure)) + if (!is_nested(flat_sequence)) + { + throw new ArrayTypeMismatchException($"Attempted to pack value:\\n {flat_sequence}\\ninto a structure, " + + $"but found incompatible type `{flat_sequence.GetType()}` instead."); + } + if (!is_nested(structure)) { if (len(flat) != 1) throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat)} > 1"); diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index da1d25c9..94aeb0dd 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -614,331 +614,331 @@ namespace Tensorflow.Keras return nest.pack_sequence_as(inputs, inp); } - if (mask != null) - { - var mask_list = tf.unstack(mask); - if (go_backwards) - { - mask_list.Reverse(); - } - - for (int i = 0; i < time_steps; i++) - { - // TODO(Wanglongzhi2001),deal with _get_input_tensor - var inp = _get_input_tensor(i); - var mask_t = mask_list[i]; - // TODO - var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); - - var tiled_mask_t = _expand_mask(mask_t, output); - - Tensors prev_output; - if (successive_outputs == null) - { - prev_output = tf.zeros_like(output); - } - else - { - prev_output = successive_outputs[successive_outputs.Length - 1]; - } - - output = tf.where(tiled_mask_t, output, prev_output); - - //var flat_states = nest.flatten(states); - //var flat_new_states = nest.flatten(newStates); - var flat_states = states.ToList(); - var flat_new_states = newStates.ToList(); - - var tiledMaskT = flat_states - .Select(s => _expand_mask(mask_t, s)) - .ToArray(); - var tuple = Tuple.Create(tiledMaskT); - - List flat_final_states = new List(); - foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states)) - { - flat_final_states.Add(tf.where(m, s, ps)); - } - - states = (Tensors)nest.pack_sequence_as(states, flat_final_states); - if (return_all_outputs) - { - successive_outputs.Add(output); - successive_states.Add(states); - } - else - { - successive_outputs = new Tensors { output }; - successive_states = new Tensors { states }; - } - - } - last_output = successive_outputs[successive_outputs.Length - 1]; - new_states = successive_states[successive_states.Length - 1]; - outputs = tf.stack(successive_outputs); - - if (zero_output_for_mask) - { - last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); - outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); - } - else // mask is null - { - for (int i = 0; i < time_steps; i++) - { - var inp = _get_input_tensor(i); - var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); - states = newStates; - - if (return_all_outputs) - { - successive_outputs.Add(output); - successive_states.Add(newStates); - } - else - { - successive_outputs = new Tensors { output }; - successive_states = new Tensors { newStates }; - } - } - last_output = successive_outputs[successive_outputs.Length - 1]; - new_states = successive_states[successive_states.Length - 1]; - outputs = tf.stack(successive_outputs); - } - } - } - else // unroll == false - { - var states = initial_states; - // Create input tensor array, if the inputs is nested tensors, then it - // will be flattened first, and tensor array will be created one per - // flattened tensor. - var input_ta = new List(); - for (int i = 0; i < flatted_inptus.Count; i++) - { - input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t)); - } - - // Get the time(0) input and compute the output for that, the output will - // be used to determine the dtype of output tensor array. Don't read from - // input_ta due to TensorArray clear_after_read default to True. - var inps = new Tensors(); - foreach (var inp in flatted_inptus) - { - inps.Add(inp[0]); - } - var input_time_zero = nest.pack_sequence_as(inputs, inps); - - // output_time_zero is used to determine the cell output shape and its - // dtype. the value is discarded. - (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants }); - - var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1); - var output_ta = new List(); - for (int i = 0; i < output_time_zero.ToList().Count; i++) - { - var Out = output_time_zero.ToList()[i]; - output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); - } - - var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); - - - - Func? masking_fn; - Func? compute_masked_output = null; - if (mask != null) - { - if (go_backwards) - { - mask = tf.reverse(mask, axis: new[] { 0 }); - } - var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t); - mask_ta = mask_ta.unstack(mask); - - masking_fn = (time) => - { - return mask_ta.read(time); - }; - - compute_masked_output = (mask_t, flat_out, flat_mask) => - { - var tiled_mask_t = new Tensors(); - foreach (var o in flat_out) - { - tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); - } - - Tensors res = new Tensors(); - foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask)) - { - res.Add(tf.where(m, o, fm)); - } - return res; - }; - } - // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? - else if (input_length is Tensor) - { - if (go_backwards) - { - var max_len = tf.reduce_max(input_length, axis: 0); - var rev_input_length = tf.subtract(max_len - 1, input_length); - - masking_fn = (time) => - { - return tf.less(rev_input_length, time); - }; - } - else - { - masking_fn = (time) => - { - return tf.greater(input_length, time); - }; - } - - compute_masked_output = (mask_t, flat_out, flat_mask) => - { - var res = new List(); - foreach (var (o, zo) in zip(flat_out, flat_mask)) - { - res.Add(tf.where(mask_t, o, zo)); - } - return res; - }; - } - else - { - masking_fn = null; - } - - - if (masking_fn != null) - { - // Mask for the T output will be base on the output of T - 1. In the - // case T = 0, a zero filled tensor will be used. - var flat_zero_output = new Tensors(); - foreach (var o in nest.flatten(output_time_zero)) - { - flat_zero_output.Add(tf.zeros_like(o)); - } - - - (Tensor, List, Tensors, Tensors) _step(Tensor time, List output_ta_t, Tensors prev_output, Tensors states) - { - /* - RNN step function. - Args: - time: Current timestep value. - output_ta_t: TensorArray. - prev_output: tuple of outputs from time - 1. - *states: List of states. - Returns: - Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` - */ - - var current_input = input_ta.Select(x => x.read(time)).ToList(); - // maybe set shape - // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type - current_input = (List)nest.pack_sequence_as(inputs, current_input); - var mask_t = masking_fn(time); - var (output, new_states) = step_function(current_input, new Tensors { states, constants }); - // mask output - //var flat_output = nest.flatten(output); - var flat_output = output.ToList(); - - var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); - - // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type - var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); - - // mask states - var flat_state = states.ToList(); - var flat_new_state = new_states.ToList(); - - foreach (var (state, new_state) in zip(flat_state, flat_new_state)) - { - if (new_state is Tensor) - { - new_state.set_shape(state.shape); - } - } - - var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); - new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state); - - var ta_index_to_write = return_all_outputs ? time : tf.constant(0); - var Output_ta_t = new List(); - // TODO(Wanglongzhi2001),deal with zip output_ta_t - foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) - { - Output_ta_t.Add(ta.write(ta_index_to_write, Out)); - } - - - - //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); - - - return (time + 1, Output_ta_t, flat_new_output, new_states); - - } - Func cond = (time) => (time < time_step_t); - - var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states)); - new_states = final_outputs.Item4; - output_ta = final_outputs.Item2; - - } - else - { - (Tensor, List, Tensors) _step(Tensor time, List output_ta_t, Tensors states) - { - var current_input = input_ta.Select(x => x.read(time)).ToList(); - // maybe set shape - // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type - current_input = (List)nest.pack_sequence_as(inputs, current_input); - var (output, new_states) = step_function(current_input, new Tensors { states, constants }); - var flat_state = states.ToList(); - var flat_new_state = new_states.ToList(); - foreach (var (state, new_state) in zip(flat_state, flat_new_state)) - { - if (new_state is Tensor) - { - new_state.set_shape(state.shape); - } - } - var flat_output = output.ToList(); - var ta_index_to_write = return_all_outputs ? time : tf.constant(0); - var Output_ta_t = new List(); - foreach (var (ta, out_) in zip(output_ta_t, flat_output)) - { - Output_ta_t.Add(ta.write(ta_index_to_write, out_)); - } - - new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); - return (time + 1, Output_ta_t, new_states); - } - Func cond = (time) => (time < time_step_t); - var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states)); - new_states = final_outputs.Item3; - output_ta = final_outputs.Item2; - - } - //Tensors outputs = new Tensors(); - foreach (var o in output_ta) - { - outputs.Add(o.stack()); - } - foreach (var o in outputs) - { - last_output.Add(o[-1]); - } - outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs); - last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output); - + //if (mask != null) + //{ + // var mask_list = tf.unstack(mask); + // if (go_backwards) + // { + // mask_list.Reverse(); + // } + + // for (int i = 0; i < time_steps; i++) + // { + // // TODO(Wanglongzhi2001),deal with _get_input_tensor + // var inp = _get_input_tensor(i); + // var mask_t = mask_list[i]; + // // TODO + // var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); + + // var tiled_mask_t = _expand_mask(mask_t, output); + + // Tensors prev_output; + // if (successive_outputs == null) + // { + // prev_output = tf.zeros_like(output); + // } + // else + // { + // prev_output = successive_outputs[successive_outputs.Length - 1]; + // } + + // output = tf.where(tiled_mask_t, output, prev_output); + + // //var flat_states = nest.flatten(states); + // //var flat_new_states = nest.flatten(newStates); + // var flat_states = states.ToList(); + // var flat_new_states = newStates.ToList(); + + // var tiledMaskT = flat_states + // .Select(s => _expand_mask(mask_t, s)) + // .ToArray(); + // var tuple = Tuple.Create(tiledMaskT); + + // List flat_final_states = new List(); + // foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states)) + // { + // flat_final_states.Add(tf.where(m, s, ps)); + // } + + // states = (Tensors)nest.pack_sequence_as(states, flat_final_states); + // if (return_all_outputs) + // { + // successive_outputs.Add(output); + // successive_states.Add(states); + // } + // else + // { + // successive_outputs = new Tensors { output }; + // successive_states = new Tensors { states }; + // } + + // } + // last_output = successive_outputs[successive_outputs.Length - 1]; + // new_states = successive_states[successive_states.Length - 1]; + // outputs = tf.stack(successive_outputs); + + // if (zero_output_for_mask) + // { + // last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); + // outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); + // } + // else // mask is null + // { + // for (int i = 0; i < time_steps; i++) + // { + // var inp = _get_input_tensor(i); + // var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); + // states = newStates; + + // if (return_all_outputs) + // { + // successive_outputs.Add(output); + // successive_states.Add(newStates); + // } + // else + // { + // successive_outputs = new Tensors { output }; + // successive_states = new Tensors { newStates }; + // } + // } + // last_output = successive_outputs[successive_outputs.Length - 1]; + // new_states = successive_states[successive_states.Length - 1]; + // outputs = tf.stack(successive_outputs); + // } + //} } + //else // unroll == false + //{ + // var states = initial_states; + // // Create input tensor array, if the inputs is nested tensors, then it + // // will be flattened first, and tensor array will be created one per + // // flattened tensor. + // var input_ta = new List(); + // for (int i = 0; i < flatted_inptus.Count; i++) + // { + // input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t)); + // } + + // // Get the time(0) input and compute the output for that, the output will + // // be used to determine the dtype of output tensor array. Don't read from + // // input_ta due to TensorArray clear_after_read default to True. + // var inps = new Tensors(); + // foreach (var inp in flatted_inptus) + // { + // inps.Add(inp[0]); + // } + // var input_time_zero = nest.pack_sequence_as(inputs, inps); + + // // output_time_zero is used to determine the cell output shape and its + // // dtype. the value is discarded. + // (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants }); + + // var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1); + // var output_ta = new List(); + // for (int i = 0; i < output_time_zero.ToList().Count; i++) + // { + // var Out = output_time_zero.ToList()[i]; + // output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); + // } + + // var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); + + + + // Func? masking_fn; + // Func? compute_masked_output = null; + // if (mask != null) + // { + // if (go_backwards) + // { + // mask = tf.reverse(mask, axis: new[] { 0 }); + // } + // var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t); + // mask_ta = mask_ta.unstack(mask); + + // masking_fn = (time) => + // { + // return mask_ta.read(time); + // }; + + // compute_masked_output = (mask_t, flat_out, flat_mask) => + // { + // var tiled_mask_t = new Tensors(); + // foreach (var o in flat_out) + // { + // tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); + // } + + // Tensors res = new Tensors(); + // foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask)) + // { + // res.Add(tf.where(m, o, fm)); + // } + // return res; + // }; + // } + // // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? + // else if (input_length is Tensor) + // { + // if (go_backwards) + // { + // var max_len = tf.reduce_max(input_length, axis: 0); + // var rev_input_length = tf.subtract(max_len - 1, input_length); + + // masking_fn = (time) => + // { + // return tf.less(rev_input_length, time); + // }; + // } + // else + // { + // masking_fn = (time) => + // { + // return tf.greater(input_length, time); + // }; + // } + + // compute_masked_output = (mask_t, flat_out, flat_mask) => + // { + // var res = new List(); + // foreach (var (o, zo) in zip(flat_out, flat_mask)) + // { + // res.Add(tf.where(mask_t, o, zo)); + // } + // return res; + // }; + // } + // else + // { + // masking_fn = null; + // } + + + // if (masking_fn != null) + // { + // // Mask for the T output will be base on the output of T - 1. In the + // // case T = 0, a zero filled tensor will be used. + // var flat_zero_output = new Tensors(); + // foreach (var o in nest.flatten(output_time_zero)) + // { + // flat_zero_output.Add(tf.zeros_like(o)); + // } + + + // (Tensor, List, Tensors, Tensors) _step(Tensor time, List output_ta_t, Tensors prev_output, Tensors states) + // { + // /* + // RNN step function. + // Args: + // time: Current timestep value. + // output_ta_t: TensorArray. + // prev_output: tuple of outputs from time - 1. + // *states: List of states. + // Returns: + // Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` + // */ + + // var current_input = input_ta.Select(x => x.read(time)).ToList(); + // // maybe set shape + // // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + // current_input = (List)nest.pack_sequence_as(inputs, current_input); + // var mask_t = masking_fn(time); + // var (output, new_states) = step_function(current_input, new Tensors { states, constants }); + // // mask output + // //var flat_output = nest.flatten(output); + // var flat_output = output.ToList(); + + // var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); + + // // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type + // var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); + + // // mask states + // var flat_state = states.ToList(); + // var flat_new_state = new_states.ToList(); + + // foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + // { + // if (new_state is Tensor) + // { + // new_state.set_shape(state.shape); + // } + // } + + // var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); + // new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state); + + // var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + // var Output_ta_t = new List(); + // // TODO(Wanglongzhi2001),deal with zip output_ta_t + // foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) + // { + // Output_ta_t.Add(ta.write(ta_index_to_write, Out)); + // } + + + + // //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); + + + // return (time + 1, Output_ta_t, flat_new_output, new_states); + + // } + // Func cond = (time) => (time < time_step_t); + + // var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states)); + // new_states = final_outputs.Item4; + // output_ta = final_outputs.Item2; + + // } + // else + // { + // (Tensor, List, Tensors) _step(Tensor time, List output_ta_t, Tensors states) + // { + // var current_input = input_ta.Select(x => x.read(time)).ToList(); + // // maybe set shape + // // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + // current_input = (List)nest.pack_sequence_as(inputs, current_input); + // var (output, new_states) = step_function(current_input, new Tensors { states, constants }); + // var flat_state = states.ToList(); + // var flat_new_state = new_states.ToList(); + // foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + // { + // if (new_state is Tensor) + // { + // new_state.set_shape(state.shape); + // } + // } + // var flat_output = output.ToList(); + // var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + // var Output_ta_t = new List(); + // foreach (var (ta, out_) in zip(output_ta_t, flat_output)) + // { + // Output_ta_t.Add(ta.write(ta_index_to_write, out_)); + // } + + // new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); + // return (time + 1, Output_ta_t, new_states); + // } + // Func cond = (time) => (time < time_step_t); + // var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states)); + // new_states = final_outputs.Item3; + // output_ta = final_outputs.Item2; + + // } + // //Tensors outputs = new Tensors(); + // foreach (var o in output_ta) + // { + // outputs.Add(o.stack()); + // } + // foreach (var o in outputs) + // { + // last_output.Add(o[-1]); + // } + // outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs); + // last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output); + + //} Func set_shape; set_shape = (output_) => @@ -968,5 +968,27 @@ namespace Tensorflow.Keras return (last_output, Outputs, new_states); } + + // Multiplies 2 tensors (and/or variables) and returns a tensor. + // This operation corresponds to `numpy.dot(a, b, out=None)`. + public Tensor Dot(Tensor x, Tensor y) + { + //if (x.ndim != 1 && (x.ndim > 2 || y.ndim > 2)) + //{ + // var x_shape = new List(); + // foreach (var (i,s) in zip(x.shape.as_int_list(), tf.unstack(tf.shape(x)))) + // { + // if (i != 0) + // { + // x_shape.append(i); + // } + // else + // { + // x_shape.append(s); + // } + // } + //} + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index e768bd0b..660856b6 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -325,7 +325,7 @@ namespace Tensorflow.Keras.Engine nodes_in_decreasing_depth.append(node); } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var tensor_dict = new Dictionary>(); // map input values diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs index c0430458..f40cdec7 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs @@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine if (!built) MaybeBuild(inputs); - var outputs = Call(inputs, state: state, training: training); + var outputs = Call(inputs, initial_state: state, training: training); // memory leak // _set_connectivity_metadata_(inputs, outputs); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 4216c725..7d50f83a 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -254,6 +254,10 @@ namespace Tensorflow.Keras.Engine /// public Func? ReplacedCall { get; set; } = null; + public StateSizeWrapper state_size => throw new NotImplementedException(); + + public int output_size => throw new NotImplementedException(); + public Layer(LayerArgs args) { Initialize(args); @@ -434,56 +438,61 @@ namespace Tensorflow.Keras.Engine public override void SetAttr(string name, object value) { - // TODO(Rinne): deal with "_self_setattr_tracking". + //// TODO(Rinne): deal with "_self_setattr_tracking". - value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); + //value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); - foreach(var val in nest.flatten(value)) - { - if(val is Metric) - { - // TODO(Rinne): deal with metrics. - } - } - - // TODO(Rinne): deal with "_auto_track_sub_layers". - - foreach(var val in nest.flatten(value)) - { - if(val is not IVariableV1 variable) - { - continue; - } - if (variable.Trainable) - { - if (_trainable_weights.Contains(variable)) - { - continue; - } - _trainable_weights.Add(variable); - } - else - { - if (_non_trainable_weights.Contains(variable)) - { - continue; - } - _non_trainable_weights.Add(variable); - } - keras.backend.track_variable(variable); - } + //foreach(var val in nest.flatten(value)) + //{ + // if(val is Metric) + // { + // // TODO(Rinne): deal with metrics. + // } + //} + + //// TODO(Rinne): deal with "_auto_track_sub_layers". + + //foreach(var val in nest.flatten(value)) + //{ + // if(val is not IVariableV1 variable) + // { + // continue; + // } + // if (variable.Trainable) + // { + // if (_trainable_weights.Contains(variable)) + // { + // continue; + // } + // _trainable_weights.Add(variable); + // } + // else + // { + // if (_non_trainable_weights.Contains(variable)) + // { + // continue; + // } + // _non_trainable_weights.Add(variable); + // } + // keras.backend.track_variable(variable); + //} + + //// Directly use the implementation of `Trackable`. + //var t = this.GetType(); + //var field_info = t.GetField(name); + //if (field_info is not null) + //{ + // field_info.SetValue(this, value); + //} + //else + //{ + // CustomizedFields[name] = value; + //} + } - // Directly use the implementation of `Trackable`. - var t = this.GetType(); - var field_info = t.GetField(name); - if (field_info is not null) - { - field_info.SetValue(this, value); - } - else - { - CustomizedFields[name] = value; - } + Tensors ILayer.Call(Tensors inputs, Tensor mask, bool? training, Tensors initial_state, Tensors constants) + { + throw new NotImplementedException(); } } } diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 27874751..bb70e67e 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -143,7 +143,7 @@ namespace Tensorflow.Keras.Engine } } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { if (!_has_explicit_input_shape) { @@ -154,10 +154,10 @@ namespace Tensorflow.Keras.Engine { if (!built) _init_graph_network(this.inputs, outputs); - return base.Call(inputs, state, training); + return base.Call(inputs, initial_state, training); } - return base.Call(inputs, state, training); + return base.Call(inputs, initial_state, training); } void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype) diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs index bbd49acd..217dd28f 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs @@ -83,7 +83,7 @@ namespace Tensorflow.Keras.Layers _buildInputShape = input_shape; } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { var inputs_shape = array_ops.shape(inputs); var batch_size = inputs_shape[0]; diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 3b095bc2..02e9d995 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -709,6 +709,23 @@ namespace Tensorflow.Keras.Layers ReturnState = return_state }); + public ILayer SimpleRNNCell( + int units, + string activation = "tanh", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros") + => new SimpleRNNCell(new SimpleRNNArgs + { + Units = units, + Activation = keras.activations.GetActivationFromName(activation), + UseBias = use_bias, + KernelInitializer = GetInitializerByName(kernel_initializer), + RecurrentInitializer = GetInitializerByName(recurrent_initializer), + } + ); + /// /// Long Short-Term Memory layer - Hochreiter 1997. /// diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs b/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs new file mode 100644 index 00000000..fcf9b596 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs @@ -0,0 +1,80 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.ArgsDefinition.Rnn; +using Tensorflow.Keras.Engine; + + + +namespace Tensorflow.Keras.Layers.Rnn +{ + public class DropoutRNNCellMixin + { + public float dropout; + public float recurrent_dropout; + // Get the dropout mask for RNN cell's input. + public Tensors get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) + { + + return _generate_dropout_mask( + tf.ones_like(input), + dropout, + training, + count); + } + + // Get the recurrent dropout mask for RNN cell. + public Tensors get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) + { + return _generate_dropout_mask( + tf.ones_like(input), + recurrent_dropout, + training, + count); + } + + public Tensors _create_dropout_mask(Tensors input, bool training, int count = 1) + { + return _generate_dropout_mask( + tf.ones_like(input), + dropout, + training, + count); + } + + public Tensors _create_recurrent_dropout_mask(Tensors input, bool training, int count = 1) + { + return _generate_dropout_mask( + tf.ones_like(input), + recurrent_dropout, + training, + count); + } + + public Tensors _generate_dropout_mask(Tensor ones, float rate, bool training, int count = 1) + { + Tensors dropped_inputs() + { + DropoutArgs args = new DropoutArgs(); + args.Rate = rate; + var DropoutLayer = new Dropout(args); + var mask = DropoutLayer.Apply(ones, training: training); + return mask; + } + + if (count > 1) + { + Tensors results = new Tensors(); + for (int i = 0; i < count; i++) + { + results.Add(dropped_inputs()); + } + return results; + } + + return dropped_inputs(); + } + } + + +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 7bd4047a..a26743e6 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers.Rnn private RNNArgs args; private object input_spec = null; // or NoneValue?? private object state_spec = null; - private object _states = null; + private Tensors _states = null; private object constants_spec = null; private int _num_constants = 0; protected IVariableV1 kernel; @@ -44,19 +44,15 @@ namespace Tensorflow.Keras.Layers.Rnn cell = args.Cell.AsT1; } - - - - Type type = cell.GetType(); - MethodInfo methodInfo = type.GetMethod("Call"); - if (methodInfo == null) + MethodInfo callMethodInfo = type.GetMethod("Call"); + if (callMethodInfo == null) { throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. "); } - PropertyInfo propertyInfo = type.GetProperty("state_size"); - if (propertyInfo == null) + PropertyInfo state_size_info = type.GetProperty("state_size"); + if (state_size_info == null) { throw new ValueError(@"The RNN cell should have a `state_size` attribute"); } @@ -80,7 +76,7 @@ namespace Tensorflow.Keras.Layers.Rnn // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...) // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape - public object States + public Tensors States { get { @@ -106,7 +102,6 @@ namespace Tensorflow.Keras.Layers.Rnn // state_size is a array of ints or a positive integer var state_size = cell.state_size; - // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor Func _get_output_shape; _get_output_shape = (flat_output_size) => @@ -132,8 +127,10 @@ namespace Tensorflow.Keras.Layers.Rnn return output_shape; }; + Type type = cell.GetType(); + PropertyInfo output_size_info = type.GetProperty("output_size"); Shape output_shape; - if (cell.output_size != 0) + if (output_size_info != null) { output_shape = nest.map_structure(_get_output_shape, cell.output_size); // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 @@ -160,6 +157,7 @@ namespace Tensorflow.Keras.Layers.Rnn { return output_shape; } + } private Tensors compute_mask(Tensors inputs, Tensors mask) @@ -184,8 +182,6 @@ namespace Tensorflow.Keras.Layers.Rnn { return output_mask; } - - } public override void build(KerasShapesWrapper input_shape) @@ -247,14 +243,18 @@ namespace Tensorflow.Keras.Layers.Rnn protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); - //bool is_ragged_input = row_length != null; - //_validate_args_if_ragged(is_ragged_input, mask); - var (inputs_processed, initial_state_processed, constants_processed) = _process_inputs(inputs, initial_state, constants); + // 暂时先不接受ragged tensor + int? row_length = null; + bool is_ragged_input = false; + _validate_args_if_ragged(is_ragged_input, mask); + + (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants); _maybe_reset_cell_dropout_mask(cell); if (cell is StackedRNNCells) { - foreach (var cell in ((StackedRNNCells)cell).Cells) + var stack_cell = cell as StackedRNNCells; + foreach (var cell in stack_cell.Cells) { _maybe_reset_cell_dropout_mask(cell); } @@ -263,17 +263,16 @@ namespace Tensorflow.Keras.Layers.Rnn if (mask != null) { // Time step masks must be the same for each input. - //mask = nest.flatten(mask)[0]; - mask = mask[0]; + mask = nest.flatten(mask)[0]; } - Shape input_shape; - if (nest.is_nested(initial_state_processed)) + if (nest.is_nested(inputs)) { // In the case of nested input, use the first element for shape check // input_shape = nest.flatten(inputs)[0].shape; - input_shape = inputs[0].shape; + // TODO(Wanglongzhi2001) + input_shape = nest.flatten(inputs)[0].shape; } else { @@ -322,6 +321,7 @@ namespace Tensorflow.Keras.Layers.Rnn // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) states = states.Length == 1 ? states[0] : states; var (output, new_states) = cell_call_fn(inputs, null, null, states, constants); + // TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors? if (!nest.is_nested(new_states)) { return (output, new Tensors { new_states }); @@ -351,7 +351,7 @@ namespace Tensorflow.Keras.Layers.Rnn go_backwards: args.GoBackwards, mask: mask, unroll: args.Unroll, - input_length: row_length != null ? row_length : new Tensor(timesteps), + input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps), time_major: args.TimeMajor, zero_output_for_mask: args.ZeroOutputForMask, return_all_outputs: args.ReturnSequences); @@ -387,24 +387,9 @@ namespace Tensorflow.Keras.Layers.Rnn } } - private (Tensors, Tensors, Tensors) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants) + private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants) { - bool IsSequence(object obj) - { - // Check if the object is an IEnumerable - if (obj is IEnumerable) - { - // If it is, check if it is a tuple - if (!(obj is Tuple)) - { - return true; - } - } - // If it is not, return false - return false; - } - - if (IsSequence(input)) + if (nest.is_sequence(input)) { if (_num_constants != 0) { @@ -413,6 +398,7 @@ namespace Tensorflow.Keras.Layers.Rnn else { initial_state = inputs[new Slice(1, len(inputs) - _num_constants)]; + constants = inputs[new Slice(len(inputs) - _num_constants, len(inputs))]; } if (len(initial_state) == 0) initial_state = null; @@ -421,32 +407,63 @@ namespace Tensorflow.Keras.Layers.Rnn if (args.Stateful) { - throw new NotImplementedException("argument stateful has not been implemented!"); + if (initial_state != null) + { + var tmp = new Tensor[] { }; + foreach (var s in nest.flatten(States)) + { + tmp.add(tf.math.count_nonzero((Tensor)s)); + } + var non_zero_count = tf.add_n(tmp); + //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); + if((int)non_zero_count.numpy() > 0) + { + initial_state = States; + } + } + else + { + initial_state = States; + } } + else if(initial_state != null) + { + initial_state = get_initial_state(inputs); + } - return (inputs, initial_state, constants); + if (initial_state.Length != States.Length) + { + throw new ValueError( + $"Layer {this} expects {States.Length} state(s), " + + $"but it received {initial_state.Length} " + + $"initial state(s). Input received: {inputs}"); + } + return (inputs, initial_state, constants); } private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) { - if (is_ragged_input) + if (!is_ragged_input) { - if (args.Unroll) - { - throw new ValueError("The input received contains RaggedTensors and does " + - "not support unrolling. Disable unrolling by passing " + - "`unroll=False` in the RNN Layer constructor."); - } - if (mask != null) - { - throw new ValueError($"The mask that was passed in was {mask}, which " + - "cannot be applied to RaggedTensor inputs. Please " + - "make sure that there is no mask injected by upstream " + - "layers."); - } + return; + } + + if (args.Unroll) + { + throw new ValueError("The input received contains RaggedTensors and does " + + "not support unrolling. Disable unrolling by passing " + + "`unroll=False` in the RNN Layer constructor."); + } + if (mask != null) + { + throw new ValueError($"The mask that was passed in was {mask}, which " + + "cannot be applied to RaggedTensor inputs. Please " + + "make sure that there is no mask injected by upstream " + + "layers."); } + } void _maybe_reset_cell_dropout_mask(ILayer cell) @@ -489,46 +506,77 @@ namespace Tensorflow.Keras.Layers.Rnn { throw new NotImplementedException(); } - public RNN New(LayerRnnCell cell, - bool return_sequences = false, - bool return_state = false, - bool go_backwards = false, - bool stateful = false, - bool unroll = false, - bool time_major = false) - => new RNN(new RNNArgs - { - Cell = cell, - ReturnSequences = return_sequences, - ReturnState = return_state, - GoBackwards = go_backwards, - Stateful = stateful, - Unroll = unroll, - TimeMajor = time_major - }); - public RNN New(IList cell, - bool return_sequences = false, - bool return_state = false, - bool go_backwards = false, - bool stateful = false, - bool unroll = false, - bool time_major = false) - => new RNN(new RNNArgs - { - Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }), - ReturnSequences = return_sequences, - ReturnState = return_state, - GoBackwards = go_backwards, - Stateful = stateful, - Unroll = unroll, - TimeMajor = time_major - }); + // 好像不能cell不能传接口类型 + //public RNN New(IRnnArgCell cell, + // bool return_sequences = false, + // bool return_state = false, + // bool go_backwards = false, + // bool stateful = false, + // bool unroll = false, + // bool time_major = false) + // => new RNN(new RNNArgs + // { + // Cell = cell, + // ReturnSequences = return_sequences, + // ReturnState = return_state, + // GoBackwards = go_backwards, + // Stateful = stateful, + // Unroll = unroll, + // TimeMajor = time_major + // }); + + //public RNN New(List cell, + // bool return_sequences = false, + // bool return_state = false, + // bool go_backwards = false, + // bool stateful = false, + // bool unroll = false, + // bool time_major = false) + // => new RNN(new RNNArgs + // { + // Cell = cell, + // ReturnSequences = return_sequences, + // ReturnState = return_state, + // GoBackwards = go_backwards, + // Stateful = stateful, + // Unroll = unroll, + // TimeMajor = time_major + // }); + + + protected Tensors get_initial_state(Tensor inputs) + { + Type type = cell.GetType(); + MethodInfo MethodInfo = type.GetMethod("get_initial_state"); + if (nest.is_nested(inputs)) + { + // The input are nested sequences. Use the first element in the seq + // to get batch size and dtype. + inputs = nest.flatten(inputs)[0]; + } - protected Tensor get_initial_state(Tensor inputs) - { - return _generate_zero_filled_state_for_cell(null, null); + var input_shape = tf.shape(inputs); + var batch_size = args.TimeMajor ? input_shape[1] : input_shape[0]; + var dtype = inputs.dtype; + Tensor init_state; + if (MethodInfo != null) + { + init_state = (Tensor)MethodInfo.Invoke(cell, new object[] { null, batch_size, dtype }); + } + else + { + init_state = RNNUtils.generate_zero_filled_state(batch_size, cell.state_size, dtype); + } + + //if (!nest.is_nested(init_state)) + //{ + // init_state = new List { init_state}; + //} + return new List { init_state }; + + //return _generate_zero_filled_state_for_cell(null, null); } Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size) diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNNUtils.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNNUtils.cs new file mode 100644 index 00000000..f516f765 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNNUtils.cs @@ -0,0 +1,59 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Util; +using OneOf; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.Layers.Rnn +{ + public class RNNUtils + { + public static Tensor generate_zero_filled_state(Tensor batch_size_tensor, StateSizeWrapper state_size, TF_DataType dtype = TF_DataType.TF_FLOAT) + { + if (batch_size_tensor == null || dtype == null) + { + throw new ValueError( + "batch_size and dtype cannot be None while constructing initial " + + $"state. Received: batch_size={batch_size_tensor}, dtype={dtype}"); + } + + Func create_zeros; + create_zeros = (StateSizeWrapper unnested_state_size) => + { + var flat_dims = unnested_state_size.state_size; + //if (unnested_state_size is int[]) + //{ + // flat_dims = new Shape(unnested_state_size.AsT0).as_int_list(); + //} + //else if (unnested_state_size.IsT1) + //{ + // flat_dims = new Shape(unnested_state_size.AsT1).as_int_list(); + //} + var init_state_size = batch_size_tensor.ToArray().concat(flat_dims); + return tf.zeros(init_state_size, dtype: dtype); + }; + + //if (nest.is_nested(state_size)) + //{ + // return nest.map_structure(create_zeros, state_size); + //} + //else + //{ + // return create_zeros(state_size); + //} + return create_zeros(state_size); + + } + + public static Tensor generate_zero_filled_state_for_cell(SimpleRNNCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype) + { + if (inputs != null) + { + batch_size = tf.shape(inputs)[0]; + dtype = inputs.dtype; + } + return generate_zero_filled_state(batch_size, cell.state_size, dtype); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 86985d7e..2c89d2e6 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -4,6 +4,7 @@ using System.Text; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Util; namespace Tensorflow.Keras.Layers.Rnn { @@ -13,10 +14,23 @@ namespace Tensorflow.Keras.Layers.Rnn IVariableV1 kernel; IVariableV1 recurrent_kernel; IVariableV1 bias; - + DropoutRNNCellMixin DRCMixin; public SimpleRNNCell(SimpleRNNArgs args) : base(args) { this.args = args; + if (args.Units <= 0) + { + throw new ValueError( + $"units must be a positive integer, got {args.Units}"); + } + this.args.Dropout = Math.Min(1f, Math.Max(0f, this.args.Dropout)); + this.args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this.args.RecurrentDropout)); + this.args.state_size = this.args.Units; + this.args.output_size = this.args.Units; + + DRCMixin = new DropoutRNNCellMixin(); + DRCMixin.dropout = this.args.Dropout; + DRCMixin.recurrent_dropout = this.args.RecurrentDropout; } public override void build(KerasShapesWrapper input_shape) @@ -44,7 +58,81 @@ namespace Tensorflow.Keras.Layers.Rnn protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { - return base.Call(inputs, initial_state, training); + Console.WriteLine($"shape of input: {inputs.shape}"); + Tensor states = initial_state[0]; + Console.WriteLine($"shape of initial_state: {states.shape}"); + + var prev_output = nest.is_nested(states) ? states[0] : states; + var dp_mask = DRCMixin.get_dropout_maskcell_for_cell(inputs, training.Value); + var rec_dp_mask = DRCMixin.get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); + + Tensor h; + var ranks = inputs.rank; + //if (dp_mask != null) + if(false) + { + if (ranks > 2) + { + h = tf.linalg.tensordot(tf.multiply(inputs, dp_mask), kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); + } + else + { + h = math_ops.matmul(tf.multiply(inputs, dp_mask), kernel.AsTensor()); + } + } + else + { + if (ranks > 2) + { + h = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); + } + else + { + h = math_ops.matmul(inputs, kernel.AsTensor()); + } + } + + if (bias != null) + { + h = tf.nn.bias_add(h, bias); + } + + if (rec_dp_mask != null) + { + prev_output = tf.multiply(prev_output, rec_dp_mask); + } + + ranks = prev_output.rank; + Console.WriteLine($"shape of h: {h.shape}"); + + Tensor output; + if (ranks > 2) + { + var tmp = tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); + output = h + tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } })[0]; + } + else + { + output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor())[0]; + + } + Console.WriteLine($"shape of output: {output.shape}"); + + if (args.Activation != null) + { + output = args.Activation.Apply(output); + } + if (nest.is_nested(states)) + { + return (output, new Tensors { output }); + } + return (output, output); + } + + + public Tensor get_initial_state(Tensors inputs, Tensor batch_size, TF_DataType dtype) + { + return RNNUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); } } } diff --git a/src/TensorflowNET.Hub/KerasLayer.cs b/src/TensorflowNET.Hub/KerasLayer.cs index b9ca949b..6a2ecb4c 100644 --- a/src/TensorflowNET.Hub/KerasLayer.cs +++ b/src/TensorflowNET.Hub/KerasLayer.cs @@ -89,7 +89,7 @@ namespace Tensorflow.Hub } } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { _check_trainability(); diff --git a/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs b/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs deleted file mode 100644 index ac5ba15e..00000000 --- a/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs +++ /dev/null @@ -1,60 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System.Collections.Generic; -using Tensorflow.Keras.Callbacks; -using Tensorflow.Keras.Engine; -using static Tensorflow.KerasApi; - - -namespace Tensorflow.Keras.UnitTest.Callbacks -{ - [TestClass] - public class EarlystoppingTest - { - [TestMethod] - // Because loading the weight variable into the model has not yet been implemented, - // so you'd better not set patience too large, because the weights will equal to the last epoch's weights. - public void Earlystopping() - { - var layers = keras.layers; - var model = keras.Sequential(new List - { - layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)), - layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu), - layers.MaxPooling2D(), - layers.Flatten(), - layers.Dense(128, activation: keras.activations.Relu), - layers.Dense(10) - }); - - - model.summary(); - - model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), - loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), - metrics: new[] { "acc" }); - - var num_epochs = 3; - var batch_size = 8; - - var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); - x_train = x_train / 255.0f; - // define a CallbackParams first, the parameters you pass al least contain Model and Epochs. - CallbackParams callback_parameters = new CallbackParams - { - Model = model, - Epochs = num_epochs, - }; - // define your earlystop - ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy"); - // define a callbcaklist, then add the earlystopping to it. - var callbacks = new List(); - callbacks.add(earlystop); - - model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs, callbacks: callbacks); - } - - } - - -} - diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index 3de33746..1e2f894b 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -144,6 +144,17 @@ namespace Tensorflow.Keras.UnitTest.Layers Assert.AreEqual(expected_output, actual_output); } + [TestMethod] + public void SimpleRNNCell() + { + var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; + var x = tf.random.normal(new Shape(4, 100)); + var cell = keras.layers.SimpleRNNCell(64); + var (y, h1) = cell.Apply(inputs:x, state:h0); + Assert.AreEqual((4, 64), y.shape); + Assert.AreEqual((4, 64), h1[0].shape); + } + [TestMethod, Ignore("WIP")] public void SimpleRNN() { diff --git a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj index d744c336..db7d5892 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj +++ b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj @@ -67,4 +67,8 @@ + + + + From 2ffc9dafc8b17b6f16ac5eff43a6ffe4657bf41b Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Fri, 2 Jun 2023 21:13:16 +0800 Subject: [PATCH 3/4] update draft pr --- src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs | 3 +-- test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 2c89d2e6..0bca437b 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -68,8 +68,7 @@ namespace Tensorflow.Keras.Layers.Rnn Tensor h; var ranks = inputs.rank; - //if (dp_mask != null) - if(false) + if (dp_mask != null) { if (ranks > 2) { diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index 1e2f894b..b3d45729 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -152,7 +152,8 @@ namespace Tensorflow.Keras.UnitTest.Layers var cell = keras.layers.SimpleRNNCell(64); var (y, h1) = cell.Apply(inputs:x, state:h0); Assert.AreEqual((4, 64), y.shape); - Assert.AreEqual((4, 64), h1[0].shape); + // this test now cannot pass, need to deal with SimpleRNNCell's Call method + //Assert.AreEqual((4, 64), h1[0].shape); } [TestMethod, Ignore("WIP")] From 08b4b89f777000e5138cd7be5479b0bd261a527f Mon Sep 17 00:00:00 2001 From: Wanglongzhi2001 <583087864@qq.com> Date: Sat, 3 Jun 2023 18:46:48 +0800 Subject: [PATCH 4/4] Finish SimpleRNNCell and add test --- .../Keras/Layers/ILayersApi.cs | 4 +++- .../Operations/gen_math_ops.cs | 3 ++- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 6 +++++- .../Layers/Rnn/DropOutRNNCellMixin.cs | 9 ++++++--- .../Layers/Rnn/SimpleRNNCell.cs | 18 ++++++------------ .../Layers/LayersTest.cs | 12 +++++++----- 6 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index e60ba6fc..7f596500 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -206,7 +206,9 @@ namespace Tensorflow.Keras.Layers bool use_bias = true, string kernel_initializer = "glorot_uniform", string recurrent_initializer = "orthogonal", - string bias_initializer = "zeros"); + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f); public ILayer Subtract(); } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 3456d9b3..68d561ae 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -4633,8 +4633,9 @@ public static class gen_math_ops var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatMul", name) { args = new object[] { a, b }, attrs = new Dictionary() { ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b } }); return _fast_path_result[0]; } - catch (Exception) + catch (ArgumentException) { + throw new ArgumentException("In[0] and In[1] has diffrent ndims!"); } try { diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 02e9d995..35410337 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -715,7 +715,9 @@ namespace Tensorflow.Keras.Layers bool use_bias = true, string kernel_initializer = "glorot_uniform", string recurrent_initializer = "orthogonal", - string bias_initializer = "zeros") + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f) => new SimpleRNNCell(new SimpleRNNArgs { Units = units, @@ -723,6 +725,8 @@ namespace Tensorflow.Keras.Layers UseBias = use_bias, KernelInitializer = GetInitializerByName(kernel_initializer), RecurrentInitializer = GetInitializerByName(recurrent_initializer), + Dropout = dropout, + RecurrentDropout = recurrent_dropout } ); diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs b/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs index fcf9b596..b9a6fbc3 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/DropOutRNNCellMixin.cs @@ -13,9 +13,10 @@ namespace Tensorflow.Keras.Layers.Rnn public float dropout; public float recurrent_dropout; // Get the dropout mask for RNN cell's input. - public Tensors get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) + public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) { - + if (dropout == 0f) + return null; return _generate_dropout_mask( tf.ones_like(input), dropout, @@ -24,8 +25,10 @@ namespace Tensorflow.Keras.Layers.Rnn } // Get the recurrent dropout mask for RNN cell. - public Tensors get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) + public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) { + if (dropout == 0f) + return null; return _generate_dropout_mask( tf.ones_like(input), recurrent_dropout, diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 0bca437b..ad2e9484 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -58,10 +58,7 @@ namespace Tensorflow.Keras.Layers.Rnn protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) { - Console.WriteLine($"shape of input: {inputs.shape}"); Tensor states = initial_state[0]; - Console.WriteLine($"shape of initial_state: {states.shape}"); - var prev_output = nest.is_nested(states) ? states[0] : states; var dp_mask = DRCMixin.get_dropout_maskcell_for_cell(inputs, training.Value); var rec_dp_mask = DRCMixin.get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); @@ -72,11 +69,12 @@ namespace Tensorflow.Keras.Layers.Rnn { if (ranks > 2) { - h = tf.linalg.tensordot(tf.multiply(inputs, dp_mask), kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); + // 因为multiply函数会自动添加第一个维度,所以加上下标0 + h = tf.linalg.tensordot(math_ops.multiply(inputs, dp_mask)[0], kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); } else { - h = math_ops.matmul(tf.multiply(inputs, dp_mask), kernel.AsTensor()); + h = math_ops.matmul(math_ops.multiply(inputs, dp_mask)[0], kernel.AsTensor()); } } else @@ -98,22 +96,18 @@ namespace Tensorflow.Keras.Layers.Rnn if (rec_dp_mask != null) { - prev_output = tf.multiply(prev_output, rec_dp_mask); + prev_output = math_ops.multiply(prev_output, rec_dp_mask)[0]; } ranks = prev_output.rank; - Console.WriteLine($"shape of h: {h.shape}"); - Tensor output; if (ranks > 2) { - var tmp = tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); - output = h + tf.linalg.tensordot(prev_output, recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } })[0]; + output = h + tf.linalg.tensordot(prev_output[0], recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); } else { - output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor())[0]; - + output = h + math_ops.matmul(prev_output, recurrent_kernel.AsTensor()); } Console.WriteLine($"shape of output: {output.shape}"); diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index b3d45729..c4888a39 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -147,13 +147,15 @@ namespace Tensorflow.Keras.UnitTest.Layers [TestMethod] public void SimpleRNNCell() { + var cell = keras.layers.SimpleRNNCell(64, dropout:0.5f, recurrent_dropout:0.5f); var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; - var x = tf.random.normal(new Shape(4, 100)); - var cell = keras.layers.SimpleRNNCell(64); - var (y, h1) = cell.Apply(inputs:x, state:h0); + var x = tf.random.normal((4, 100)); + var (y, h1) = cell.Apply(inputs: x, state: h0); + // TODO(Wanglongzhi2001),因为SimpleRNNCell需要返回一个Tensor和一个Tensors,只用一个Tensors的话 + // hold不住,所以自行在外面将h强制转换成Tensors + var h2 = (Tensors)h1; Assert.AreEqual((4, 64), y.shape); - // this test now cannot pass, need to deal with SimpleRNNCell's Call method - //Assert.AreEqual((4, 64), h1[0].shape); + Assert.AreEqual((4, 64), h2[0].shape); } [TestMethod, Ignore("WIP")]