diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 1faaa647..c00fc2c7 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -42,7 +42,7 @@ namespace Tensorflow.Operations public override GradLoopState grad_state => _grad_state; public override bool back_prop => _back_prop; - public WhileContext(int? maximum_iterations = null, + public WhileContext(Tensor maximum_iterations = null, int parallel_iterations = 10, bool back_prop = true, bool swap_memory = false, @@ -64,7 +64,7 @@ namespace Tensorflow.Operations _grad_state = grad_state; } - private void _init_from_args(int? maximum_iterations, + private void _init_from_args(Tensor maximum_iterations, int parallel_iterations, bool back_prop, bool swap_memory, @@ -107,9 +107,9 @@ namespace Tensorflow.Operations /// /// Add the loop termination condition and body to the graph. /// - public Tensor[] BuildLoop(Func pred, - Func body, - Tensor[] loop_vars, + internal Tensor[] BuildLoop(Func pred, + Func> body, + TItem loop_vars, TensorShape shape_invariants, bool return_same_structure) { @@ -131,88 +131,107 @@ namespace Tensorflow.Operations return packed_exit_vars as Tensor[]; } - private (Tensor[], Tensor[]) _BuildLoop(Func pred, - Func body, - Tensor[] original_loop_vars, - Tensor[] loop_vars, + private Tensor _convert_tensorarray_to_flow(TItem tensor_or_tensor_array) + { + if (tensor_or_tensor_array is TensorArray tensor_array) + return tensor_array.flow; + else if (tensor_or_tensor_array is Tensor tensor) + return tensor; + + throw new NotImplementedException("_convert_tensorarray_to_flow"); + } + + private (Tensor[], Tensor[]) _BuildLoop(Func pred, + Func> body, + TItem original_loop_vars, + TItem loop_vars, TensorShape shape_invariants) { var flat_loop_vars = original_loop_vars; + // Convert TensorArrays to their flow variables + var loop_vars_tensor = nest.map_structure( + _convert_tensorarray_to_flow, + nest.flatten(loop_vars)); + // Let the context know the loop variables so the loop variables // would be added in the outer contexts properly. - _InitializeValues(loop_vars); - var real_vars = loop_vars; - Tensor[] enter_vars = null; - tf_with(ops.control_dependencies(null), delegate + if (loop_vars is Tensor[] real_vars) { - enter_vars = real_vars.Select(x => _Enter(x, - _name, - is_constant: false, - parallel_iterations: _parallel_iterations, - use_input_shape: shape_invariants == null)) - .ToArray(); - - foreach(var x in enter_vars) + _InitializeValues(real_vars); + Tensor[] enter_vars = null; + tf_with(ops.control_dependencies(null), delegate + { + enter_vars = real_vars.Select(x => _Enter(x, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + use_input_shape: shape_invariants == null)) + .ToArray(); + + foreach (var x in enter_vars) + { + x.graph.prevent_feeding(x); + if (_outer_context != null) + _outer_context.AddInnerOp(x.op); + } + }); + + // Finds the closest enclosing non-None control pivot. + var outer_context = _outer_context; + while (outer_context != null) { - x.graph.prevent_feeding(x); - if (_outer_context != null) - _outer_context.AddInnerOp(x.op); + } - }); - // Finds the closest enclosing non-None control pivot. - var outer_context = _outer_context; - while (outer_context != null) - { + _SetShapeInvariants(real_vars, enter_vars, shape_invariants); + + // Fix the control inputs and control flow context of these enter ops. + _FixControlInputsAndContext(enter_vars); + _InitializeValues(enter_vars); + _loop_enters = enter_vars.ToList(); + + var merge_vars = enter_vars + .Select(x => merge(new[] { x, x })) + .ToArray(); + + _pivot_for_pred = merge_vars[0]; + + // Build the graph for pred. + var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); + // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); + var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem))); + _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); + var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) + .ToArray(); + // Build the graph for body. + var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); + // Convert TensorArray flow variables inside the context back into + // their associated TensorArrays for calling the body. + var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); + /*var body_result = body(packed_vars_for_body[0]); + var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); + + // Store body_result to keep track of TensorArrays returned by body + var original_body_result = new[] { body_result }; + // Convert TensorArrays returned by body into their flow variables + var result = new[] { body_result }; + + var next_vars = new List(); + foreach (var (m, v) in zip(merge_vars, result)) + next_vars.Add(_AddNextAndBackEdge(m, v)); + + // Add the exit ops. + var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); + _loop_exits = exit_vars; + + // Exit the loop. + // ExitResult(exit_vars); + return (original_body_result, exit_vars.ToArray());*/ } - _SetShapeInvariants(real_vars, enter_vars, shape_invariants); - - // Fix the control inputs and control flow context of these enter ops. - _FixControlInputsAndContext(enter_vars); - _InitializeValues(enter_vars); - _loop_enters = enter_vars.ToList(); - - var merge_vars = enter_vars - .Select(x => merge(new[] { x, x })) - .ToArray(); - - _pivot_for_pred = merge_vars[0]; - - // Build the graph for pred. - var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); - // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); - var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0])); - _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); - var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) - .ToArray(); - - // Build the graph for body. - var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); - // Convert TensorArray flow variables inside the context back into - // their associated TensorArrays for calling the body. - var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); - var body_result = body(packed_vars_for_body[0]); - var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); - - // Store body_result to keep track of TensorArrays returned by body - var original_body_result = new[] { body_result }; - // Convert TensorArrays returned by body into their flow variables - var result = new[] { body_result }; - - var next_vars = new List(); - foreach (var (m, v) in zip(merge_vars, result)) - next_vars.Add(_AddNextAndBackEdge(m, v)); - - // Add the exit ops. - var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); - _loop_exits = exit_vars; - - // Exit the loop. - // ExitResult(exit_vars); - return (original_body_result, exit_vars.ToArray()); + throw new NotImplementedException(""); } private void _FixControlInputsAndContext(Tensor[] enters)