diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs index df6222cd..d12ed1ad 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs @@ -9,11 +9,11 @@ namespace Tensorflow.Keras.Layers.Rnn { GeneralizedTensorShape StateSize { get; } GeneralizedTensorShape OutputSize { get; } + bool IsTFRnnCell { get; } /// /// Whether the optional RNN args are supported when appying the layer. /// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. /// bool SupportOptionalArgs { get; } - (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null); } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 71fdc301..26646b76 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -183,6 +183,7 @@ namespace Tensorflow } public GeneralizedTensorShape StateSize => throw new NotImplementedException(); public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); + public bool IsTFRnnCell => throw new NotImplementedException(); public bool SupportOptionalArgs => throw new NotImplementedException(); } } diff --git a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs index cf1b50af..ed65a08d 100644 --- a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Eager; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -48,6 +49,7 @@ namespace Tensorflow.Operations public override Tensor flow => _flow; bool _clear_after_read; List _tensor_array; + List _previous_read_indices; public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false, bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, @@ -61,16 +63,20 @@ namespace Tensorflow.Operations _dtype = dtype.as_base_dtype(); _dynamic_size = dynamic_size; _clear_after_read = clear_after_read; - _tensor_array = new List(); + _tensor_array = Enumerable.Repeat(null, size.numpy()).ToList(); + _previous_read_indices = new(); } public override TensorArray unstack(Tensor value, string name = null) { - return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate + var tensors = array_ops.unstack(value, name: name); + if(tensors.Length > _tensor_array.Count && !_dynamic_size) { - var num_elements = array_ops.shape(value)[0]; - return scatter(indices: math_ops.range(0, num_elements), value: value, name: name); - }); + throw new ValueError($"Cannot unstack {tensors.Length} tensors into a TensorArray of static size {_tensor_array.Count}"); + } + _tensor_array = tensors.ToList(); + // TODO(Rinne): revise the implementation. Here we should return `parent()`. + return this; } public TensorArray scatter(Tensor indices, Tensor value, string name = null) @@ -116,9 +122,19 @@ namespace Tensorflow.Operations _colocate_with.Add(value); } + private Tensor _maybe_zero(int ix) + { + var val = _tensor_array[ix]; + if(val is null) + { + val = _tensor_array[ix] = array_ops.zeros(_element_shape, _dtype); + } + return val; + } + public override Tensor read(T index, string name = null) { - int index_int = -1; + int index_int; if (index is int int_index) index_int = int_index; else if (index is Tensor tensor_index) @@ -126,27 +142,75 @@ namespace Tensorflow.Operations else throw new ValueError(""); + if(index_int >= _tensor_array.Count) + { + throw new OutOfRangeError($"Tried to read from index {index_int} but array size is: {_tensor_array.Count} "); + } + + var res = _tensor_array[index_int]; + if(res is null) + { + if (_previous_read_indices.Contains(index_int)) + { + throw new InvalidArgumentError($"Could not read index {index_int} twice because it was cleared after " + + $"a previous read (perhaps try setting clear_after_read = false?)"); + } + else + { + res = _maybe_zero(index_int); + } + } + if (_clear_after_read) { _tensor_array[index_int] = null; + _previous_read_indices.Add(index_int); } - - return _tensor_array[index_int]; + return res; } public override TensorArray write(Tensor index, Tensor value, string name = null) { - if (_infer_shape) - _element_shape = _element_shape.merge_with(value.shape); - _tensor_array.add(value); - return this; + int index_int; + if(index is EagerTensor eager) + { + return write(eager.numpy(), value, name); + } + throw new InvalidArgumentError("The index is supposed to be an EagerTensor"); } public override TensorArray write(int index, T value, string name = null) { - var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); - var index_tensor = ops.convert_to_tensor(index, name: "index"); - return write(index_tensor, value_tensor, name: name); + int size = _tensor_array.Count; + if(index >= size) + { + if (!_dynamic_size) + { + throw new OutOfRangeError($"Tried to write to index {index} but array is not resizeable and size " + + $"is: {size} "); + } + _tensor_array.AddRange(Enumerable.Repeat(null, index - size + 1)); + } + + Tensor tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + + if(_dtype != tensor.dtype) + { + throw new InvalidArgumentError($"TensorArray dtype is {_dtype.as_python_name()} but Op is " + + $"trying to write dtype {tensor.dtype.as_python_name()} "); + } + + if (!_element_shape.is_compatible_with(tensor.shape)) + { + throw new ValueError($"Incompatible shape for value ({tensor.shape}), expected ({_element_shape})"); + } + + if (_infer_shape) + { + _element_shape = _element_shape.merge_with(tensor.shape); + } + _tensor_array[index] = tensor; + return this; } private Tensor size(string name = null) @@ -156,11 +220,26 @@ namespace Tensorflow.Operations public override Tensor stack(string name = null) { - ops.colocate_with(_handle); - return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate + if(_tensor_array.Count > 0) { - return gather(math_ops.range(0, size()), name: name); - }); + for(int i = 0; i < _tensor_array.Count; i++) + { + _maybe_zero(i); + } + } + if(_tensor_array.Count == 0 && _element_shape.IsFullyDefined) + { + return ops.convert_to_tensor(new Shape(new long[] { 0 }.Concat(_element_shape.dims).ToArray()), name: name, dtype: _dtype); + } + else + { + return ops.convert_to_tensor(_tensor_array, name: name, dtype: _dtype); + } + //ops.colocate_with(_handle); + //return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate + //{ + // return gather(math_ops.range(0, size()), name: name); + //}); } public override Tensor gather(Tensor indices, string name = null) diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index a7c1bcad..30b73e82 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -20,9 +20,11 @@ using System.Linq; using System.Collections.Generic; using Tensorflow.Functions; using Tensorflow.Graphs; +using Tensorflow.Common.Extensions; using static Tensorflow.Binding; using static Tensorflow.Graphs.SubGraphUtility; using Tensorflow.Util; +using Tensorflow.Common.Types; namespace Tensorflow.Keras { @@ -452,7 +454,7 @@ namespace Tensorflow.Keras return x; } - public static (Tensors, Tensors, Tensors) rnn( + public (Tensors, Tensors, Tensors) rnn( Func step_function, // args:inputs, states, return:output, new_states Tensors inputs, // inputs is a tuple of tensors (one per input sequence) Tensors initial_states, @@ -466,7 +468,7 @@ namespace Tensorflow.Keras bool return_all_outputs = true) { - Tensors swap_batch_timestep(Tensors input_t) + Tensor swap_batch_timestep(Tensor input_t) { var axes = Enumerable.Range(0, input_t.rank).ToArray(); axes[0] = 1; @@ -476,13 +478,14 @@ namespace Tensorflow.Keras if (!time_major) { - inputs = nest.map_structure(swap_batch_timestep, inputs); + inputs = Nest.MapStructure(swap_batch_timestep, inputs).ToTensors(); } - var flatted_inptus = nest.flatten(inputs); - var time_steps = flatted_inptus[0].shape[0]; - var batch = flatted_inptus[0].shape[1]; - var time_step_t = tf.shape(flatted_inptus[0])[0]; + var flatted_inptus = Nest.Flatten(inputs).ToList(); + var first_flatted_input = flatted_inptus[0]; + var time_steps = first_flatted_input.shape[0]; + var batch = first_flatted_input.shape[1]; + var time_steps_t = (int)first_flatted_input.shape[0]; foreach (var input_ in flatted_inptus) { @@ -508,11 +511,6 @@ namespace Tensorflow.Keras } - if (constants == null) - { - constants = new List(); - } - // tf.where needs its condition tensor to be the same shape as its two // result tensors, but in our case the condition (mask) tensor is // (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. @@ -522,12 +520,12 @@ namespace Tensorflow.Keras Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1) { - if (nest.is_nested(mask_t)) + if (!mask_t.IsSingle()) { throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}"); } - if (nest.is_nested(input_t)) + if (!input_t.IsSingle()) { throw new ValueError($"input_t is expected to be tensor, but got {input_t}"); } @@ -575,21 +573,21 @@ namespace Tensorflow.Keras - Tensors _process_single_input_t(Tensors input_t) + Tensors _process_single_input_t(Tensor input_t) { - input_t = tf.unstack(input_t); // unstack for time_step dim + var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim if (go_backwards) { - input_t.Reverse(); + unstaked_input_t = unstaked_input_t.Reverse().ToArray(); } - return input_t; + return unstaked_input_t; } // TODO(Wanglongzhi2001) Tensors processed_input; - if (nest.is_nested(inputs)) + if (!inputs.IsSingle()) { - processed_input = nest.map_structure(_process_single_input_t, inputs); + processed_input = inputs.MapStructure(_process_single_input_t).ReduceTo().ToTensors(); } else { @@ -603,334 +601,339 @@ namespace Tensorflow.Keras { inp.Add(t_[time]); } - return nest.pack_sequence_as(inputs, inp); + return Nest.PackSequenceAs(inputs, inp); } - //if (mask != null) - //{ - // var mask_list = tf.unstack(mask); - // if (go_backwards) - // { - // mask_list.Reverse(); - // } - - // for (int i = 0; i < time_steps; i++) - // { - // // TODO(Wanglongzhi2001),deal with _get_input_tensor - // var inp = _get_input_tensor(i); - // var mask_t = mask_list[i]; - // // TODO - // var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); - - // var tiled_mask_t = _expand_mask(mask_t, output); - - // Tensors prev_output; - // if (successive_outputs == null) - // { - // prev_output = tf.zeros_like(output); - // } - // else - // { - // prev_output = successive_outputs[successive_outputs.Length - 1]; - // } - - // output = tf.where(tiled_mask_t, output, prev_output); - - // //var flat_states = nest.flatten(states); - // //var flat_new_states = nest.flatten(newStates); - // var flat_states = states.ToList(); - // var flat_new_states = newStates.ToList(); - - // var tiledMaskT = flat_states - // .Select(s => _expand_mask(mask_t, s)) - // .ToArray(); - // var tuple = Tuple.Create(tiledMaskT); - - // List flat_final_states = new List(); - // foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states)) - // { - // flat_final_states.Add(tf.where(m, s, ps)); - // } - - // states = (Tensors)nest.pack_sequence_as(states, flat_final_states); - // if (return_all_outputs) - // { - // successive_outputs.Add(output); - // successive_states.Add(states); - // } - // else - // { - // successive_outputs = new Tensors { output }; - // successive_states = new Tensors { states }; - // } - - // } - // last_output = successive_outputs[successive_outputs.Length - 1]; - // new_states = successive_states[successive_states.Length - 1]; - // outputs = tf.stack(successive_outputs); - - // if (zero_output_for_mask) - // { - // last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); - // outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); - // } - // else // mask is null - // { - // for (int i = 0; i < time_steps; i++) - // { - // var inp = _get_input_tensor(i); - // var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); - // states = newStates; - - // if (return_all_outputs) - // { - // successive_outputs.Add(output); - // successive_states.Add(newStates); - // } - // else - // { - // successive_outputs = new Tensors { output }; - // successive_states = new Tensors { newStates }; - // } - // } - // last_output = successive_outputs[successive_outputs.Length - 1]; - // new_states = successive_states[successive_states.Length - 1]; - // outputs = tf.stack(successive_outputs); - // } - //} + if (mask != null) + { + var mask_list = tf.unstack(mask); + if (go_backwards) + { + mask_list.Reverse(); + } + + for (int i = 0; i < time_steps; i++) + { + // TODO(Wanglongzhi2001),deal with _get_input_tensor + var inp = _get_input_tensor(i); + var mask_t = mask_list[i]; + // TODO + var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants)); + + var tiled_mask_t = _expand_mask(mask_t, output); + + Tensors prev_output; + if (successive_outputs == null) + { + prev_output = tf.zeros_like(output); + } + else + { + prev_output = successive_outputs[successive_outputs.Length - 1]; + } + + output = tf.where(tiled_mask_t, output, prev_output); + + var flat_states = Nest.Flatten(states).ToList(); + var flat_new_states = Nest.Flatten(newStates).ToList(); + + var tiledMaskT = flat_states + .Select(s => _expand_mask(mask_t, s)) + .ToArray(); + var tuple = Tuple.Create(tiledMaskT); + + List flat_final_states = new List(); + foreach (var (m, s, ps) in zip(tiled_mask_t.ToList(), flat_new_states, flat_states)) + { + flat_final_states.Add(tf.where(m, s, ps)); + } + + states = Nest.PackSequenceAs(states, flat_final_states).ToTensors(); + if (return_all_outputs) + { + successive_outputs.Add(output); + successive_states.Add(states); + } + else + { + successive_outputs = new Tensors { output }; + successive_states = new Tensors { states }; + } + + } + last_output = successive_outputs[successive_outputs.Length - 1]; + new_states = successive_states[successive_states.Length - 1]; + outputs = tf.stack(successive_outputs); + + if (zero_output_for_mask) + { + last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); + outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); + } + else // mask is null + { + for (int i = 0; i < time_steps; i++) + { + var inp = _get_input_tensor(i); + var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants)); + states = newStates; + + if (return_all_outputs) + { + successive_outputs.Add(output); + successive_states.Add(newStates); + } + else + { + successive_outputs = new Tensors { output }; + successive_states = new Tensors { newStates }; + } + } + last_output = successive_outputs[successive_outputs.Length - 1]; + new_states = successive_states[successive_states.Length - 1]; + outputs = tf.stack(successive_outputs); + } + } + } + else // unroll == false + { + var states = initial_states; + // Create input tensor array, if the inputs is nested tensors, then it + // will be flattened first, and tensor array will be created one per + // flattened tensor. + var input_ta = new List(); + for (int i = 0; i < flatted_inptus.Count; i++) + { + input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t)); + } + + foreach(var (ta, input_) in zip(input_ta, flatted_inptus)) + { + if (!go_backwards) + { + ta.unstack(input_); + } + else + { + ta.unstack(reverse(input_, 0)); + } + } + + // Get the time(0) input and compute the output for that, the output will + // be used to determine the dtype of output tensor array. Don't read from + // input_ta due to TensorArray clear_after_read default to True. + var inps = new Tensors(); + foreach (var inp in flatted_inptus) + { + inps.Add(inp[0]); + } + var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors(); + + // output_time_zero is used to determine the cell output shape and its + // dtype. the value is discarded. + (output_time_zero, _) = step_function((Tensor)input_time_zero, + constants is null ? initial_states : initial_states.MergeWith(constants)); + + int output_ta_size = return_all_outputs ? time_steps_t : 1; + var output_ta = new List(); + for (int i = 0; i < output_time_zero.ToList().Count; i++) + { + var Out = output_time_zero.ToList()[i]; + output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); + } + + var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); + + + + Func? masking_fn; + Func? compute_masked_output = null; + if (mask != null) + { + if (go_backwards) + { + mask = tf.reverse(mask, axis: new[] { 0 }); + } + var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t); + mask_ta = mask_ta.unstack(mask); + + masking_fn = (time) => + { + return mask_ta.read(time); + }; + + compute_masked_output = (mask_t, flat_out, flat_mask) => + { + var tiled_mask_t = new Tensors(); + foreach (var o in flat_out) + { + tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); + } + + Tensors res = new Tensors(); + foreach (var (m, o, fm) in zip(tiled_mask_t.ToList(), flat_out.ToList(), flat_mask.ToList())) + { + res.Add(tf.where(m, o, fm)); + } + return res; + }; + } + // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? + else if (input_length is Tensor) + { + if (go_backwards) + { + var max_len = tf.reduce_max(input_length, axis: 0); + var rev_input_length = tf.subtract(max_len - 1, input_length); + + masking_fn = (time) => + { + return tf.less(rev_input_length, time); + }; + } + else + { + masking_fn = (time) => + { + return tf.greater(input_length, time); + }; + } + + compute_masked_output = (mask_t, flat_out, flat_mask) => + { + var res = new List(); + foreach (var (o, zo) in zip(flat_out, flat_mask)) + { + res.Add(tf.where(mask_t, o, zo)); + } + return res; + }; + } + else + { + masking_fn = null; + } + + Func cond = (time) => (time < time_steps_t); + int parallel_iterations = 32; + if (masking_fn != null) + { + // Mask for the T output will be base on the output of T - 1. In the + // case T = 0, a zero filled tensor will be used. + var flat_zero_output = new Tensors(); + foreach (var o in Nest.Flatten(output_time_zero)) + { + flat_zero_output.Add(tf.zeros_like(o)); + } + + var prev_output = flat_zero_output; + var output_ta_t = output_ta; + Tensor _step(Tensor time) + { + /* + RNN step function. + Args: + time: Current timestep value. + output_ta_t: TensorArray. + prev_output: tuple of outputs from time - 1. + *states: List of states. + Returns: + Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` + */ + + var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); + // maybe set shape + // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); + var mask_t = masking_fn(time); + var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants)); + // mask output + var flat_output = Nest.Flatten(output).ToList(); + + var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); + + // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type + var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); + + // mask states + var flat_state = states.ToList(); + var flat_new_state = new_states_internal.ToList(); + + foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + { + if (new_state is Tensor) + { + new_state.shape = state.shape; + } + } + + var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); + new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors(); + + var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + // TODO(Wanglongzhi2001),deal with zip output_ta_t + foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) + { + output_ta_t.Add(ta.write(ta_index_to_write, Out)); + } + + new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); + + output_ta = output_ta_t; + new_states = new_states_internal; + return time + 1; + + } + var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); + } + else + { + var output_ta_t = output_ta; + new_states = states; + Tensor _step(Tensor time) + { + var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); + // maybe set shape + // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type + var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); + var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants)); + var flat_state = new_states.Flatten().ToList(); + var flat_new_state = new_states_internal.Flatten().ToList(); + foreach (var (state, new_state) in zip(flat_state, flat_new_state)) + { + if (new_state is Tensor) + { + new_state.shape = state.shape; + } + } + var flat_output = Nest.Flatten(output); + var ta_index_to_write = return_all_outputs ? time : tf.constant(0); + output_ta_t = zip(output_ta_t, flat_output).Select(item => + { + var (ta, out_) = item; + return ta.write(ta_index_to_write, out_); + }).ToList(); + + new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); + output_ta = output_ta_t; + new_states = new_states_internal; + return time + 1; + } + var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); + } + //Tensors outputs = new Tensors(); + foreach (var o in output_ta) + { + outputs.Add(o.stack()); + } + foreach (var o in outputs) + { + last_output.Add(o[-1]); + } + outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors(); + last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors(); + } - //else // unroll == false - //{ - // var states = initial_states; - // // Create input tensor array, if the inputs is nested tensors, then it - // // will be flattened first, and tensor array will be created one per - // // flattened tensor. - // var input_ta = new List(); - // for (int i = 0; i < flatted_inptus.Count; i++) - // { - // input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t)); - // } - - // // Get the time(0) input and compute the output for that, the output will - // // be used to determine the dtype of output tensor array. Don't read from - // // input_ta due to TensorArray clear_after_read default to True. - // var inps = new Tensors(); - // foreach (var inp in flatted_inptus) - // { - // inps.Add(inp[0]); - // } - // var input_time_zero = nest.pack_sequence_as(inputs, inps); - - // // output_time_zero is used to determine the cell output shape and its - // // dtype. the value is discarded. - // (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants }); - - // var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1); - // var output_ta = new List(); - // for (int i = 0; i < output_time_zero.ToList().Count; i++) - // { - // var Out = output_time_zero.ToList()[i]; - // output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); - // } - - // var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); - - - - // Func? masking_fn; - // Func? compute_masked_output = null; - // if (mask != null) - // { - // if (go_backwards) - // { - // mask = tf.reverse(mask, axis: new[] { 0 }); - // } - // var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t); - // mask_ta = mask_ta.unstack(mask); - - // masking_fn = (time) => - // { - // return mask_ta.read(time); - // }; - - // compute_masked_output = (mask_t, flat_out, flat_mask) => - // { - // var tiled_mask_t = new Tensors(); - // foreach (var o in flat_out) - // { - // tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); - // } - - // Tensors res = new Tensors(); - // foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask)) - // { - // res.Add(tf.where(m, o, fm)); - // } - // return res; - // }; - // } - // // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? - // else if (input_length is Tensor) - // { - // if (go_backwards) - // { - // var max_len = tf.reduce_max(input_length, axis: 0); - // var rev_input_length = tf.subtract(max_len - 1, input_length); - - // masking_fn = (time) => - // { - // return tf.less(rev_input_length, time); - // }; - // } - // else - // { - // masking_fn = (time) => - // { - // return tf.greater(input_length, time); - // }; - // } - - // compute_masked_output = (mask_t, flat_out, flat_mask) => - // { - // var res = new List(); - // foreach (var (o, zo) in zip(flat_out, flat_mask)) - // { - // res.Add(tf.where(mask_t, o, zo)); - // } - // return res; - // }; - // } - // else - // { - // masking_fn = null; - // } - - - // if (masking_fn != null) - // { - // // Mask for the T output will be base on the output of T - 1. In the - // // case T = 0, a zero filled tensor will be used. - // var flat_zero_output = new Tensors(); - // foreach (var o in nest.flatten(output_time_zero)) - // { - // flat_zero_output.Add(tf.zeros_like(o)); - // } - - - // (Tensor, List, Tensors, Tensors) _step(Tensor time, List output_ta_t, Tensors prev_output, Tensors states) - // { - // /* - // RNN step function. - // Args: - // time: Current timestep value. - // output_ta_t: TensorArray. - // prev_output: tuple of outputs from time - 1. - // *states: List of states. - // Returns: - // Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` - // */ - - // var current_input = input_ta.Select(x => x.read(time)).ToList(); - // // maybe set shape - // // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type - // current_input = (List)nest.pack_sequence_as(inputs, current_input); - // var mask_t = masking_fn(time); - // var (output, new_states) = step_function(current_input, new Tensors { states, constants }); - // // mask output - // //var flat_output = nest.flatten(output); - // var flat_output = output.ToList(); - - // var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); - - // // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type - // var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); - - // // mask states - // var flat_state = states.ToList(); - // var flat_new_state = new_states.ToList(); - - // foreach (var (state, new_state) in zip(flat_state, flat_new_state)) - // { - // if (new_state is Tensor) - // { - // new_state.set_shape(state.shape); - // } - // } - - // var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); - // new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state); - - // var ta_index_to_write = return_all_outputs ? time : tf.constant(0); - // var Output_ta_t = new List(); - // // TODO(Wanglongzhi2001),deal with zip output_ta_t - // foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) - // { - // Output_ta_t.Add(ta.write(ta_index_to_write, Out)); - // } - - - - // //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); - - - // return (time + 1, Output_ta_t, flat_new_output, new_states); - - // } - // Func cond = (time) => (time < time_step_t); - - // var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states)); - // new_states = final_outputs.Item4; - // output_ta = final_outputs.Item2; - - // } - // else - // { - // (Tensor, List, Tensors) _step(Tensor time, List output_ta_t, Tensors states) - // { - // var current_input = input_ta.Select(x => x.read(time)).ToList(); - // // maybe set shape - // // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type - // current_input = (List)nest.pack_sequence_as(inputs, current_input); - // var (output, new_states) = step_function(current_input, new Tensors { states, constants }); - // var flat_state = states.ToList(); - // var flat_new_state = new_states.ToList(); - // foreach (var (state, new_state) in zip(flat_state, flat_new_state)) - // { - // if (new_state is Tensor) - // { - // new_state.set_shape(state.shape); - // } - // } - // var flat_output = output.ToList(); - // var ta_index_to_write = return_all_outputs ? time : tf.constant(0); - // var Output_ta_t = new List(); - // foreach (var (ta, out_) in zip(output_ta_t, flat_output)) - // { - // Output_ta_t.Add(ta.write(ta_index_to_write, out_)); - // } - - // new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); - // return (time + 1, Output_ta_t, new_states); - // } - // Func cond = (time) => (time < time_step_t); - // var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states)); - // new_states = final_outputs.Item3; - // output_ta = final_outputs.Item2; - - // } - // //Tensors outputs = new Tensors(); - // foreach (var o in output_ta) - // { - // outputs.Add(o.stack()); - // } - // foreach (var o in outputs) - // { - // last_output.Add(o[-1]); - // } - // outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs); - // last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output); - - //} Func set_shape; set_shape = (output_) => @@ -947,18 +950,38 @@ namespace Tensorflow.Keras shape[0] = 1; } shape[1] = (int)batch; - output_.set_shape(new Tensor(shape)); + output_.shape = shape; } return output_; }; - var Outputs = (Tensors)nest.map_structure(set_shape, outputs); + outputs = Nest.MapStructure(set_shape, outputs).ToTensors(); if (!time_major) { - Outputs = nest.map_structure(swap_batch_timestep, outputs); + outputs = Nest.MapStructure(swap_batch_timestep, outputs).ToTensors(); + } + return (last_output, outputs, new_states); + + } + + public Tensor reverse(Tensor input, int axis) + { + return reverse(input, new int[] { axis }); + } + + public Tensor reverse(Tensor input, int[] axes) + { + return tf.reverse(input, axes); + } + + public Tensor maybe_convert_to_ragged(bool is_ragged_output, Tensor output, int nested_row_lengths, bool go_backwards = false) + { + if (!is_ragged_output) + { + return output; } - return (last_output, Outputs, new_states); + throw new NotImplementedException("Not implemented currently, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index b014737f..ab4cef12 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -55,8 +55,8 @@ namespace Tensorflow.Keras.Layers.Rnn if (_states == null) { // CHECK(Rinne): check if this is correct. - var state = nest.map_structure(x => null, _cell.StateSize); - return new Tensors { state }; + var nested = _cell.StateSize.MapStructure(x => null); + _states = nested.AsNest().ToTensors(); } return _states; } @@ -230,7 +230,7 @@ namespace Tensorflow.Keras.Layers.Rnn Tensors? mask = rnn_optional_args?.Mask; //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); // 暂时先不接受ragged tensor - int? row_length = null; + int row_length = 0; // TODO(Rinne): support this param. bool is_ragged_input = false; _validate_args_if_ragged(is_ragged_input, mask); @@ -249,16 +249,16 @@ namespace Tensorflow.Keras.Layers.Rnn if (mask != null) { // Time step masks must be the same for each input. - mask = nest.flatten(mask)[0]; + mask = mask.Flatten().First(); } Shape input_shape; - if (nest.is_nested(inputs)) + if (!inputs.IsSingle()) { // In the case of nested input, use the first element for shape check // input_shape = nest.flatten(inputs)[0].shape; // TODO(Wanglongzhi2001) - input_shape = nest.flatten(inputs)[0].shape; + input_shape = inputs.Flatten().First().shape; } else { @@ -286,6 +286,7 @@ namespace Tensorflow.Keras.Layers.Rnn // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call) Func step; + bool is_tf_rnn_cell = _cell.IsTFRnnCell; if (constants is not null) { if (!_cell.SupportOptionalArgs) @@ -299,7 +300,8 @@ namespace Tensorflow.Keras.Layers.Rnn { constants = new Tensors(states.TakeLast(_num_constants)); states = new Tensors(states.SkipLast(_num_constants)); - var(output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); + states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; + var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); // TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors? return (output, new_states.Single); }; @@ -308,13 +310,13 @@ namespace Tensorflow.Keras.Layers.Rnn { step = (inputs, states) => { - // states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) + states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; var (output, new_states) = _cell.Apply(inputs, states); return (output, new_states.Single); }; } - var (last_output, outputs, states) = BackendImpl.rnn(step, + var (last_output, outputs, states) = keras.backend.rnn(step, inputs, initial_state, constants: constants, @@ -334,8 +336,8 @@ namespace Tensorflow.Keras.Layers.Rnn Tensors output = new Tensors(); if (_args.ReturnSequences) { - throw new NotImplementedException("this argument havn't been developed."); - + // TODO(Rinne): add go_backwards parameter and revise the `row_length` param + output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false); } else { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs index fcb5d1eb..751312e5 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs @@ -14,8 +14,8 @@ namespace Tensorflow.Keras.Layers.Rnn public RnnCellBase(LayerArgs args) : base(args) { } public abstract GeneralizedTensorShape StateSize { get; } public abstract GeneralizedTensorShape OutputSize { get; } + public abstract bool IsTFRnnCell { get; } public abstract bool SupportOptionalArgs { get; } - public abstract (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null); public virtual Tensors GetInitialState(Tensors inputs, long batch_size, TF_DataType dtype) { return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index abb57d8a..f0b2ed4d 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -5,6 +5,7 @@ using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; using Tensorflow.Common.Types; +using Tensorflow.Common.Extensions; namespace Tensorflow.Keras.Layers.Rnn { @@ -26,6 +27,7 @@ namespace Tensorflow.Keras.Layers.Rnn public override GeneralizedTensorShape StateSize => _state_size; public override GeneralizedTensorShape OutputSize => _output_size; + public override bool IsTFRnnCell => true; public override bool SupportOptionalArgs => false; public SimpleRNNCell(SimpleRNNCellArgs args) : base(args) @@ -66,37 +68,22 @@ namespace Tensorflow.Keras.Layers.Rnn built = true; } - public override (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) + // TODO(Rinne): revise the trining param (with refactoring of the framework) + protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) { // TODO(Rinne): check if it will have multiple tensors when not nested. - Tensor prev_output = states[0]; + Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states; var dp_mask = get_dropout_maskcell_for_cell(inputs, training.Value); var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); Tensor h; - var ranks = inputs.rank; if (dp_mask != null) { - if (ranks > 2) - { - // 因为multiply函数会自动添加第一个维度,所以加上下标0 - h = tf.linalg.tensordot(math_ops.multiply(inputs, dp_mask)[0], _kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); - } - else - { - h = math_ops.matmul(math_ops.multiply(inputs, dp_mask)[0], _kernel.AsTensor()); - } + h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor()); } else { - if (ranks > 2) - { - h = tf.linalg.tensordot(inputs, _kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); - } - else - { - h = math_ops.matmul(inputs, _kernel.AsTensor()); - } + h = math_ops.matmul(inputs, _kernel.AsTensor()); } if (_bias != null) @@ -106,26 +93,25 @@ namespace Tensorflow.Keras.Layers.Rnn if (rec_dp_mask != null) { - prev_output = math_ops.multiply(prev_output, rec_dp_mask)[0]; + prev_output = math_ops.multiply(prev_output, rec_dp_mask); } - ranks = prev_output.rank; - Tensor output; - if (ranks > 2) + Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); + + if (_args.Activation != null) { - output = h + tf.linalg.tensordot(prev_output[0], _recurrent_kernel.AsTensor(), new[,] { { ranks - 1 }, { 0 } }); + output = _args.Activation.Apply(output); } - else + if (Nest.IsNested(states)) { - output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); + return new Nest(new List> { + new Nest(new List> { new Nest(output) }), new Nest(output) }) + .ToTensors(); } - Console.WriteLine($"shape of output: {output.shape}"); - - if (_args.Activation != null) + else { - output = _args.Activation.Apply(output); + return new Tensors(output, output); } - return (output, new Tensors { output }); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs index 7923192f..0b92fd3c 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -170,6 +170,7 @@ namespace Tensorflow.Keras.Layers.Rnn } public GeneralizedTensorShape StateSize => throw new NotImplementedException(); public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); + public bool IsTFRnnCell => throw new NotImplementedException(); public bool SupportOptionalArgs => throw new NotImplementedException(); } } diff --git a/src/TensorflowNET.Hub/KerasLayer.cs b/src/TensorflowNET.Hub/KerasLayer.cs index b9ca949b..20d9851b 100644 --- a/src/TensorflowNET.Hub/KerasLayer.cs +++ b/src/TensorflowNET.Hub/KerasLayer.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Common.Types; using Tensorflow.Keras.Engine; using Tensorflow.Train; using Tensorflow.Training; @@ -89,7 +90,7 @@ namespace Tensorflow.Hub } } - protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optionalArgs = null) { _check_trainability();