diff --git a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs index 6ed475a9..b2b5574a 100644 --- a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs +++ b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs @@ -37,7 +37,7 @@ namespace Tensorflow public Operation group(T[] inputs, string name = null) where T : ITensorOrOperation => control_flow_ops.group(inputs, name: name); - public Tensor while_loop(Func cond, Func body, Tensor[] loop_vars, + /*public Tensor while_loop(Func cond, Func body, Tensor[] loop_vars, TensorShape shape_invariants = null, int parallel_iterations = 10, bool back_prop = true, @@ -52,7 +52,7 @@ namespace Tensorflow swap_memory: swap_memory, name: name, maximum_iterations: maximum_iterations, - return_same_structure: return_same_structure); + return_same_structure: return_same_structure);*/ public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) => ops.control_dependencies(control_inputs); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs new file mode 100644 index 00000000..fa2fe9d3 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + internal class LoopVar + { + public Tensor Counter { get; } + public TItem[] Items { get; } + public TItem Item { get; } + + public LoopVar(Tensor counter, TItem[] items) + { + Counter = counter; + Items = items; + } + + public LoopVar(Tensor counter, TItem item) + { + Counter = counter; + Item = item; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs new file mode 100644 index 00000000..f0086793 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + internal class BodyItemInRnnWhileLoop + { + /// + /// int32 scalar Tensor. + /// + public Tensor time { get; set; } + /// + /// List of `TensorArray`s that represent the output. + /// + public TensorArray[] output_ta_t { get; set; } + /// + /// nested tuple of vector tensors that represent the state. + /// + public Tensor state { get; set; } + + public BodyItemInRnnWhileLoop(Tensor time, TensorArray[] output_ta_t, Tensor state) + { + this.time = time; + this.output_ta_t = output_ta_t; + this.state = state; + } + + public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item) + => (item.time, item.output_ta_t, item.state); + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 8e7425e5..e058c077 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -145,7 +145,7 @@ namespace Tensorflow.Operations { var ta = new TensorArray(dtype: dtype_, size: time_steps, - element_shape: new[] { element_shape }, + element_shape: element_shape, tensor_array_name: base_name + name); return ta; }; @@ -178,19 +178,29 @@ namespace Tensorflow.Operations // Make sure that we run at least 1 step, if necessary, to ensure // the TensorArrays pick up the dynamic shape. - Tensor loop_bound; + Tensor loop_bound = null; if (in_graph_mode) loop_bound = math_ops.minimum( time_steps, math_ops.maximum(1, max_sequence_length)); - /*Func cond = (ctime) => + Func cond = (item) => { - return null; + return time < loop_bound; }; - control_flow_ops.while_loop( + // Take a time step of the dynamic RNN. + Func _time_step = (item) => + { + return item; + }; + + control_flow_ops.while_loop( cond: cond, - body = );*/ + body: _time_step, + loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), + parallel_iterations: parallel_iterations, + maximum_iterations: time_steps, + swap_memory: swap_memory); throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/Operations/TensorArray.cs b/src/TensorFlowNET.Core/Operations/TensorArray.cs index 7251bf85..60e1bde5 100644 --- a/src/TensorFlowNET.Core/Operations/TensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/TensorArray.cs @@ -39,7 +39,7 @@ namespace Tensorflow.Operations public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, - bool infer_shape = true, TensorShape[] element_shape = null, + bool infer_shape = true, TensorShape element_shape = null, bool colocate_with_first_write_call = true, string name = null) { _implementation = new _GraphTensorArray(dtype, diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs index bd919ad8..5a667560 100644 --- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -44,7 +44,7 @@ namespace Tensorflow.Operations public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, - bool infer_shape = true, TensorShape[] element_shape = null, + bool infer_shape = true, TensorShape element_shape = null, bool colocate_with_first_write_call = true, string name = null) { clear_after_read = clear_after_read ?? true; @@ -68,7 +68,7 @@ namespace Tensorflow.Operations else { _infer_shape = true; - _element_shape = new List { }; + _element_shape = new List { element_shape }; } tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => @@ -135,7 +135,7 @@ namespace Tensorflow.Operations var ta = new TensorArray(_dtype, infer_shape:_infer_shape, - element_shape: _element_shape.ToArray(), + element_shape: _element_shape[0], dynamic_size: _dynamic_size, handle: _handle, flow: flow_out, diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index e8b5f0eb..27e43153 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -485,7 +485,7 @@ namespace Tensorflow }); } - public static Tensor[] _convert_flows_to_tensorarrays(T[] tensors_or_tensorarrays, Tensor[] tensors_or_flows) + public static Tensor[] _convert_flows_to_tensorarrays(T tensors_or_tensorarrays, Tensor[] tensors_or_flows) { // zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray(); return tensors_or_flows; @@ -591,18 +591,18 @@ namespace Tensorflow /// /// /// - public static Tensor while_loop(Func cond, Func body, Tensor[] loop_vars, + public static Tensor while_loop(Func cond, Func body, TItem loop_vars, TensorShape shape_invariants = null, int parallel_iterations = 10, bool back_prop = true, bool swap_memory = false, string name = null, - int? maximum_iterations = null, + Tensor maximum_iterations = null, bool return_same_structure = false) { tf_with(ops.name_scope(name, "while", loop_vars), scope => { - if (loop_vars == null || loop_vars.Length == 0) + if (loop_vars == null) throw new ValueError("No loop variables provided"); if (cond == null) throw new ValueError("cond must be callable."); @@ -611,6 +611,28 @@ namespace Tensorflow if (parallel_iterations < 1) throw new ValueError("parallel_iterations must be a positive integer."); + var try_to_pack = loop_vars is Tensor && !return_same_structure; + var counter = constant_op.constant(0, dtype: maximum_iterations.dtype, name: "iteration_counter"); + var orig_cond = cond; + var orig_body = body; + + LoopVar loop_vars_1 = null; + Func> body_buildloop = null; + Func cond_buildloop = null; + + if (try_to_pack) + { + + } + else + { + loop_vars_1 = new LoopVar(counter, loop_vars); + cond_buildloop = (i, lv) => + math_ops.logical_and(i < maximum_iterations, orig_cond(lv)); + body_buildloop = (i, lv) => new LoopVar(i + 1, orig_body(lv)); + } + try_to_pack = false; + var loop_context = new WhileContext( maximum_iterations: maximum_iterations, parallel_iterations: parallel_iterations, @@ -620,7 +642,7 @@ namespace Tensorflow if (loop_context.outer_context == null) ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); - var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, + var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars, shape_invariants, return_same_structure); if (maximum_iterations != null) diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index fa194934..71e9bbab 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -28,12 +28,9 @@ namespace Tensorflow } public static (Tensor, Tensor) tensor_array_v3(T size, TF_DataType dtype = TF_DataType.DtInvalid, - TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, - bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) + TensorShape element_shape = null, bool dynamic_size = false, bool clear_after_read = true, + bool identical_element_shapes = false, string tensor_array_name = "", string name = null) { - if (tensor_array_name == null) - tensor_array_name = string.Empty; - var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new { size, diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index b3ae594f..54ff358a 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -223,7 +223,6 @@ namespace Tensorflow.Util private static void _flatten_recursive(T obj, List list) { - switch(obj) { case IDictionary dict: diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 39fd2ac9..a512663e 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -93,14 +93,14 @@ namespace Tensorflow return new Session().as_default(); } - public Session Session(Graph graph, SessionOptions opts = null) + public Session Session(Graph graph, ConfigProto config = null) { - return new Session(graph, opts: opts).as_default(); + return new Session(graph, config: config).as_default(); } - public Session Session(SessionOptions opts) + public Session Session(ConfigProto config) { - return new Session(null, opts).as_default(); + return new Session(null, config).as_default(); } public void __init__() diff --git a/test/TensorFlowNET.UnitTest/CSession.cs b/test/TensorFlowNET.UnitTest/CSession.cs index fa293288..e9ed7784 100644 --- a/test/TensorFlowNET.UnitTest/CSession.cs +++ b/test/TensorFlowNET.UnitTest/CSession.cs @@ -25,9 +25,8 @@ namespace TensorFlowNET.UnitTest { lock (Locks.ProcessWide) { - var opts = new SessionOptions(); - opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4}); - session_ = new Session(graph, opts, s); + var config = new ConfigProto {InterOpParallelismThreads = 4}; + session_ = new Session(graph, config, s); } } diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs index 72dd83ea..80ff71db 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs @@ -18,10 +18,10 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test var i = constant_op.constant(0, name: "i"); var c = new Func(x => tf.less(x, 10, name: "c")); var b = new Func(x => tf.add(x, 1, name: "c")); - var r = control_flow_ops.while_loop(c, b, new[] { i }); + var r = control_flow_ops.while_loop(c, b, i); } - private void _testWhileContextHelper(int? maximum_iterations = null) + private void _testWhileContextHelper(int maximum_iterations) { // TODO: implement missing code dependencies using (var sess = this.cached_session()) @@ -30,7 +30,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test var c = new Func(x => gen_math_ops.less(x, 10, name: "c")); var b = new Func(x => gen_math_ops.add(x, 1, name: "c")); control_flow_ops.while_loop( - c, b, new[] { i }, maximum_iterations: maximum_iterations); + c, b, i , maximum_iterations: tf.constant(maximum_iterations)); foreach (Operation op in sess.graph.get_operations()) { var control_flow_context = op._get_control_flow_context(); @@ -42,13 +42,6 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test } } - [Ignore("TODO")] - [TestMethod] - public void testWhileContext() - { - _testWhileContextHelper(); - } - [Ignore("TODO")] [TestMethod] public void testWhileContextWithMaximumIterations()