|
|
@@ -20,9 +20,11 @@ using System.Linq; |
|
|
|
using System.Collections.Generic; |
|
|
|
using Tensorflow.Functions; |
|
|
|
using Tensorflow.Graphs; |
|
|
|
using Tensorflow.Common.Extensions; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
using static Tensorflow.Graphs.SubGraphUtility; |
|
|
|
using Tensorflow.Util; |
|
|
|
using Tensorflow.Common.Types; |
|
|
|
|
|
|
|
namespace Tensorflow.Keras |
|
|
|
{ |
|
|
@@ -452,7 +454,7 @@ namespace Tensorflow.Keras |
|
|
|
return x; |
|
|
|
} |
|
|
|
|
|
|
|
public static (Tensors, Tensors, Tensors) rnn( |
|
|
|
public (Tensors, Tensors, Tensors) rnn( |
|
|
|
Func<Tensors, Tensors, (Tensors, Tensors)> step_function, // args:inputs, states, return:output, new_states |
|
|
|
Tensors inputs, // inputs is a tuple of tensors (one per input sequence) |
|
|
|
Tensors initial_states, |
|
|
@@ -466,7 +468,7 @@ namespace Tensorflow.Keras |
|
|
|
bool return_all_outputs = true) |
|
|
|
{ |
|
|
|
|
|
|
|
Tensors swap_batch_timestep(Tensors input_t) |
|
|
|
Tensor swap_batch_timestep(Tensor input_t) |
|
|
|
{ |
|
|
|
var axes = Enumerable.Range(0, input_t.rank).ToArray(); |
|
|
|
axes[0] = 1; |
|
|
@@ -476,13 +478,14 @@ namespace Tensorflow.Keras |
|
|
|
|
|
|
|
if (!time_major) |
|
|
|
{ |
|
|
|
inputs = nest.map_structure(swap_batch_timestep, inputs); |
|
|
|
inputs = Nest.MapStructure(swap_batch_timestep, inputs).ToTensors(); |
|
|
|
} |
|
|
|
|
|
|
|
var flatted_inptus = nest.flatten(inputs); |
|
|
|
var time_steps = flatted_inptus[0].shape[0]; |
|
|
|
var batch = flatted_inptus[0].shape[1]; |
|
|
|
var time_step_t = tf.shape(flatted_inptus[0])[0]; |
|
|
|
var flatted_inptus = Nest.Flatten(inputs).ToList(); |
|
|
|
var first_flatted_input = flatted_inptus[0]; |
|
|
|
var time_steps = first_flatted_input.shape[0]; |
|
|
|
var batch = first_flatted_input.shape[1]; |
|
|
|
var time_steps_t = (int)first_flatted_input.shape[0]; |
|
|
|
|
|
|
|
foreach (var input_ in flatted_inptus) |
|
|
|
{ |
|
|
@@ -508,11 +511,6 @@ namespace Tensorflow.Keras |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
if (constants == null) |
|
|
|
{ |
|
|
|
constants = new List<Tensor>(); |
|
|
|
} |
|
|
|
|
|
|
|
// tf.where needs its condition tensor to be the same shape as its two |
|
|
|
// result tensors, but in our case the condition (mask) tensor is |
|
|
|
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. |
|
|
@@ -522,12 +520,12 @@ namespace Tensorflow.Keras |
|
|
|
|
|
|
|
Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1) |
|
|
|
{ |
|
|
|
if (nest.is_nested(mask_t)) |
|
|
|
if (!mask_t.IsSingle()) |
|
|
|
{ |
|
|
|
throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}"); |
|
|
|
} |
|
|
|
|
|
|
|
if (nest.is_nested(input_t)) |
|
|
|
if (!input_t.IsSingle()) |
|
|
|
{ |
|
|
|
throw new ValueError($"input_t is expected to be tensor, but got {input_t}"); |
|
|
|
} |
|
|
@@ -575,21 +573,21 @@ namespace Tensorflow.Keras |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Tensors _process_single_input_t(Tensors input_t) |
|
|
|
Tensors _process_single_input_t(Tensor input_t) |
|
|
|
{ |
|
|
|
input_t = tf.unstack(input_t); // unstack for time_step dim |
|
|
|
var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim |
|
|
|
if (go_backwards) |
|
|
|
{ |
|
|
|
input_t.Reverse(); |
|
|
|
unstaked_input_t = unstaked_input_t.Reverse().ToArray(); |
|
|
|
} |
|
|
|
return input_t; |
|
|
|
return unstaked_input_t; |
|
|
|
} |
|
|
|
|
|
|
|
// TODO(Wanglongzhi2001) |
|
|
|
Tensors processed_input; |
|
|
|
if (nest.is_nested(inputs)) |
|
|
|
if (!inputs.IsSingle()) |
|
|
|
{ |
|
|
|
processed_input = nest.map_structure(_process_single_input_t, inputs); |
|
|
|
processed_input = inputs.MapStructure(_process_single_input_t).ReduceTo<Tensors, Tensor>().ToTensors(); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
@@ -603,334 +601,339 @@ namespace Tensorflow.Keras |
|
|
|
{ |
|
|
|
inp.Add(t_[time]); |
|
|
|
} |
|
|
|
return nest.pack_sequence_as(inputs, inp); |
|
|
|
return Nest.PackSequenceAs(inputs, inp); |
|
|
|
} |
|
|
|
|
|
|
|
//if (mask != null) |
|
|
|
//{ |
|
|
|
// var mask_list = tf.unstack(mask); |
|
|
|
// if (go_backwards) |
|
|
|
// { |
|
|
|
// mask_list.Reverse(); |
|
|
|
// } |
|
|
|
|
|
|
|
// for (int i = 0; i < time_steps; i++) |
|
|
|
// { |
|
|
|
// // TODO(Wanglongzhi2001),deal with _get_input_tensor |
|
|
|
// var inp = _get_input_tensor(i); |
|
|
|
// var mask_t = mask_list[i]; |
|
|
|
// // TODO |
|
|
|
// var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); |
|
|
|
|
|
|
|
// var tiled_mask_t = _expand_mask(mask_t, output); |
|
|
|
|
|
|
|
// Tensors prev_output; |
|
|
|
// if (successive_outputs == null) |
|
|
|
// { |
|
|
|
// prev_output = tf.zeros_like(output); |
|
|
|
// } |
|
|
|
// else |
|
|
|
// { |
|
|
|
// prev_output = successive_outputs[successive_outputs.Length - 1]; |
|
|
|
// } |
|
|
|
|
|
|
|
// output = tf.where(tiled_mask_t, output, prev_output); |
|
|
|
|
|
|
|
// //var flat_states = nest.flatten(states); |
|
|
|
// //var flat_new_states = nest.flatten(newStates); |
|
|
|
// var flat_states = states.ToList(); |
|
|
|
// var flat_new_states = newStates.ToList(); |
|
|
|
|
|
|
|
// var tiledMaskT = flat_states |
|
|
|
// .Select(s => _expand_mask(mask_t, s)) |
|
|
|
// .ToArray(); |
|
|
|
// var tuple = Tuple.Create(tiledMaskT); |
|
|
|
|
|
|
|
// List<Tensor> flat_final_states = new List<Tensor>(); |
|
|
|
// foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states)) |
|
|
|
// { |
|
|
|
// flat_final_states.Add(tf.where(m, s, ps)); |
|
|
|
// } |
|
|
|
|
|
|
|
// states = (Tensors)nest.pack_sequence_as(states, flat_final_states); |
|
|
|
// if (return_all_outputs) |
|
|
|
// { |
|
|
|
// successive_outputs.Add(output); |
|
|
|
// successive_states.Add(states); |
|
|
|
// } |
|
|
|
// else |
|
|
|
// { |
|
|
|
// successive_outputs = new Tensors { output }; |
|
|
|
// successive_states = new Tensors { states }; |
|
|
|
// } |
|
|
|
|
|
|
|
// } |
|
|
|
// last_output = successive_outputs[successive_outputs.Length - 1]; |
|
|
|
// new_states = successive_states[successive_states.Length - 1]; |
|
|
|
// outputs = tf.stack(successive_outputs); |
|
|
|
|
|
|
|
// if (zero_output_for_mask) |
|
|
|
// { |
|
|
|
// last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); |
|
|
|
// outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); |
|
|
|
// } |
|
|
|
// else // mask is null |
|
|
|
// { |
|
|
|
// for (int i = 0; i < time_steps; i++) |
|
|
|
// { |
|
|
|
// var inp = _get_input_tensor(i); |
|
|
|
// var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); |
|
|
|
// states = newStates; |
|
|
|
|
|
|
|
// if (return_all_outputs) |
|
|
|
// { |
|
|
|
// successive_outputs.Add(output); |
|
|
|
// successive_states.Add(newStates); |
|
|
|
// } |
|
|
|
// else |
|
|
|
// { |
|
|
|
// successive_outputs = new Tensors { output }; |
|
|
|
// successive_states = new Tensors { newStates }; |
|
|
|
// } |
|
|
|
// } |
|
|
|
// last_output = successive_outputs[successive_outputs.Length - 1]; |
|
|
|
// new_states = successive_states[successive_states.Length - 1]; |
|
|
|
// outputs = tf.stack(successive_outputs); |
|
|
|
// } |
|
|
|
//} |
|
|
|
if (mask != null) |
|
|
|
{ |
|
|
|
var mask_list = tf.unstack(mask); |
|
|
|
if (go_backwards) |
|
|
|
{ |
|
|
|
mask_list.Reverse(); |
|
|
|
} |
|
|
|
|
|
|
|
for (int i = 0; i < time_steps; i++) |
|
|
|
{ |
|
|
|
// TODO(Wanglongzhi2001),deal with _get_input_tensor |
|
|
|
var inp = _get_input_tensor(i); |
|
|
|
var mask_t = mask_list[i]; |
|
|
|
// TODO |
|
|
|
var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants)); |
|
|
|
|
|
|
|
var tiled_mask_t = _expand_mask(mask_t, output); |
|
|
|
|
|
|
|
Tensors prev_output; |
|
|
|
if (successive_outputs == null) |
|
|
|
{ |
|
|
|
prev_output = tf.zeros_like(output); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
prev_output = successive_outputs[successive_outputs.Length - 1]; |
|
|
|
} |
|
|
|
|
|
|
|
output = tf.where(tiled_mask_t, output, prev_output); |
|
|
|
|
|
|
|
var flat_states = Nest.Flatten(states).ToList(); |
|
|
|
var flat_new_states = Nest.Flatten(newStates).ToList(); |
|
|
|
|
|
|
|
var tiledMaskT = flat_states |
|
|
|
.Select(s => _expand_mask(mask_t, s)) |
|
|
|
.ToArray(); |
|
|
|
var tuple = Tuple.Create(tiledMaskT); |
|
|
|
|
|
|
|
List<Tensor> flat_final_states = new List<Tensor>(); |
|
|
|
foreach (var (m, s, ps) in zip(tiled_mask_t.ToList(), flat_new_states, flat_states)) |
|
|
|
{ |
|
|
|
flat_final_states.Add(tf.where(m, s, ps)); |
|
|
|
} |
|
|
|
|
|
|
|
states = Nest.PackSequenceAs(states, flat_final_states).ToTensors(); |
|
|
|
if (return_all_outputs) |
|
|
|
{ |
|
|
|
successive_outputs.Add(output); |
|
|
|
successive_states.Add(states); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
successive_outputs = new Tensors { output }; |
|
|
|
successive_states = new Tensors { states }; |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
last_output = successive_outputs[successive_outputs.Length - 1]; |
|
|
|
new_states = successive_states[successive_states.Length - 1]; |
|
|
|
outputs = tf.stack(successive_outputs); |
|
|
|
|
|
|
|
if (zero_output_for_mask) |
|
|
|
{ |
|
|
|
last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); |
|
|
|
outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); |
|
|
|
} |
|
|
|
else // mask is null |
|
|
|
{ |
|
|
|
for (int i = 0; i < time_steps; i++) |
|
|
|
{ |
|
|
|
var inp = _get_input_tensor(i); |
|
|
|
var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants)); |
|
|
|
states = newStates; |
|
|
|
|
|
|
|
if (return_all_outputs) |
|
|
|
{ |
|
|
|
successive_outputs.Add(output); |
|
|
|
successive_states.Add(newStates); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
successive_outputs = new Tensors { output }; |
|
|
|
successive_states = new Tensors { newStates }; |
|
|
|
} |
|
|
|
} |
|
|
|
last_output = successive_outputs[successive_outputs.Length - 1]; |
|
|
|
new_states = successive_states[successive_states.Length - 1]; |
|
|
|
outputs = tf.stack(successive_outputs); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
else // unroll == false |
|
|
|
{ |
|
|
|
var states = initial_states; |
|
|
|
// Create input tensor array, if the inputs is nested tensors, then it |
|
|
|
// will be flattened first, and tensor array will be created one per |
|
|
|
// flattened tensor. |
|
|
|
var input_ta = new List<TensorArray>(); |
|
|
|
for (int i = 0; i < flatted_inptus.Count; i++) |
|
|
|
{ |
|
|
|
input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t)); |
|
|
|
} |
|
|
|
|
|
|
|
foreach(var (ta, input_) in zip(input_ta, flatted_inptus)) |
|
|
|
{ |
|
|
|
if (!go_backwards) |
|
|
|
{ |
|
|
|
ta.unstack(input_); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
ta.unstack(reverse(input_, 0)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Get the time(0) input and compute the output for that, the output will |
|
|
|
// be used to determine the dtype of output tensor array. Don't read from |
|
|
|
// input_ta due to TensorArray clear_after_read default to True. |
|
|
|
var inps = new Tensors(); |
|
|
|
foreach (var inp in flatted_inptus) |
|
|
|
{ |
|
|
|
inps.Add(inp[0]); |
|
|
|
} |
|
|
|
var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors(); |
|
|
|
|
|
|
|
// output_time_zero is used to determine the cell output shape and its |
|
|
|
// dtype. the value is discarded. |
|
|
|
(output_time_zero, _) = step_function((Tensor)input_time_zero, |
|
|
|
constants is null ? initial_states : initial_states.MergeWith(constants)); |
|
|
|
|
|
|
|
int output_ta_size = return_all_outputs ? time_steps_t : 1; |
|
|
|
var output_ta = new List<TensorArray>(); |
|
|
|
for (int i = 0; i < output_time_zero.ToList().Count; i++) |
|
|
|
{ |
|
|
|
var Out = output_time_zero.ToList()[i]; |
|
|
|
output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); |
|
|
|
} |
|
|
|
|
|
|
|
var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Func<Tensor, Tensor>? masking_fn; |
|
|
|
Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null; |
|
|
|
if (mask != null) |
|
|
|
{ |
|
|
|
if (go_backwards) |
|
|
|
{ |
|
|
|
mask = tf.reverse(mask, axis: new[] { 0 }); |
|
|
|
} |
|
|
|
var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t); |
|
|
|
mask_ta = mask_ta.unstack(mask); |
|
|
|
|
|
|
|
masking_fn = (time) => |
|
|
|
{ |
|
|
|
return mask_ta.read(time); |
|
|
|
}; |
|
|
|
|
|
|
|
compute_masked_output = (mask_t, flat_out, flat_mask) => |
|
|
|
{ |
|
|
|
var tiled_mask_t = new Tensors(); |
|
|
|
foreach (var o in flat_out) |
|
|
|
{ |
|
|
|
tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); |
|
|
|
} |
|
|
|
|
|
|
|
Tensors res = new Tensors(); |
|
|
|
foreach (var (m, o, fm) in zip(tiled_mask_t.ToList(), flat_out.ToList(), flat_mask.ToList())) |
|
|
|
{ |
|
|
|
res.Add(tf.where(m, o, fm)); |
|
|
|
} |
|
|
|
return res; |
|
|
|
}; |
|
|
|
} |
|
|
|
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? |
|
|
|
else if (input_length is Tensor) |
|
|
|
{ |
|
|
|
if (go_backwards) |
|
|
|
{ |
|
|
|
var max_len = tf.reduce_max(input_length, axis: 0); |
|
|
|
var rev_input_length = tf.subtract(max_len - 1, input_length); |
|
|
|
|
|
|
|
masking_fn = (time) => |
|
|
|
{ |
|
|
|
return tf.less(rev_input_length, time); |
|
|
|
}; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
masking_fn = (time) => |
|
|
|
{ |
|
|
|
return tf.greater(input_length, time); |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
compute_masked_output = (mask_t, flat_out, flat_mask) => |
|
|
|
{ |
|
|
|
var res = new List<Tensor>(); |
|
|
|
foreach (var (o, zo) in zip(flat_out, flat_mask)) |
|
|
|
{ |
|
|
|
res.Add(tf.where(mask_t, o, zo)); |
|
|
|
} |
|
|
|
return res; |
|
|
|
}; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
masking_fn = null; |
|
|
|
} |
|
|
|
|
|
|
|
Func<Tensor, Tensor> cond = (time) => (time < time_steps_t); |
|
|
|
int parallel_iterations = 32; |
|
|
|
if (masking_fn != null) |
|
|
|
{ |
|
|
|
// Mask for the T output will be base on the output of T - 1. In the |
|
|
|
// case T = 0, a zero filled tensor will be used. |
|
|
|
var flat_zero_output = new Tensors(); |
|
|
|
foreach (var o in Nest.Flatten(output_time_zero)) |
|
|
|
{ |
|
|
|
flat_zero_output.Add(tf.zeros_like(o)); |
|
|
|
} |
|
|
|
|
|
|
|
var prev_output = flat_zero_output; |
|
|
|
var output_ta_t = output_ta; |
|
|
|
Tensor _step(Tensor time) |
|
|
|
{ |
|
|
|
/* |
|
|
|
RNN step function. |
|
|
|
Args: |
|
|
|
time: Current timestep value. |
|
|
|
output_ta_t: TensorArray. |
|
|
|
prev_output: tuple of outputs from time - 1. |
|
|
|
*states: List of states. |
|
|
|
Returns: |
|
|
|
Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` |
|
|
|
*/ |
|
|
|
|
|
|
|
var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); |
|
|
|
// maybe set shape |
|
|
|
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type |
|
|
|
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); |
|
|
|
var mask_t = masking_fn(time); |
|
|
|
var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants)); |
|
|
|
// mask output |
|
|
|
var flat_output = Nest.Flatten(output).ToList(); |
|
|
|
|
|
|
|
var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); |
|
|
|
|
|
|
|
// TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type |
|
|
|
var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); |
|
|
|
|
|
|
|
// mask states |
|
|
|
var flat_state = states.ToList(); |
|
|
|
var flat_new_state = new_states_internal.ToList(); |
|
|
|
|
|
|
|
foreach (var (state, new_state) in zip(flat_state, flat_new_state)) |
|
|
|
{ |
|
|
|
if (new_state is Tensor) |
|
|
|
{ |
|
|
|
new_state.shape = state.shape; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); |
|
|
|
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors(); |
|
|
|
|
|
|
|
var ta_index_to_write = return_all_outputs ? time : tf.constant(0); |
|
|
|
// TODO(Wanglongzhi2001),deal with zip output_ta_t |
|
|
|
foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) |
|
|
|
{ |
|
|
|
output_ta_t.Add(ta.write(ta_index_to_write, Out)); |
|
|
|
} |
|
|
|
|
|
|
|
new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); |
|
|
|
|
|
|
|
output_ta = output_ta_t; |
|
|
|
new_states = new_states_internal; |
|
|
|
return time + 1; |
|
|
|
|
|
|
|
} |
|
|
|
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
var output_ta_t = output_ta; |
|
|
|
new_states = states; |
|
|
|
Tensor _step(Tensor time) |
|
|
|
{ |
|
|
|
var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); |
|
|
|
// maybe set shape |
|
|
|
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type |
|
|
|
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); |
|
|
|
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants)); |
|
|
|
var flat_state = new_states.Flatten().ToList(); |
|
|
|
var flat_new_state = new_states_internal.Flatten().ToList(); |
|
|
|
foreach (var (state, new_state) in zip(flat_state, flat_new_state)) |
|
|
|
{ |
|
|
|
if (new_state is Tensor) |
|
|
|
{ |
|
|
|
new_state.shape = state.shape; |
|
|
|
} |
|
|
|
} |
|
|
|
var flat_output = Nest.Flatten(output); |
|
|
|
var ta_index_to_write = return_all_outputs ? time : tf.constant(0); |
|
|
|
output_ta_t = zip(output_ta_t, flat_output).Select(item => |
|
|
|
{ |
|
|
|
var (ta, out_) = item; |
|
|
|
return ta.write(ta_index_to_write, out_); |
|
|
|
}).ToList(); |
|
|
|
|
|
|
|
new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); |
|
|
|
output_ta = output_ta_t; |
|
|
|
new_states = new_states_internal; |
|
|
|
return time + 1; |
|
|
|
} |
|
|
|
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); |
|
|
|
} |
|
|
|
//Tensors outputs = new Tensors(); |
|
|
|
foreach (var o in output_ta) |
|
|
|
{ |
|
|
|
outputs.Add(o.stack()); |
|
|
|
} |
|
|
|
foreach (var o in outputs) |
|
|
|
{ |
|
|
|
last_output.Add(o[-1]); |
|
|
|
} |
|
|
|
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors(); |
|
|
|
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors(); |
|
|
|
|
|
|
|
} |
|
|
|
//else // unroll == false |
|
|
|
//{ |
|
|
|
// var states = initial_states; |
|
|
|
// // Create input tensor array, if the inputs is nested tensors, then it |
|
|
|
// // will be flattened first, and tensor array will be created one per |
|
|
|
// // flattened tensor. |
|
|
|
// var input_ta = new List<TensorArray>(); |
|
|
|
// for (int i = 0; i < flatted_inptus.Count; i++) |
|
|
|
// { |
|
|
|
// input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t)); |
|
|
|
// } |
|
|
|
|
|
|
|
// // Get the time(0) input and compute the output for that, the output will |
|
|
|
// // be used to determine the dtype of output tensor array. Don't read from |
|
|
|
// // input_ta due to TensorArray clear_after_read default to True. |
|
|
|
// var inps = new Tensors(); |
|
|
|
// foreach (var inp in flatted_inptus) |
|
|
|
// { |
|
|
|
// inps.Add(inp[0]); |
|
|
|
// } |
|
|
|
// var input_time_zero = nest.pack_sequence_as(inputs, inps); |
|
|
|
|
|
|
|
// // output_time_zero is used to determine the cell output shape and its |
|
|
|
// // dtype. the value is discarded. |
|
|
|
// (output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants }); |
|
|
|
|
|
|
|
// var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1); |
|
|
|
// var output_ta = new List<TensorArray>(); |
|
|
|
// for (int i = 0; i < output_time_zero.ToList().Count; i++) |
|
|
|
// { |
|
|
|
// var Out = output_time_zero.ToList()[i]; |
|
|
|
// output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); |
|
|
|
// } |
|
|
|
|
|
|
|
// var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Func<Tensor, Tensor>? masking_fn; |
|
|
|
// Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null; |
|
|
|
// if (mask != null) |
|
|
|
// { |
|
|
|
// if (go_backwards) |
|
|
|
// { |
|
|
|
// mask = tf.reverse(mask, axis: new[] { 0 }); |
|
|
|
// } |
|
|
|
// var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t); |
|
|
|
// mask_ta = mask_ta.unstack(mask); |
|
|
|
|
|
|
|
// masking_fn = (time) => |
|
|
|
// { |
|
|
|
// return mask_ta.read(time); |
|
|
|
// }; |
|
|
|
|
|
|
|
// compute_masked_output = (mask_t, flat_out, flat_mask) => |
|
|
|
// { |
|
|
|
// var tiled_mask_t = new Tensors(); |
|
|
|
// foreach (var o in flat_out) |
|
|
|
// { |
|
|
|
// tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); |
|
|
|
// } |
|
|
|
|
|
|
|
// Tensors res = new Tensors(); |
|
|
|
// foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask)) |
|
|
|
// { |
|
|
|
// res.Add(tf.where(m, o, fm)); |
|
|
|
// } |
|
|
|
// return res; |
|
|
|
// }; |
|
|
|
// } |
|
|
|
// // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? |
|
|
|
// else if (input_length is Tensor) |
|
|
|
// { |
|
|
|
// if (go_backwards) |
|
|
|
// { |
|
|
|
// var max_len = tf.reduce_max(input_length, axis: 0); |
|
|
|
// var rev_input_length = tf.subtract(max_len - 1, input_length); |
|
|
|
|
|
|
|
// masking_fn = (time) => |
|
|
|
// { |
|
|
|
// return tf.less(rev_input_length, time); |
|
|
|
// }; |
|
|
|
// } |
|
|
|
// else |
|
|
|
// { |
|
|
|
// masking_fn = (time) => |
|
|
|
// { |
|
|
|
// return tf.greater(input_length, time); |
|
|
|
// }; |
|
|
|
// } |
|
|
|
|
|
|
|
// compute_masked_output = (mask_t, flat_out, flat_mask) => |
|
|
|
// { |
|
|
|
// var res = new List<Tensor>(); |
|
|
|
// foreach (var (o, zo) in zip(flat_out, flat_mask)) |
|
|
|
// { |
|
|
|
// res.Add(tf.where(mask_t, o, zo)); |
|
|
|
// } |
|
|
|
// return res; |
|
|
|
// }; |
|
|
|
// } |
|
|
|
// else |
|
|
|
// { |
|
|
|
// masking_fn = null; |
|
|
|
// } |
|
|
|
|
|
|
|
|
|
|
|
// if (masking_fn != null) |
|
|
|
// { |
|
|
|
// // Mask for the T output will be base on the output of T - 1. In the |
|
|
|
// // case T = 0, a zero filled tensor will be used. |
|
|
|
// var flat_zero_output = new Tensors(); |
|
|
|
// foreach (var o in nest.flatten(output_time_zero)) |
|
|
|
// { |
|
|
|
// flat_zero_output.Add(tf.zeros_like(o)); |
|
|
|
// } |
|
|
|
|
|
|
|
|
|
|
|
// (Tensor, List<TensorArray>, Tensors, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors prev_output, Tensors states) |
|
|
|
// { |
|
|
|
// /* |
|
|
|
// RNN step function. |
|
|
|
// Args: |
|
|
|
// time: Current timestep value. |
|
|
|
// output_ta_t: TensorArray. |
|
|
|
// prev_output: tuple of outputs from time - 1. |
|
|
|
// *states: List of states. |
|
|
|
// Returns: |
|
|
|
// Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` |
|
|
|
// */ |
|
|
|
|
|
|
|
// var current_input = input_ta.Select(x => x.read(time)).ToList(); |
|
|
|
// // maybe set shape |
|
|
|
// // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type |
|
|
|
// current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input); |
|
|
|
// var mask_t = masking_fn(time); |
|
|
|
// var (output, new_states) = step_function(current_input, new Tensors { states, constants }); |
|
|
|
// // mask output |
|
|
|
// //var flat_output = nest.flatten(output); |
|
|
|
// var flat_output = output.ToList(); |
|
|
|
|
|
|
|
// var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); |
|
|
|
|
|
|
|
// // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type |
|
|
|
// var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); |
|
|
|
|
|
|
|
// // mask states |
|
|
|
// var flat_state = states.ToList(); |
|
|
|
// var flat_new_state = new_states.ToList(); |
|
|
|
|
|
|
|
// foreach (var (state, new_state) in zip(flat_state, flat_new_state)) |
|
|
|
// { |
|
|
|
// if (new_state is Tensor) |
|
|
|
// { |
|
|
|
// new_state.set_shape(state.shape); |
|
|
|
// } |
|
|
|
// } |
|
|
|
|
|
|
|
// var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); |
|
|
|
// new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state); |
|
|
|
|
|
|
|
// var ta_index_to_write = return_all_outputs ? time : tf.constant(0); |
|
|
|
// var Output_ta_t = new List<TensorArray>(); |
|
|
|
// // TODO(Wanglongzhi2001),deal with zip output_ta_t |
|
|
|
// foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) |
|
|
|
// { |
|
|
|
// Output_ta_t.Add(ta.write(ta_index_to_write, Out)); |
|
|
|
// } |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// //new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); |
|
|
|
|
|
|
|
|
|
|
|
// return (time + 1, Output_ta_t, flat_new_output, new_states); |
|
|
|
|
|
|
|
// } |
|
|
|
// Func<Tensor, Tensor> cond = (time) => (time < time_step_t); |
|
|
|
|
|
|
|
// var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states)); |
|
|
|
// new_states = final_outputs.Item4; |
|
|
|
// output_ta = final_outputs.Item2; |
|
|
|
|
|
|
|
// } |
|
|
|
// else |
|
|
|
// { |
|
|
|
// (Tensor, List<TensorArray>, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors states) |
|
|
|
// { |
|
|
|
// var current_input = input_ta.Select(x => x.read(time)).ToList(); |
|
|
|
// // maybe set shape |
|
|
|
// // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type |
|
|
|
// current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input); |
|
|
|
// var (output, new_states) = step_function(current_input, new Tensors { states, constants }); |
|
|
|
// var flat_state = states.ToList(); |
|
|
|
// var flat_new_state = new_states.ToList(); |
|
|
|
// foreach (var (state, new_state) in zip(flat_state, flat_new_state)) |
|
|
|
// { |
|
|
|
// if (new_state is Tensor) |
|
|
|
// { |
|
|
|
// new_state.set_shape(state.shape); |
|
|
|
// } |
|
|
|
// } |
|
|
|
// var flat_output = output.ToList(); |
|
|
|
// var ta_index_to_write = return_all_outputs ? time : tf.constant(0); |
|
|
|
// var Output_ta_t = new List<TensorArray>(); |
|
|
|
// foreach (var (ta, out_) in zip(output_ta_t, flat_output)) |
|
|
|
// { |
|
|
|
// Output_ta_t.Add(ta.write(ta_index_to_write, out_)); |
|
|
|
// } |
|
|
|
|
|
|
|
// new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); |
|
|
|
// return (time + 1, Output_ta_t, new_states); |
|
|
|
// } |
|
|
|
// Func<Tensor, Tensor> cond = (time) => (time < time_step_t); |
|
|
|
// var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states)); |
|
|
|
// new_states = final_outputs.Item3; |
|
|
|
// output_ta = final_outputs.Item2; |
|
|
|
|
|
|
|
// } |
|
|
|
// //Tensors outputs = new Tensors(); |
|
|
|
// foreach (var o in output_ta) |
|
|
|
// { |
|
|
|
// outputs.Add(o.stack()); |
|
|
|
// } |
|
|
|
// foreach (var o in outputs) |
|
|
|
// { |
|
|
|
// last_output.Add(o[-1]); |
|
|
|
// } |
|
|
|
// outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs); |
|
|
|
// last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output); |
|
|
|
|
|
|
|
//} |
|
|
|
|
|
|
|
Func<Tensor, Tensor> set_shape; |
|
|
|
set_shape = (output_) => |
|
|
@@ -947,18 +950,38 @@ namespace Tensorflow.Keras |
|
|
|
shape[0] = 1; |
|
|
|
} |
|
|
|
shape[1] = (int)batch; |
|
|
|
output_.set_shape(new Tensor(shape)); |
|
|
|
output_.shape = shape; |
|
|
|
} |
|
|
|
return output_; |
|
|
|
}; |
|
|
|
|
|
|
|
var Outputs = (Tensors)nest.map_structure(set_shape, outputs); |
|
|
|
outputs = Nest.MapStructure(set_shape, outputs).ToTensors(); |
|
|
|
if (!time_major) |
|
|
|
{ |
|
|
|
Outputs = nest.map_structure(swap_batch_timestep, outputs); |
|
|
|
outputs = Nest.MapStructure(swap_batch_timestep, outputs).ToTensors(); |
|
|
|
} |
|
|
|
return (last_output, outputs, new_states); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
public Tensor reverse(Tensor input, int axis) |
|
|
|
{ |
|
|
|
return reverse(input, new int[] { axis }); |
|
|
|
} |
|
|
|
|
|
|
|
public Tensor reverse(Tensor input, int[] axes) |
|
|
|
{ |
|
|
|
return tf.reverse(input, axes); |
|
|
|
} |
|
|
|
|
|
|
|
public Tensor maybe_convert_to_ragged(bool is_ragged_output, Tensor output, int nested_row_lengths, bool go_backwards = false) |
|
|
|
{ |
|
|
|
if (!is_ragged_output) |
|
|
|
{ |
|
|
|
return output; |
|
|
|
} |
|
|
|
return (last_output, Outputs, new_states); |
|
|
|
|
|
|
|
throw new NotImplementedException("Not implemented currently, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); |
|
|
|
} |
|
|
|
} |
|
|
|
} |