Browse Source

Graph.control_dependencies: added overload and updated implementation which was far from the original functionality.

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
5a2d265d72
3 changed files with 131 additions and 70 deletions
  1. +36
    -6
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  2. +4
    -1
      src/TensorFlowNET.Core/ops.py.cs
  3. +91
    -63
      test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs

+ 36
- 6
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -32,13 +32,13 @@ namespace Tensorflow
{
var ret = new List<ITensorOrOperation>();

foreach(var controller in _control_dependencies_stack)
foreach (var controller in _control_dependencies_stack)
{
bool dominated = false;
// If any of the input_ops already depends on the inputs from controller,
// we say that the new op is dominated (by that input), and we therefore
// do not need to add control dependencies for this controller's inputs.
foreach(var op in input_ops)
foreach (var op in input_ops)
{
if (controller.op_in_group(op))
{
@@ -48,12 +48,22 @@ namespace Tensorflow
}

if (!dominated)
ret.AddRange( controller.control_inputs.Where(x => !input_ops.Contains(x)));
ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x)));
}

return ret.ToArray();
}

/// <summary>
/// Returns a context manager that specifies control dependencies.
///
/// Use with the `with` keyword to specify that all operations constructed
/// within the context should have control dependencies on
/// `control_inputs`.
/// </summary>
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
=> control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray());

/// <summary>
/// Returns a context manager that specifies control dependencies.
///
@@ -61,7 +71,7 @@ namespace Tensorflow
/// within the context should have control dependencies on
/// `control_inputs`.
/// </summary>
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
public _ControlDependenciesController control_dependencies(object[] control_inputs)
{
if (control_inputs == null)
return new _ControlDependenciesController(this, null);
@@ -69,9 +79,26 @@ namespace Tensorflow
var control_ops = new List<ITensorOrOperation>();
foreach (var c in control_inputs)
{
control_ops.Add(c);
switch (c)
{
// TODO: implement IndexedSlices
//case IndexedSlices islice:
// control_ops.Add(islice.op);
// break;
case Tensor t:
control_ops.Add(t.op);
break;
case Operation op:
control_ops.Add(op);
break;
default:
var t1 = _as_graph_element(c);
if (t1 == null)
throw new TypeError($"Control input must be Operation or Tensor:{c}");
control_ops.Add(t1.op);
break;
}
}

return new _ControlDependenciesController(this, control_ops);
}

@@ -103,6 +130,9 @@ namespace Tensorflow
_control_dependencies_stack.Dequeue();
}

/// <summary>
/// Record that the given op depends on all registered control dependencies.
/// </summary>
public void _record_op_seen_by_control_dependencies(Operation op)
{
foreach (var controller in _control_dependencies_stack)


+ 4
- 1
src/TensorFlowNET.Core/ops.py.cs View File

@@ -119,11 +119,14 @@ namespace Tensorflow
/// A context manager that specifies control dependencies for all
/// operations constructed within the context.
/// </returns>
public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
public static _ControlDependenciesController control_dependencies(object[] control_inputs)
{
return get_default_graph().control_dependencies(control_inputs);
}

public static _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
=> control_dependencies(control_inputs == null ? null : control_inputs.OfType<object>().ToArray());

/// <summary>
/// Creates a TF_Operation.
/// </summary>


+ 91
- 63
test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs View File

@@ -23,7 +23,7 @@ namespace TensorFlowNET.UnitTest
{
a = constant_op.constant(1.0);
b = constant_op.constant(1.0);
with(g.control_dependencies(new ITensorOrOperation[] { a }), x =>
with(g.control_dependencies(new[] { a }), x =>
{
c = constant_op.constant(1.0);
d = array_ops.identity(b);
@@ -36,15 +36,15 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(0, e.op.control_inputs.Length);
}
[Ignore("Part of this test is not compiling")]
[Ignore("Future is not supported yet")]
[TestMethod]
public void TestEager()
{
Tensor a = null, b = null, c = null, d = null, e = null;
Tensor a = null, c = null, d = null, e = null;
object b = null;
var calls = 0;
Func<Tensor> future = () =>
{
calls += 1;
return constant_op.constant(2.0);
};
@@ -55,13 +55,13 @@ namespace TensorFlowNET.UnitTest
if (context.executing_eagerly())
{
// TODO: make this compile (see original Python code below)
//a = constant_op.constant(1.0);
//b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well.
//with(ops.control_dependencies(new Operation[] {a, b}), ctrl =>
//{
// return c = constant_op.constant(3.0);
//});
//Assert.AreEqual(calls, 1);
a = constant_op.constant(1.0);
b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well.
with(ops.control_dependencies(new object[] { a, b }), ctrl =>
{
return c = constant_op.constant(3.0);
});
Assert.AreEqual(calls, 1);
}
else
{
@@ -69,12 +69,12 @@ namespace TensorFlowNET.UnitTest
with<Graph>(graph, g =>
{
a = constant_op.constant(1.0);
b = future();
with(g.control_dependencies(new ITensorOrOperation[] { a, b }), ctrl =>
{
c = constant_op.constant(3.0);
});
Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b.op }));
var b1 = future();
with(g.control_dependencies(new [] { a, b}), ctrl =>
{
c = constant_op.constant(3.0);
});
Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b1.op }));
Assert.AreEqual(1, calls);
});
@@ -106,19 +106,7 @@ namespace TensorFlowNET.UnitTest
}
// Note: {henon}, all tests below use the function _apply_op which is not really portable in C#, see original source below
// but I think _apply_op(...) can just be replaced by g.create_op(...).
/*
def _apply_op(g, *args, **kwargs):
op = g.create_op(*args, **kwargs)
if len(op.outputs) == 1:
return op.outputs[0]
else:
return op.outputs
*/
[Ignore("")]
[Ignore("How to port the ConvertibleObj?")]
[TestMethod]
public void TestBasicWithConversion()
{
@@ -127,58 +115,98 @@ def _apply_op(g, *args, **kwargs):
var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
// TODO: ConvertibleObj, see original source below
/*
def testBasicWithConversion(self):
g = ops.Graph()
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
def testBasicWithConversion(self):
g = ops.Graph()
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
class ConvertibleObj(object):
class ConvertibleObj(object):
def _as_graph_element(self):
return a
def _as_graph_element(self):
return a
with g.control_dependencies([ConvertibleObj()]):
c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([ConvertibleObj()]):
c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
self.assertEqual(c.op.control_inputs, [a.op])
self.assertEqual(c.op.control_inputs, [a.op])
*/
}
//[Ignore]
[TestMethod()]
[TestMethod]
public void TestNested()
{
var g = ops.get_default_graph();
var g = tf.Graph().as_default();
var a_1 = constant_op.constant(1.0);
var a_2 = constant_op.constant(3.0);
var a_3 = constant_op.constant(4.0);
var a_4 = constant_op.constant(5.0);
Operation b_1 = null, b_2 = null;
with(g.control_dependencies(new ITensorOrOperation[] { a_1, a_2, a_3, a_4 }), ctrl =>
{
b_1 = constant_op.constant(6.0);
});
with(g.control_dependencies(new ITensorOrOperation[] { a_1 }), ctrl1 =>
{
with(g.control_dependencies(new ITensorOrOperation[] { a_2 }), ctrl2 =>
{
with(g.control_dependencies(new ITensorOrOperation[] { a_3 }), ctrl3 =>
{
with(g.control_dependencies(new ITensorOrOperation[] { a_4 }), ctrl4 =>
{
b_2 = constant_op.constant(7.0);
});
});
});
});
AssertItemsEqual(new[] {a_1.op, a_2.op, a_3.op, a_4.op}, b_1.op.control_inputs);
with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl =>
{
b_1 = constant_op.constant(6.0);
});
with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
{
with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
{
with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
{
with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
{
b_2 = constant_op.constant(7.0);
});
});
});
});
AssertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs);
AssertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
}
[Ignore("will fail due to unsupported op 'FloatOutput'")]
[Ignore("Fails")]
[TestMethod]
public void TestClear()
{
var g = tf.Graph().as_default();
var a_1 = constant_op.constant(1.0);
var a_2 = constant_op.constant(3.0);
var a_3 = constant_op.constant(4.0);
var a_4 = constant_op.constant(5.0);
Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null;
with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
{
with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
{
with(g.control_dependencies(null), ctrl3 =>
{
with(g.control_dependencies(new[] { a_3 }), ctrl4 =>
{
with(g.control_dependencies(new[] { a_4 }), ctrl5 =>
{
// deps [a_3, a_4]
b_3_4 = constant_op.constant(7.0);
});
// deps = [a_3]
b_3 = constant_op.constant(8.0);
});
// deps back to None
b_none = constant_op.constant(9.0);
});
// deps back to [a_1, a_2]
b_1_2 = constant_op.constant(10.0);
});
// deps back to [a_1]
b_1 = constant_op.constant(11.0);
with(g.control_dependencies(null), ctrl6 =>
{
// deps are None again
b_none2 = constant_op.constant(12.0);
});
});
AssertItemsEqual(new[] {a_3.op, a_4.op}, b_3_4.op.control_inputs);
AssertItemsEqual(new[] {a_3.op}, b_3.op.control_inputs);
AssertItemsEqual(new object[0], b_none.op.control_inputs);
AssertItemsEqual(new[] {a_1.op, a_2.op}, b_1_2.op.control_inputs);
AssertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
AssertItemsEqual(new object[0], b_none2.op.control_inputs);
/*
def testClear(self):
g = ops.Graph()


Loading…
Cancel
Save