diff --git a/src/TensorFlowHub/MnistDataSet.cs b/src/TensorFlowHub/MnistDataSet.cs index accc57e1..f1a73349 100644 --- a/src/TensorFlowHub/MnistDataSet.cs +++ b/src/TensorFlowHub/MnistDataSet.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Text; using NumSharp; using Tensorflow; @@ -21,7 +22,12 @@ namespace Tensorflow.Hub images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); images.astype(dataType); + // for debug np.multiply performance + var sw = new Stopwatch(); + sw.Start(); images = np.multiply(images, 1.0f / 255.0f); + sw.Stop(); + Console.WriteLine($"{sw.ElapsedMilliseconds}ms"); Data = images; labels.astype(dataType); diff --git a/src/TensorFlowNET.Core/APIs/tf.control.cs b/src/TensorFlowNET.Core/APIs/tf.control.cs index ce5da031..d744a141 100644 --- a/src/TensorFlowNET.Core/APIs/tf.control.cs +++ b/src/TensorFlowNET.Core/APIs/tf.control.cs @@ -14,10 +14,29 @@ limitations under the License. ******************************************************************************/ +using System; + namespace Tensorflow { public static partial class tf { + public static Tensor while_loop(Func cond, Func body, Tensor[] loop_vars, + TensorShape shape_invariants = null, + int parallel_iterations = 10, + bool back_prop = true, + bool swap_memory = false, + string name = null, + int? maximum_iterations = null, + bool return_same_structure = false) + => control_flow_ops.while_loop(cond, body, loop_vars, + shape_invariants: shape_invariants, + parallel_iterations: parallel_iterations, + back_prop: back_prop, + swap_memory: swap_memory, + name: name, + maximum_iterations: maximum_iterations, + return_same_structure: return_same_structure); + public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) => ops.control_dependencies(control_inputs); } diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index a8604483..c2686812 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -39,8 +39,8 @@ namespace Tensorflow public static Tensor asin(Tensor x, string name = null) => gen_math_ops.asin(x, name); - public static Tensor add(Tx a, Ty b) - => gen_math_ops.add(a, b); + public static Tensor add(Tx a, Ty b, string name = null) + => gen_math_ops.add(a, b, name: name); /// /// Computes atan of x element-wise. diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index 70057113..4a3ac793 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -33,7 +33,7 @@ namespace Tensorflow /// /// The data input ops for an op to be created. /// A list of control inputs for the op to be created. - private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) + public ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) { var ret = new List(); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index d2a1f628..5cfecc49 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -53,6 +53,11 @@ namespace Tensorflow.Operations protected Stack _context_stack; protected ControlFlowContext _outer_context; + /// + /// The keys are the names of tensors referenced by but external to this + /// context. Each value is the Tensor that should be used by this context to + /// access the key value (e.g. a switch output guarding a cond input value). + /// protected Dictionary _external_values; public ControlFlowContext() @@ -68,6 +73,12 @@ namespace Tensorflow.Operations _outer_context = ops.get_default_graph()._get_control_flow_context(); if (values_def != null) _init_values_from_proto(values_def, import_scope: import_scope); + else + { + _values = new HashSet(); + _external_values = new Dictionary(); + } + } public void __enter__() @@ -114,6 +125,27 @@ namespace Tensorflow.Operations graph._set_control_flow_context(this); } + protected virtual Tensor _Enter(Tensor data, string frame_name, + bool is_constant = false, + int parallel_iterations = 10, + bool use_ref = true, + bool use_input_shape = true, + string name = null) + { + Tensor result; + data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); + if (data.dtype.is_ref_dtype() && use_ref) + throw new NotImplementedException("_Enter"); + else + result = gen_control_flow_ops.enter( + data, frame_name, is_constant, parallel_iterations, name: name); + + if (use_input_shape) + result.SetShape(data.TensorShape); + + return result; + } + /// /// Exit this control flow context. /// @@ -184,6 +216,10 @@ namespace Tensorflow.Operations return true; } + protected virtual bool _IsInOuterContext(Operation op) + { + throw new NotImplementedException("_IsInOuterContext"); + } protected virtual void _RemoveExternalControlEdges(Operation op) { diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index 69329b21..c5e7121c 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -15,8 +15,12 @@ ******************************************************************************/ using System; +using System.Collections.Generic; +using System.Linq; using Tensorflow.Operations.ControlFlows; +using Tensorflow.Util; using static Tensorflow.Python; +using static Tensorflow.control_flow_ops; namespace Tensorflow.Operations { @@ -32,10 +36,14 @@ namespace Tensorflow.Operations bool _swap_memory; Tensor _pivot_for_pred; Tensor _pivot_for_body; - Tensor[] _loop_exits; - Tensor[] _loop_enters; + List _loop_exits; + List _loop_enters; + Graph _graph; + public override GradLoopState grad_state => _grad_state; + public override bool back_prop => _back_prop; - public WhileContext(int parallel_iterations = 10, + public WhileContext(int? maximum_iterations = null, + int parallel_iterations = 10, bool back_prop = true, bool swap_memory = false, string name = "while_context", @@ -49,12 +57,27 @@ namespace Tensorflow.Operations } else { - + __init__(); + _init_from_args(maximum_iterations, parallel_iterations, back_prop, swap_memory, name); } _grad_state = grad_state; } + private void _init_from_args(int? maximum_iterations, + int parallel_iterations, + bool back_prop, + bool swap_memory, + string name) + { + _name = ops.get_default_graph().unique_name(name); + _back_prop = back_prop; + _swap_memory = swap_memory; + _loop_exits = new List(); + _loop_enters = new List(); + _graph = ops.get_default_graph(); + } + private void _init_from_proto(WhileContextDef context_def, string import_scope = null) { var g = ops.get_default_graph(); @@ -70,26 +93,156 @@ namespace Tensorflow.Operations // The boolean tensor for loop termination condition. _pivot = g.as_graph_element(ops.prepend_name_scope(context_def.PivotName, import_scope)) as Tensor; // The list of exit tensors for loop variables. - _loop_exits = new Tensor[context_def.LoopExitNames.Count]; + _loop_exits = new List(); foreach (var (i, exit_name) in enumerate(context_def.LoopExitNames)) - _loop_exits[i] = g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor; + _loop_exits.Add(g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor); // The list of enter tensors for loop variables. - _loop_enters = new Tensor[context_def.LoopEnterNames.Count]; + _loop_enters = new List(); foreach (var (i, enter_name) in enumerate(context_def.LoopEnterNames)) - _loop_enters[i] = g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor; + _loop_enters.Add(g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor); __init__(values_def: context_def.ValuesDef, import_scope: import_scope); } - public override WhileContext GetWhileContext() + /// + /// Add the loop termination condition and body to the graph. + /// + public Tensor[] BuildLoop(Func pred, + Func body, + Tensor[] loop_vars, + TensorShape shape_invariants, + bool return_same_structure) { - return this; + // Keep original_loop_vars to identify which are TensorArrays + var original_loop_vars = loop_vars; + // Convert TensorArrays to their flow variables + Enter(); + var(original_body_result, exit_vars) = _BuildLoop( + pred, body, original_loop_vars, loop_vars, shape_invariants); + Exit(); + + var flat_result = original_body_result; + + var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); + var packed_exit_vars = nest.pack_sequence_as( + structure: original_body_result, + flat_sequence: exit_vars_with_tensor_arrays); + + return packed_exit_vars as Tensor[]; } + private (Tensor[], Tensor[]) _BuildLoop(Func pred, + Func body, + Tensor[] original_loop_vars, + Tensor[] loop_vars, + TensorShape shape_invariants) + { + var flat_loop_vars = original_loop_vars; - public override GradLoopState grad_state => _grad_state; + // Let the context know the loop variables so the loop variables + // would be added in the outer contexts properly. + _InitializeValues(loop_vars); + var real_vars = loop_vars; + Tensor[] enter_vars = null; + tf_with(ops.control_dependencies(null), delegate + { + enter_vars = real_vars.Select(x => _Enter(x, + _name, + is_constant: false, + parallel_iterations: _parallel_iterations, + use_input_shape: shape_invariants == null)) + .ToArray(); - public override bool back_prop => _back_prop; + foreach(var x in enter_vars) + { + x.graph.prevent_feeding(x); + if (_outer_context != null) + _outer_context.AddInnerOp(x.op); + } + }); + + // Finds the closest enclosing non-None control pivot. + var outer_context = _outer_context; + while (outer_context != null) + { + + } + + _SetShapeInvariants(real_vars, enter_vars, shape_invariants); + + // Fix the control inputs and control flow context of these enter ops. + _FixControlInputsAndContext(enter_vars); + _InitializeValues(enter_vars); + _loop_enters = enter_vars.ToList(); + + var merge_vars = enter_vars + .Select(x => merge(new[] { x, x })) + .ToArray(); + + _pivot_for_pred = merge_vars[0]; + + // Build the graph for pred. + var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); + // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); + var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0])); + _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); + var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) + .ToArray(); + + // Build the graph for body. + var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); + // Convert TensorArray flow variables inside the context back into + // their associated TensorArrays for calling the body. + var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); + var body_result = body(packed_vars_for_body[0]); + var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + + // Store body_result to keep track of TensorArrays returned by body + var original_body_result = new[] { body_result }; + // Convert TensorArrays returned by body into their flow variables + var result = new[] { body_result }; + + var next_vars = new List(); + foreach (var (m, v) in zip(merge_vars, result)) + next_vars.Add(_AddNextAndBackEdge(m, v)); + + // Add the exit ops. + var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); + _loop_exits = exit_vars; + + // Exit the loop. + // ExitResult(exit_vars); + return (original_body_result, exit_vars.ToArray()); + } + + private void _FixControlInputsAndContext(Tensor[] enters) + { + var graph = ops.get_default_graph(); + foreach(var e in enters) + { + var inp_op = e.op.inputs[0].op; + var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op }); + // op for op in control_inputs if self._IsInOuterContext(op) + var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) + .Select(x => x.op) + .ToArray(); + e.op._set_control_flow_context(this); + e.op._add_control_inputs(outer_control_inputs); + graph._record_op_seen_by_control_dependencies(e.op); + } + } + + private void _InitializeValues(Tensor[] values) + { + _values = new HashSet(); + foreach(var x in values) + _values.Add(x.name); + } + + public override WhileContext GetWhileContext() + { + return this; + } public WhileContext from_proto(WhileContextDef proto, string import_scope) { diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 5b820b3a..1c7466f7 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -141,30 +141,57 @@ namespace Tensorflow.Operations string base_name = null; tf_with(ops.name_scope("dynamic_rnn"), scope => base_name = scope); - Func _create_ta = (name, element_shape, dtype_) => + Func _create_ta = (name, element_shape, dtype_) => { - new TensorArray(dtype: dtype_, + var ta = new TensorArray(dtype: dtype_, size: time_steps, element_shape: element_shape, tensor_array_name: base_name + name); - throw new NotImplementedException(""); + return ta; }; bool in_graph_mode = true; + var output_ta = new List(); + var input_ta = new List(); if (in_graph_mode) { - foreach(var (i, out_size) in enumerate(flat_output_size)) + foreach (var (i, out_size) in enumerate(flat_output_size)) { - _create_ta($"output_{i}", + output_ta.Add(_create_ta($"output_{i}", new TensorShape(const_batch_size).concatenate( _maybe_tensor_shape_from_tensor(out_size)), - _infer_state_dtype(dtype, state)); + _infer_state_dtype(dtype, state))); + } + foreach (var (i, flat_input_i) in enumerate(flat_input)) + { + input_ta.Add(_create_ta($"input_{i}", + new TensorShape(flat_input_i.dims.Skip(1).ToArray()), + flat_input_i.dtype)); + } - + for (int i = 0; i < input_ta.Count; i++) + { + var (ta, input_) = (input_ta[0], flat_input[0]); } } + // Make sure that we run at least 1 step, if necessary, to ensure + // the TensorArrays pick up the dynamic shape. + Tensor loop_bound; + if (in_graph_mode) + loop_bound = math_ops.minimum( + time_steps, math_ops.maximum(1, max_sequence_length)); + + /*Func cond = (ctime) => + { + return null; + }; + + control_flow_ops.while_loop( + cond: cond, + body = );*/ + throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 2717fd3e..c8c711e7 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -26,6 +26,44 @@ namespace Tensorflow { public class control_flow_ops { + public static Tensor _AddNextAndBackEdge(Tensor m, Tensor v, bool enforce_shape_invariant = true) + { + v = ops.convert_to_tensor(v); + v = _NextIteration(v); + if (enforce_shape_invariant) + _EnforceShapeInvariant(m, v); + m.op._update_input(1, v); + return v; + } + + /// + /// Check if the shapes of the loops variables are invariants. + /// + /// + /// + public static void _EnforceShapeInvariant(Tensor merge_var, Tensor next_var) + { + + } + + public static Tensor exit(Tensor data, string name = null) + { + data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); + if (data.dtype.is_ref_dtype()) + return gen_control_flow_ops.ref_exit(data, name: name); + else + return gen_control_flow_ops._exit(data, name: name); + } + + public static Tensor _NextIteration(Tensor data, string name = null) + { + data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); + if (data.dtype.is_ref_dtype()) + return gen_control_flow_ops.ref_next_iteration(data, name: name); + else + return gen_control_flow_ops.next_iteration(data, name: name); + } + public static Operation Assert(Tensor condition, object[] data, int? summarize = null, string name = null) { return tf_with(ops.name_scope(name, "Assert", new { condition, data }), scope => @@ -213,6 +251,14 @@ namespace Tensorflow return gen_array_ops.identity(data, name: name); } + public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape shapes = null) + { + if (shapes == null) + return; + + throw new NotImplementedException("_SetShapeInvariants"); + } + /// /// Forwards `data` to an output determined by `pred`. /// If `pred` is false, the `data` input is forwarded to the first output. @@ -516,10 +562,52 @@ namespace Tensorflow throw new NotImplementedException("ZerosLikeOutsideLoop"); } - // TODO - public static void while_loop(Func func, Func func1, Tensor[] tensors, int? i) + /// + /// Repeat `body` while the condition `cond` is true. + /// + /// + /// + /// + /// + public static Tensor while_loop(Func cond, Func body, Tensor[] loop_vars, + TensorShape shape_invariants = null, + int parallel_iterations = 10, + bool back_prop = true, + bool swap_memory = false, + string name = null, + int? maximum_iterations = null, + bool return_same_structure = false) { - throw new NotImplementedException(); + tf_with(ops.name_scope(name, "while", loop_vars), scope => + { + if (loop_vars == null || loop_vars.Length == 0) + throw new ValueError("No loop variables provided"); + if (cond == null) + throw new ValueError("cond must be callable."); + if (body == null) + throw new ValueError("body must be callable."); + if (parallel_iterations < 1) + throw new ValueError("parallel_iterations must be a positive integer."); + + var loop_context = new WhileContext( + maximum_iterations: maximum_iterations, + parallel_iterations: parallel_iterations, + back_prop: back_prop, + swap_memory: swap_memory); + + if (loop_context.outer_context == null) + ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context); + + var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, + return_same_structure); + + if (maximum_iterations != null) + return results[1]; + else + return results[0]; + }); + + throw new NotImplementedException("while_loop"); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs index 580da2b7..bfbf3413 100644 --- a/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_control_flow_ops.py.cs @@ -20,6 +20,93 @@ namespace Tensorflow { public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + /// + /// Creates or finds a child frame, and makes `data` available to the child frame. + /// + /// + /// + /// + /// + /// + /// + public static Tensor enter(Tensor data, string frame_name = "frame_name", bool is_constant = false, int parallel_iterations = 10, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Enter", name, new + { + data, + frame_name, + is_constant, + parallel_iterations + }); + + return _op.output; + } + + /// + /// Forwards the input to the output. + /// + /// + /// + /// + public static Tensor loop_cond(Tensor input, string name = null) + { + var _op = _op_def_lib._apply_op_helper("LoopCond", name, new { input }); + + return _op.output; + } + + /// + /// Makes its input available to the next iteration. + /// + /// + /// + /// + public static Tensor ref_next_iteration(Tensor data, string name = null) + { + var _op = _op_def_lib._apply_op_helper("RefNextIteration", name, new { data }); + + return _op; + } + + /// + /// Makes its input available to the next iteration. + /// + /// + /// + /// + public static Tensor next_iteration(Tensor data, string name = null) + { + var _op = _op_def_lib._apply_op_helper("NextIteration", name, new { data }); + + return _op; + } + + /// + /// Exits the current frame to its parent frame. + /// + /// + /// + /// + public static Tensor ref_exit(Tensor data, string name = null) + { + var _op = _op_def_lib._apply_op_helper("RefExit", name, new { data }); + + return _op; + } + + /// + /// Exits the current frame to its parent frame. + /// + /// + /// + /// + public static Tensor _exit(Tensor data, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Exit", name, new { data }); + + return _op; + } + public static Operation no_op(string name = null) { var _op = _op_def_lib._apply_op_helper("NoOp", name, null); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index a5d26b23..1e2363e4 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -516,6 +516,9 @@ namespace Tensorflow }); } + public static Tensor minimum(Tx x, Ty y, string name = null) + => gen_math_ops.minimum(x, y, name: name); + public static Tensor maximum(Tx x, Ty y, string name = null) => gen_math_ops.maximum(x, y, name: name); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 801ab233..d17e1f59 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -416,5 +416,6 @@ namespace Tensorflow } } + public int tensor_int_val { get; set; } } } diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs index 682b826f..31109f0a 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs @@ -8,6 +8,18 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test [TestClass] public class WhileContextTestCase : PythonTest { + /// + /// https://www.tensorflow.org/api_docs/python/tf/while_loop + /// + [TestMethod] + public void SimpleWhileLoop() + { + var i = constant_op.constant(0, name: "i"); + var c = new Func(x => tf.less(x, 10, name: "c")); + var b = new Func(x => tf.add(x, 1, name: "c")); + var r = control_flow_ops.while_loop(c, b, new[] { i }); + } + private void _testWhileContextHelper(int? maximum_iterations = null) { // TODO: implement missing code dependencies @@ -17,7 +29,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test var c = new Func(x => gen_math_ops.less(x, 10, name: "c")); var b = new Func(x => gen_math_ops.add(x, 1, name: "c")); control_flow_ops.while_loop( - c, b, new[] { i }, maximum_iterations); + c, b, new[] { i }, maximum_iterations: maximum_iterations); foreach (Operation op in sess.graph.get_operations()) { var control_flow_context = op._get_control_flow_context();