diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
index 1faaa647..c00fc2c7 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
@@ -42,7 +42,7 @@ namespace Tensorflow.Operations
public override GradLoopState grad_state => _grad_state;
public override bool back_prop => _back_prop;
- public WhileContext(int? maximum_iterations = null,
+ public WhileContext(Tensor maximum_iterations = null,
int parallel_iterations = 10,
bool back_prop = true,
bool swap_memory = false,
@@ -64,7 +64,7 @@ namespace Tensorflow.Operations
_grad_state = grad_state;
}
- private void _init_from_args(int? maximum_iterations,
+ private void _init_from_args(Tensor maximum_iterations,
int parallel_iterations,
bool back_prop,
bool swap_memory,
@@ -107,9 +107,9 @@ namespace Tensorflow.Operations
///
/// Add the loop termination condition and body to the graph.
///
- public Tensor[] BuildLoop(Func pred,
- Func body,
- Tensor[] loop_vars,
+ internal Tensor[] BuildLoop(Func pred,
+ Func> body,
+ TItem loop_vars,
TensorShape shape_invariants,
bool return_same_structure)
{
@@ -131,88 +131,107 @@ namespace Tensorflow.Operations
return packed_exit_vars as Tensor[];
}
- private (Tensor[], Tensor[]) _BuildLoop(Func pred,
- Func body,
- Tensor[] original_loop_vars,
- Tensor[] loop_vars,
+ private Tensor _convert_tensorarray_to_flow(TItem tensor_or_tensor_array)
+ {
+ if (tensor_or_tensor_array is TensorArray tensor_array)
+ return tensor_array.flow;
+ else if (tensor_or_tensor_array is Tensor tensor)
+ return tensor;
+
+ throw new NotImplementedException("_convert_tensorarray_to_flow");
+ }
+
+ private (Tensor[], Tensor[]) _BuildLoop(Func pred,
+ Func> body,
+ TItem original_loop_vars,
+ TItem loop_vars,
TensorShape shape_invariants)
{
var flat_loop_vars = original_loop_vars;
+ // Convert TensorArrays to their flow variables
+ var loop_vars_tensor = nest.map_structure(
+ _convert_tensorarray_to_flow,
+ nest.flatten(loop_vars));
+
// Let the context know the loop variables so the loop variables
// would be added in the outer contexts properly.
- _InitializeValues(loop_vars);
- var real_vars = loop_vars;
- Tensor[] enter_vars = null;
- tf_with(ops.control_dependencies(null), delegate
+ if (loop_vars is Tensor[] real_vars)
{
- enter_vars = real_vars.Select(x => _Enter(x,
- _name,
- is_constant: false,
- parallel_iterations: _parallel_iterations,
- use_input_shape: shape_invariants == null))
- .ToArray();
-
- foreach(var x in enter_vars)
+ _InitializeValues(real_vars);
+ Tensor[] enter_vars = null;
+ tf_with(ops.control_dependencies(null), delegate
+ {
+ enter_vars = real_vars.Select(x => _Enter(x,
+ _name,
+ is_constant: false,
+ parallel_iterations: _parallel_iterations,
+ use_input_shape: shape_invariants == null))
+ .ToArray();
+
+ foreach (var x in enter_vars)
+ {
+ x.graph.prevent_feeding(x);
+ if (_outer_context != null)
+ _outer_context.AddInnerOp(x.op);
+ }
+ });
+
+ // Finds the closest enclosing non-None control pivot.
+ var outer_context = _outer_context;
+ while (outer_context != null)
{
- x.graph.prevent_feeding(x);
- if (_outer_context != null)
- _outer_context.AddInnerOp(x.op);
+
}
- });
- // Finds the closest enclosing non-None control pivot.
- var outer_context = _outer_context;
- while (outer_context != null)
- {
+ _SetShapeInvariants(real_vars, enter_vars, shape_invariants);
+
+ // Fix the control inputs and control flow context of these enter ops.
+ _FixControlInputsAndContext(enter_vars);
+ _InitializeValues(enter_vars);
+ _loop_enters = enter_vars.ToList();
+
+ var merge_vars = enter_vars
+ .Select(x => merge(new[] { x, x }))
+ .ToArray();
+
+ _pivot_for_pred = merge_vars[0];
+
+ // Build the graph for pred.
+ var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
+ // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
+ var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem)));
+ _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
+ var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
+ .ToArray();
+ // Build the graph for body.
+ var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
+ // Convert TensorArray flow variables inside the context back into
+ // their associated TensorArrays for calling the body.
+ var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
+ /*var body_result = body(packed_vars_for_body[0]);
+ var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
+
+ // Store body_result to keep track of TensorArrays returned by body
+ var original_body_result = new[] { body_result };
+ // Convert TensorArrays returned by body into their flow variables
+ var result = new[] { body_result };
+
+ var next_vars = new List();
+ foreach (var (m, v) in zip(merge_vars, result))
+ next_vars.Add(_AddNextAndBackEdge(m, v));
+
+ // Add the exit ops.
+ var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
+ _loop_exits = exit_vars;
+
+ // Exit the loop.
+ // ExitResult(exit_vars);
+ return (original_body_result, exit_vars.ToArray());*/
}
- _SetShapeInvariants(real_vars, enter_vars, shape_invariants);
-
- // Fix the control inputs and control flow context of these enter ops.
- _FixControlInputsAndContext(enter_vars);
- _InitializeValues(enter_vars);
- _loop_enters = enter_vars.ToList();
-
- var merge_vars = enter_vars
- .Select(x => merge(new[] { x, x }))
- .ToArray();
-
- _pivot_for_pred = merge_vars[0];
-
- // Build the graph for pred.
- var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
- // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
- var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0]));
- _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
- var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
- .ToArray();
-
- // Build the graph for body.
- var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
- // Convert TensorArray flow variables inside the context back into
- // their associated TensorArrays for calling the body.
- var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
- var body_result = body(packed_vars_for_body[0]);
- var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
-
- // Store body_result to keep track of TensorArrays returned by body
- var original_body_result = new[] { body_result };
- // Convert TensorArrays returned by body into their flow variables
- var result = new[] { body_result };
-
- var next_vars = new List();
- foreach (var (m, v) in zip(merge_vars, result))
- next_vars.Add(_AddNextAndBackEdge(m, v));
-
- // Add the exit ops.
- var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
- _loop_exits = exit_vars;
-
- // Exit the loop.
- // ExitResult(exit_vars);
- return (original_body_result, exit_vars.ToArray());
+ throw new NotImplementedException("");
}
private void _FixControlInputsAndContext(Tensor[] enters)