|
|
@@ -42,7 +42,7 @@ namespace Tensorflow.Operations |
|
|
|
public override GradLoopState grad_state => _grad_state; |
|
|
|
public override bool back_prop => _back_prop; |
|
|
|
|
|
|
|
public WhileContext(int? maximum_iterations = null, |
|
|
|
public WhileContext(Tensor maximum_iterations = null, |
|
|
|
int parallel_iterations = 10, |
|
|
|
bool back_prop = true, |
|
|
|
bool swap_memory = false, |
|
|
@@ -64,7 +64,7 @@ namespace Tensorflow.Operations |
|
|
|
_grad_state = grad_state; |
|
|
|
} |
|
|
|
|
|
|
|
private void _init_from_args(int? maximum_iterations, |
|
|
|
private void _init_from_args(Tensor maximum_iterations, |
|
|
|
int parallel_iterations, |
|
|
|
bool back_prop, |
|
|
|
bool swap_memory, |
|
|
@@ -107,9 +107,9 @@ namespace Tensorflow.Operations |
|
|
|
/// <summary> |
|
|
|
/// Add the loop termination condition and body to the graph. |
|
|
|
/// </summary> |
|
|
|
public Tensor[] BuildLoop(Func<Tensor, Tensor> pred, |
|
|
|
Func<Tensor, Tensor> body, |
|
|
|
Tensor[] loop_vars, |
|
|
|
internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred, |
|
|
|
Func<Tensor, TItem, LoopVar<TItem>> body, |
|
|
|
TItem loop_vars, |
|
|
|
TensorShape shape_invariants, |
|
|
|
bool return_same_structure) |
|
|
|
{ |
|
|
@@ -131,88 +131,107 @@ namespace Tensorflow.Operations |
|
|
|
return packed_exit_vars as Tensor[]; |
|
|
|
} |
|
|
|
|
|
|
|
private (Tensor[], Tensor[]) _BuildLoop(Func<Tensor, Tensor> pred, |
|
|
|
Func<Tensor, Tensor> body, |
|
|
|
Tensor[] original_loop_vars, |
|
|
|
Tensor[] loop_vars, |
|
|
|
private Tensor _convert_tensorarray_to_flow<TItem>(TItem tensor_or_tensor_array) |
|
|
|
{ |
|
|
|
if (tensor_or_tensor_array is TensorArray tensor_array) |
|
|
|
return tensor_array.flow; |
|
|
|
else if (tensor_or_tensor_array is Tensor tensor) |
|
|
|
return tensor; |
|
|
|
|
|
|
|
throw new NotImplementedException("_convert_tensorarray_to_flow"); |
|
|
|
} |
|
|
|
|
|
|
|
private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred, |
|
|
|
Func<Tensor, TItem, LoopVar<TItem>> body, |
|
|
|
TItem original_loop_vars, |
|
|
|
TItem loop_vars, |
|
|
|
TensorShape shape_invariants) |
|
|
|
{ |
|
|
|
var flat_loop_vars = original_loop_vars; |
|
|
|
|
|
|
|
// Convert TensorArrays to their flow variables |
|
|
|
var loop_vars_tensor = nest.map_structure( |
|
|
|
_convert_tensorarray_to_flow, |
|
|
|
nest.flatten(loop_vars)); |
|
|
|
|
|
|
|
// Let the context know the loop variables so the loop variables |
|
|
|
// would be added in the outer contexts properly. |
|
|
|
_InitializeValues(loop_vars); |
|
|
|
var real_vars = loop_vars; |
|
|
|
Tensor[] enter_vars = null; |
|
|
|
tf_with(ops.control_dependencies(null), delegate |
|
|
|
if (loop_vars is Tensor[] real_vars) |
|
|
|
{ |
|
|
|
enter_vars = real_vars.Select(x => _Enter(x, |
|
|
|
_name, |
|
|
|
is_constant: false, |
|
|
|
parallel_iterations: _parallel_iterations, |
|
|
|
use_input_shape: shape_invariants == null)) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
foreach(var x in enter_vars) |
|
|
|
_InitializeValues(real_vars); |
|
|
|
Tensor[] enter_vars = null; |
|
|
|
tf_with(ops.control_dependencies(null), delegate |
|
|
|
{ |
|
|
|
enter_vars = real_vars.Select(x => _Enter(x, |
|
|
|
_name, |
|
|
|
is_constant: false, |
|
|
|
parallel_iterations: _parallel_iterations, |
|
|
|
use_input_shape: shape_invariants == null)) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
foreach (var x in enter_vars) |
|
|
|
{ |
|
|
|
x.graph.prevent_feeding(x); |
|
|
|
if (_outer_context != null) |
|
|
|
_outer_context.AddInnerOp(x.op); |
|
|
|
} |
|
|
|
}); |
|
|
|
|
|
|
|
// Finds the closest enclosing non-None control pivot. |
|
|
|
var outer_context = _outer_context; |
|
|
|
while (outer_context != null) |
|
|
|
{ |
|
|
|
x.graph.prevent_feeding(x); |
|
|
|
if (_outer_context != null) |
|
|
|
_outer_context.AddInnerOp(x.op); |
|
|
|
|
|
|
|
} |
|
|
|
}); |
|
|
|
|
|
|
|
// Finds the closest enclosing non-None control pivot. |
|
|
|
var outer_context = _outer_context; |
|
|
|
while (outer_context != null) |
|
|
|
{ |
|
|
|
_SetShapeInvariants(real_vars, enter_vars, shape_invariants); |
|
|
|
|
|
|
|
// Fix the control inputs and control flow context of these enter ops. |
|
|
|
_FixControlInputsAndContext(enter_vars); |
|
|
|
_InitializeValues(enter_vars); |
|
|
|
_loop_enters = enter_vars.ToList(); |
|
|
|
|
|
|
|
var merge_vars = enter_vars |
|
|
|
.Select(x => merge(new[] { x, x })) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
_pivot_for_pred = merge_vars[0]; |
|
|
|
|
|
|
|
// Build the graph for pred. |
|
|
|
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); |
|
|
|
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); |
|
|
|
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem))); |
|
|
|
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); |
|
|
|
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
// Build the graph for body. |
|
|
|
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); |
|
|
|
// Convert TensorArray flow variables inside the context back into |
|
|
|
// their associated TensorArrays for calling the body. |
|
|
|
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); |
|
|
|
/*var body_result = body(packed_vars_for_body[0]); |
|
|
|
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); |
|
|
|
|
|
|
|
// Store body_result to keep track of TensorArrays returned by body |
|
|
|
var original_body_result = new[] { body_result }; |
|
|
|
// Convert TensorArrays returned by body into their flow variables |
|
|
|
var result = new[] { body_result }; |
|
|
|
|
|
|
|
var next_vars = new List<Tensor>(); |
|
|
|
foreach (var (m, v) in zip(merge_vars, result)) |
|
|
|
next_vars.Add(_AddNextAndBackEdge(m, v)); |
|
|
|
|
|
|
|
// Add the exit ops. |
|
|
|
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); |
|
|
|
_loop_exits = exit_vars; |
|
|
|
|
|
|
|
// Exit the loop. |
|
|
|
// ExitResult(exit_vars); |
|
|
|
return (original_body_result, exit_vars.ToArray());*/ |
|
|
|
} |
|
|
|
|
|
|
|
_SetShapeInvariants(real_vars, enter_vars, shape_invariants); |
|
|
|
|
|
|
|
// Fix the control inputs and control flow context of these enter ops. |
|
|
|
_FixControlInputsAndContext(enter_vars); |
|
|
|
_InitializeValues(enter_vars); |
|
|
|
_loop_enters = enter_vars.ToList(); |
|
|
|
|
|
|
|
var merge_vars = enter_vars |
|
|
|
.Select(x => merge(new[] { x, x })) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
_pivot_for_pred = merge_vars[0]; |
|
|
|
|
|
|
|
// Build the graph for pred. |
|
|
|
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); |
|
|
|
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); |
|
|
|
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0])); |
|
|
|
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); |
|
|
|
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
// Build the graph for body. |
|
|
|
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); |
|
|
|
// Convert TensorArray flow variables inside the context back into |
|
|
|
// their associated TensorArrays for calling the body. |
|
|
|
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); |
|
|
|
var body_result = body(packed_vars_for_body[0]); |
|
|
|
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); |
|
|
|
|
|
|
|
// Store body_result to keep track of TensorArrays returned by body |
|
|
|
var original_body_result = new[] { body_result }; |
|
|
|
// Convert TensorArrays returned by body into their flow variables |
|
|
|
var result = new[] { body_result }; |
|
|
|
|
|
|
|
var next_vars = new List<Tensor>(); |
|
|
|
foreach (var (m, v) in zip(merge_vars, result)) |
|
|
|
next_vars.Add(_AddNextAndBackEdge(m, v)); |
|
|
|
|
|
|
|
// Add the exit ops. |
|
|
|
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); |
|
|
|
_loop_exits = exit_vars; |
|
|
|
|
|
|
|
// Exit the loop. |
|
|
|
// ExitResult(exit_vars); |
|
|
|
return (original_body_result, exit_vars.ToArray()); |
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
|
|
|
|
private void _FixControlInputsAndContext(Tensor[] enters) |
|
|
|