From 5a2d265d721b2bb78e3a7da127fda256af30d8ec Mon Sep 17 00:00:00 2001 From: Meinrad Recheis Date: Tue, 9 Apr 2019 01:16:03 +0200 Subject: [PATCH] Graph.control_dependencies: added overload and updated implementation which was far from the original functionality. --- .../Graphs/Graph.Control.cs | 42 ++++- src/TensorFlowNET.Core/ops.py.cs | 5 +- .../ControlDependenciesTest.cs | 154 +++++++++++------- 3 files changed, 131 insertions(+), 70 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index bc1e15d5..9c2881e1 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -32,13 +32,13 @@ namespace Tensorflow { var ret = new List(); - foreach(var controller in _control_dependencies_stack) + foreach (var controller in _control_dependencies_stack) { bool dominated = false; // If any of the input_ops already depends on the inputs from controller, // we say that the new op is dominated (by that input), and we therefore // do not need to add control dependencies for this controller's inputs. - foreach(var op in input_ops) + foreach (var op in input_ops) { if (controller.op_in_group(op)) { @@ -48,12 +48,22 @@ namespace Tensorflow } if (!dominated) - ret.AddRange( controller.control_inputs.Where(x => !input_ops.Contains(x))); + ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x))); } return ret.ToArray(); } + /// + /// Returns a context manager that specifies control dependencies. + /// + /// Use with the `with` keyword to specify that all operations constructed + /// within the context should have control dependencies on + /// `control_inputs`. + /// + public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) + => control_dependencies(control_inputs == null ? null : control_inputs.OfType().ToArray()); + /// /// Returns a context manager that specifies control dependencies. /// @@ -61,7 +71,7 @@ namespace Tensorflow /// within the context should have control dependencies on /// `control_inputs`. /// - public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) + public _ControlDependenciesController control_dependencies(object[] control_inputs) { if (control_inputs == null) return new _ControlDependenciesController(this, null); @@ -69,9 +79,26 @@ namespace Tensorflow var control_ops = new List(); foreach (var c in control_inputs) { - control_ops.Add(c); + switch (c) + { + // TODO: implement IndexedSlices + //case IndexedSlices islice: + // control_ops.Add(islice.op); + // break; + case Tensor t: + control_ops.Add(t.op); + break; + case Operation op: + control_ops.Add(op); + break; + default: + var t1 = _as_graph_element(c); + if (t1 == null) + throw new TypeError($"Control input must be Operation or Tensor:{c}"); + control_ops.Add(t1.op); + break; + } } - return new _ControlDependenciesController(this, control_ops); } @@ -103,6 +130,9 @@ namespace Tensorflow _control_dependencies_stack.Dequeue(); } + /// + /// Record that the given op depends on all registered control dependencies. + /// public void _record_op_seen_by_control_dependencies(Operation op) { foreach (var controller in _control_dependencies_stack) diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index ff41e261..add752ea 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -119,11 +119,14 @@ namespace Tensorflow /// A context manager that specifies control dependencies for all /// operations constructed within the context. /// - public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) + public static _ControlDependenciesController control_dependencies(object[] control_inputs) { return get_default_graph().control_dependencies(control_inputs); } + public static _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) + => control_dependencies(control_inputs == null ? null : control_inputs.OfType().ToArray()); + /// /// Creates a TF_Operation. /// diff --git a/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs b/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs index 3187b37e..3be4e80e 100644 --- a/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs +++ b/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs @@ -23,7 +23,7 @@ namespace TensorFlowNET.UnitTest { a = constant_op.constant(1.0); b = constant_op.constant(1.0); - with(g.control_dependencies(new ITensorOrOperation[] { a }), x => + with(g.control_dependencies(new[] { a }), x => { c = constant_op.constant(1.0); d = array_ops.identity(b); @@ -36,15 +36,15 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual(0, e.op.control_inputs.Length); } - [Ignore("Part of this test is not compiling")] + [Ignore("Future is not supported yet")] [TestMethod] public void TestEager() { - Tensor a = null, b = null, c = null, d = null, e = null; + Tensor a = null, c = null, d = null, e = null; + object b = null; var calls = 0; Func future = () => { - calls += 1; return constant_op.constant(2.0); }; @@ -55,13 +55,13 @@ namespace TensorFlowNET.UnitTest if (context.executing_eagerly()) { // TODO: make this compile (see original Python code below) - //a = constant_op.constant(1.0); - //b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well. - //with(ops.control_dependencies(new Operation[] {a, b}), ctrl => - //{ - // return c = constant_op.constant(3.0); - //}); - //Assert.AreEqual(calls, 1); + a = constant_op.constant(1.0); + b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well. + with(ops.control_dependencies(new object[] { a, b }), ctrl => + { + return c = constant_op.constant(3.0); + }); + Assert.AreEqual(calls, 1); } else { @@ -69,12 +69,12 @@ namespace TensorFlowNET.UnitTest with(graph, g => { a = constant_op.constant(1.0); - b = future(); - with(g.control_dependencies(new ITensorOrOperation[] { a, b }), ctrl => - { - c = constant_op.constant(3.0); - }); - Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b.op })); + var b1 = future(); + with(g.control_dependencies(new [] { a, b}), ctrl => + { + c = constant_op.constant(3.0); + }); + Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b1.op })); Assert.AreEqual(1, calls); }); @@ -106,19 +106,7 @@ namespace TensorFlowNET.UnitTest } - // Note: {henon}, all tests below use the function _apply_op which is not really portable in C#, see original source below - // but I think _apply_op(...) can just be replaced by g.create_op(...). - /* -def _apply_op(g, *args, **kwargs): - op = g.create_op(*args, **kwargs) - if len(op.outputs) == 1: - return op.outputs[0] - else: - return op.outputs - */ - - - [Ignore("")] + [Ignore("How to port the ConvertibleObj?")] [TestMethod] public void TestBasicWithConversion() { @@ -127,58 +115,98 @@ def _apply_op(g, *args, **kwargs): var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT }); // TODO: ConvertibleObj, see original source below /* - def testBasicWithConversion(self): - g = ops.Graph() - a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + def testBasicWithConversion(self): + g = ops.Graph() + a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - class ConvertibleObj(object): + class ConvertibleObj(object): - def _as_graph_element(self): - return a + def _as_graph_element(self): + return a - with g.control_dependencies([ConvertibleObj()]): - c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + with g.control_dependencies([ConvertibleObj()]): + c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - self.assertEqual(c.op.control_inputs, [a.op]) + self.assertEqual(c.op.control_inputs, [a.op]) */ } - - //[Ignore] - [TestMethod()] + + [TestMethod] public void TestNested() { - var g = ops.get_default_graph(); + var g = tf.Graph().as_default(); var a_1 = constant_op.constant(1.0); var a_2 = constant_op.constant(3.0); var a_3 = constant_op.constant(4.0); var a_4 = constant_op.constant(5.0); Operation b_1 = null, b_2 = null; - with(g.control_dependencies(new ITensorOrOperation[] { a_1, a_2, a_3, a_4 }), ctrl => - { - b_1 = constant_op.constant(6.0); - }); - with(g.control_dependencies(new ITensorOrOperation[] { a_1 }), ctrl1 => - { - with(g.control_dependencies(new ITensorOrOperation[] { a_2 }), ctrl2 => - { - with(g.control_dependencies(new ITensorOrOperation[] { a_3 }), ctrl3 => - { - with(g.control_dependencies(new ITensorOrOperation[] { a_4 }), ctrl4 => - { - b_2 = constant_op.constant(7.0); - }); - }); - }); - }); - AssertItemsEqual(new[] {a_1.op, a_2.op, a_3.op, a_4.op}, b_1.op.control_inputs); + with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl => + { + b_1 = constant_op.constant(6.0); + }); + with(g.control_dependencies(new[] { a_1 }), ctrl1 => + { + with(g.control_dependencies(new[] { a_2 }), ctrl2 => + { + with(g.control_dependencies(new[] { a_3 }), ctrl3 => + { + with(g.control_dependencies(new[] { a_4 }), ctrl4 => + { + b_2 = constant_op.constant(7.0); + }); + }); + }); + }); + AssertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs); AssertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs); } - - [Ignore("will fail due to unsupported op 'FloatOutput'")] + [Ignore("Fails")] [TestMethod] public void TestClear() { + var g = tf.Graph().as_default(); + var a_1 = constant_op.constant(1.0); + var a_2 = constant_op.constant(3.0); + var a_3 = constant_op.constant(4.0); + var a_4 = constant_op.constant(5.0); + Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null; + with(g.control_dependencies(new[] { a_1 }), ctrl1 => + { + with(g.control_dependencies(new[] { a_2 }), ctrl2 => + { + with(g.control_dependencies(null), ctrl3 => + { + with(g.control_dependencies(new[] { a_3 }), ctrl4 => + { + with(g.control_dependencies(new[] { a_4 }), ctrl5 => + { + // deps [a_3, a_4] + b_3_4 = constant_op.constant(7.0); + }); + // deps = [a_3] + b_3 = constant_op.constant(8.0); + }); + // deps back to None + b_none = constant_op.constant(9.0); + }); + // deps back to [a_1, a_2] + b_1_2 = constant_op.constant(10.0); + }); + // deps back to [a_1] + b_1 = constant_op.constant(11.0); + with(g.control_dependencies(null), ctrl6 => + { + // deps are None again + b_none2 = constant_op.constant(12.0); + }); + }); + AssertItemsEqual(new[] {a_3.op, a_4.op}, b_3_4.op.control_inputs); + AssertItemsEqual(new[] {a_3.op}, b_3.op.control_inputs); + AssertItemsEqual(new object[0], b_none.op.control_inputs); + AssertItemsEqual(new[] {a_1.op, a_2.op}, b_1_2.op.control_inputs); + AssertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs); + AssertItemsEqual(new object[0], b_none2.op.control_inputs); /* def testClear(self): g = ops.Graph()