|
|
@@ -651,13 +651,13 @@ namespace Tensorflow.Keras |
|
|
|
states = Nest.PackSequenceAs(states, flat_final_states).ToTensors(); |
|
|
|
if (return_all_outputs) |
|
|
|
{ |
|
|
|
successive_outputs.Add(output); |
|
|
|
successive_states.Add(states); |
|
|
|
successive_outputs = successive_outputs.MergeWith(output); |
|
|
|
successive_outputs = successive_states.MergeWith(states); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
successive_outputs = new Tensors { output }; |
|
|
|
successive_states = new Tensors { states }; |
|
|
|
successive_outputs = new Tensors(output); |
|
|
|
successive_states = new Tensors(states); |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
@@ -722,16 +722,11 @@ namespace Tensorflow.Keras |
|
|
|
// Get the time(0) input and compute the output for that, the output will |
|
|
|
// be used to determine the dtype of output tensor array. Don't read from |
|
|
|
// input_ta due to TensorArray clear_after_read default to True. |
|
|
|
var inps = new Tensors(); |
|
|
|
foreach (var inp in flatted_inptus) |
|
|
|
{ |
|
|
|
inps.Add(inp[0]); |
|
|
|
} |
|
|
|
var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors(); |
|
|
|
var input_time_zero = Nest.PackSequenceAs(inputs, flatted_inptus.Select(x => x[0]).ToArray()).ToTensors(); |
|
|
|
|
|
|
|
// output_time_zero is used to determine the cell output shape and its |
|
|
|
// dtype. the value is discarded. |
|
|
|
(output_time_zero, _) = step_function((Tensor)input_time_zero, |
|
|
|
(output_time_zero, _) = step_function(input_time_zero, |
|
|
|
constants is null ? initial_states : initial_states.MergeWith(constants)); |
|
|
|
|
|
|
|
int output_ta_size = return_all_outputs ? time_steps_t : 1; |
|
|
@@ -816,6 +811,7 @@ namespace Tensorflow.Keras |
|
|
|
|
|
|
|
Func<Tensor, Tensor> cond = (time) => (time < time_steps_t); |
|
|
|
int parallel_iterations = 32; |
|
|
|
new_states = states; |
|
|
|
if (masking_fn != null) |
|
|
|
{ |
|
|
|
// Mask for the T output will be base on the output of T - 1. In the |
|
|
@@ -846,7 +842,7 @@ namespace Tensorflow.Keras |
|
|
|
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type |
|
|
|
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); |
|
|
|
var mask_t = masking_fn(time); |
|
|
|
var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants)); |
|
|
|
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants)); |
|
|
|
// mask output |
|
|
|
var flat_output = Nest.Flatten(output).ToList(); |
|
|
|
|
|
|
@@ -871,11 +867,12 @@ namespace Tensorflow.Keras |
|
|
|
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors(); |
|
|
|
|
|
|
|
var ta_index_to_write = return_all_outputs ? time : tf.constant(0); |
|
|
|
// TODO(Wanglongzhi2001),deal with zip output_ta_t |
|
|
|
foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) |
|
|
|
output_ta_t = zip(output_ta_t, flat_new_output).Select(item => |
|
|
|
{ |
|
|
|
output_ta_t.Add(ta.write(ta_index_to_write, Out)); |
|
|
|
} |
|
|
|
var (ta, out_) = item; |
|
|
|
return ta.write(ta_index_to_write, out_); |
|
|
|
}).ToList(); |
|
|
|
|
|
|
|
|
|
|
|
new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); |
|
|
|
|
|
|
@@ -921,15 +918,8 @@ namespace Tensorflow.Keras |
|
|
|
} |
|
|
|
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); |
|
|
|
} |
|
|
|
//Tensors outputs = new Tensors(); |
|
|
|
foreach (var o in output_ta) |
|
|
|
{ |
|
|
|
outputs.Add(o.stack()); |
|
|
|
} |
|
|
|
foreach (var o in outputs) |
|
|
|
{ |
|
|
|
last_output.Add(o[-1]); |
|
|
|
} |
|
|
|
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToTensors()); |
|
|
|
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToTensors()); |
|
|
|
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors(); |
|
|
|
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors(); |
|
|
|
|
|
|
|