Browse Source

merge for CondContext.

tags/v0.9
Oceania2018 6 years ago
parent
commit
805c074d38
7 changed files with 319 additions and 130 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Control.cs
  2. +19
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  3. +37
    -12
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  4. +110
    -7
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  5. +3
    -2
      src/TensorFlowNET.Core/Python.cs
  6. +144
    -109
      test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs
  7. +5
    -0
      test/TensorFlowNET.UnitTest/PythonTest.cs

+ 1
- 0
src/TensorFlowNET.Core/Graphs/Graph.Control.cs View File

@@ -8,6 +8,7 @@ namespace Tensorflow
{
public partial class Graph
{
// Current control flow context. It could be either CondContext or WhileContext
public IControlFlowContext _control_flow_context;

// represents the nested with(...) statements


+ 19
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -64,6 +64,9 @@ namespace Tensorflow.Operations
}
}

/// <summary>
/// Add the subgraph defined by fn() to the graph.
/// </summary>
public (T, Tensor) BuildCondBranch<T>(Func<T> fn)
{
// Add the subgraph defined by fn() to the graph.
@@ -71,6 +74,22 @@ namespace Tensorflow.Operations
var original_result = fn();
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);

//TODO: port this chunck of missing code:
/*
if len(post_summaries) > len(pre_summaries):
new_summaries = post_summaries[len(pre_summaries):]
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
summary_ref[:] = pre_summaries
with ops.control_dependencies(new_summaries):
if original_result is None:
return no_op(), None
else:
original_result = nest.map_structure(array_ops.identity,
original_result)
*/
if (original_result == null)
return (original_result, null);

switch (original_result)
{
case Operation[] results:


+ 37
- 12
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -3,7 +3,24 @@ using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
{
/// <summary>
/// The base class for control flow context.
///
/// The usage pattern is a sequence of(Enter, Exit) followed by a final
/// ExitResult.
///
/// We maintain the following state for control flow contexts during graph
/// construction:
/// 1. graph has _control_flow_context: the current context used to
/// construct new nodes.Changed by ctxt.Enter() and ctxt.Exit()
/// 2. op has _control_flow_context: the context to which the op belongs.
/// Set at the time the op is created.Immutable.
/// 3. A ControlFlowContext has _outer_context: the context in which this
/// context is created.Set at the time a context is created.Immutable.
/// 4. A ControlFlowContext has _context_stack.
/// Pushed and popped by ctxt.Enter() and ctxt.Exit()
/// </summary>
public abstract class ControlFlowContext : IPython, IControlFlowContext
{
/// <summary>
@@ -17,6 +34,8 @@ namespace Tensorflow.Operations
_context_stack = new Stack<IControlFlowContext>();
}

public string name { get; set; }

public void __init__()
{

@@ -26,6 +45,13 @@ namespace Tensorflow.Operations
{
}

public void __exit__()
{
}

/// <summary>
/// Enter this control flow context.
/// </summary>
public virtual void Enter()
{
var graph = ops.get_default_graph();
@@ -33,6 +59,16 @@ namespace Tensorflow.Operations
graph._set_control_flow_context(this);
}

/// <summary>
/// Exit this control flow context.
/// </summary>
public virtual void Exit()
{
var graph = ops.get_default_graph();
var last_context = _context_stack.Pop();
graph._set_control_flow_context(last_context);
}

public void AddOp(Operation op)
{
_AddOpInternal(op);
@@ -56,17 +92,6 @@ namespace Tensorflow.Operations
var internal_control_inputs = op.control_inputs;
}

public void Exit()
{
var graph = ops.get_default_graph();
var last_context = _context_stack.Pop();
graph._set_control_flow_context(last_context);
}

public void __exit__()
{
}

public void Dispose()
{
}


+ 110
- 7
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -187,6 +187,48 @@ namespace Tensorflow
return @switch(data, pred, name: name);
}

/// <summary>
/// Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
///
/// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
/// `false_fn` must have the same non-zero number and type of outputs.
///
/// **WARNING**: Any Tensors or Operations created outside of `true_fn` and
/// `false_fn` will be executed regardless of which branch is selected at runtime.
///
/// Although this behavior is consistent with the dataflow model of TensorFlow,
/// it has frequently surprised users who expected a lazier semantics.
/// Consider the following simple program:
///
/// z = tf.multiply(a, b)
/// result = tf.cond(x &lt; y, ()=> tf.add(x, z), ()=> tf.square(y))
///
/// If `x&lt;y`, the `tf.add` operation will be executed and `tf.square`
/// operation will not be executed.Since `z` is needed for at least one
/// branch of the `cond`, the `tf.multiply` operation is always executed,
/// unconditionally.
///
/// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
/// call to `cond`, and not at all during `Session.run()`). `cond`
/// stitches together the graph fragments created during the `true_fn` and
/// `false_fn` calls with some additional graph nodes to ensure that the right
/// branch gets executed depending on the value of `pred`.
///
/// `tf.cond` supports nested structures as implemented in
/// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
/// same(possibly nested) value structure of lists, tuples, and/or named tuples.
/// Singleton lists and tuples form the only exceptions to this: when returned by
/// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
/// This behavior is disabled by passing `strict= True`.
/// </summary>
/// <param name="pred"> A scalar determining whether to return the result of `true_fn` or
/// `false_fn`.</param>
/// <param name="true_fn">The callable to be performed if pred is true.</param>
/// <param name="false_fn">The callable to be performed if pred is false.</param>
/// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param>
/// <param name="name">Optional name prefix for the returned tensors.</param>
/// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the
/// callables return a singleton list, the element is extracted from the list.</returns>
public static Tensor cond(Tensor pred,
Func<ITensorOrOperation> true_fn = null,
Func<ITensorOrOperation> false_fn = null,
@@ -195,6 +237,37 @@ namespace Tensorflow
{
return with(ops.name_scope(name, "cond", new { pred }), delegate
{
// TODO: here a chunk of original code is missing
/*
if fn1 is not None:
if true_fn is not None:
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
true_fn = fn1
elif true_fn is None:
raise TypeError("cond(): true_fn argument required")
if fn2 is not None:
if false_fn is not None:
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
false_fn = fn2
elif false_fn is None:
raise TypeError("cond(): false_fn argument required")

if not callable(true_fn):
raise TypeError("true_fn must be callable.")
if not callable(false_fn):
raise TypeError("false_fn must be callable.")

with ops.name_scope(name, "cond", [pred]):
if context.executing_eagerly():
if pred:
return _UnpackIfSingleton(true_fn())
return _UnpackIfSingleton(false_fn())

# Add the Switch to the graph.
if isinstance(pred, bool):
raise TypeError("pred must not be a Python bool")
*/

// Add the Switch to the graph.
var (p_2, p_1) = @switch(pred, pred);
var pivot_1 = array_ops.identity(p_1, name: "switch_t");
@@ -207,15 +280,45 @@ namespace Tensorflow

// Build the graph for the true branch in a new context.
var context_t = new CondContext(pred, pivot_1, branch: 1);
context_t.Enter();
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
context_t.Exit();

ITensorOrOperation orig_res_t;
Tensor res_t;
try
{
context_t.Enter();
(orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
}
finally
{
context_t.Exit();
}
// Build the graph for the false branch in a new context.
var context_f = new CondContext(pred, pivot_2, branch: 0);
context_f.Enter();
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
context_f.Exit();
ITensorOrOperation orig_res_f;
Tensor res_f;
try
{
context_f.Enter();
(orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
}
finally
{
context_f.Exit();
}

//TODO: missing original code
//if not strict:
// orig_res_t = _UnpackIfSingleton(orig_res_t)
// orig_res_f = _UnpackIfSingleton(orig_res_f)
/*
# Check that the return values of the two branches have the same structure.
try:
nest.assert_same_structure(orig_res_t, orig_res_f)
except TypeError as e:
raise TypeError(
"Incompatible return types of true_fn and false_fn: {}".format(e))
except ValueError as e:
raise ValueError(
"Incompatible return values of true_fn and false_fn: {}".format(e))

var res_t_flat = new Tensor[] { res_t };
var res_f_flat = new Tensor[] { res_f };


+ 3
- 2
src/TensorFlowNET.Core/Python.cs View File

@@ -1,5 +1,6 @@
using NumSharp;
using System;
using System.Collections;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
@@ -17,8 +18,8 @@ namespace Tensorflow
Console.WriteLine(obj.ToString());
}

protected int len(Array a)
=> a.Length;
protected int len<T>(IEnumerable<T> a)
=> a.Count();

protected IEnumerable<int> range(int end)
{


+ 144
- 109
test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs View File

@@ -61,115 +61,150 @@ namespace TensorFlowNET.UnitTest
self.assertEqual(op4.name, "myop_1_1");
});
}
[Ignore("Something is not right, Switch gets not inserted correctly?")]
[TestMethod]
public void TestCond()
{
var graph = tf.Graph().as_default();
with<Graph>(graph, g =>
{
var x = constant_op.constant(10);
var true_fn = new Func<Tensor>(() =>
{
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]);
var new_ops = g._add_new_tf_operations();
self.assertEqual(len(new_ops), 1);
return x;
});
control_flow_ops.cond(x < 10, true_fn, () => x);
var op = g.get_operation_by_name("cond/myop");
self.assertIsNotNone(op);
self.assertEqual(op.name, "cond/myop");
self.assertEqual(op.type, "Identity");
//self.assertEqual(op.outputs, new object[0]);
var op_input = op.inputs[0].op;
self.assertEqual(op_input.type, "Switch");
self.assertEqual(op_input.inputs[0], x);
self.assertEqual(op.graph, g);
self.assertIsNotNone(op._get_control_flow_context());
// TODO: op._get_control_flow_context().name not implemented
//self.assertEqual(op._get_control_flow_context().name, "cond/cond_text");
});
/*
@test_util.run_v1_only("b/120545219")
def testCond(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def true_fn():
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "cond/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return x
control_flow_ops.cond(x < 10, true_fn, lambda: x)
op = g.get_operation_by_name("cond/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "cond/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Switch")
self.assertEqual(op_input.inputs[0], x)
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"cond/cond_text")
# pylint: enable=protected-access
*/
}
/*
@test_util.run_v1_only("b/120545219")
def testCond(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def true_fn():
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "cond/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return x
control_flow_ops.cond(x < 10, true_fn, lambda: x)
op = g.get_operation_by_name("cond/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "cond/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Switch")
self.assertEqual(op_input.inputs[0], x)
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"cond/cond_text")
# pylint: enable=protected-access
@test_util.run_v1_only("b/120545219")
def testWhileLoop(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "myloop/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Enter")
self.assertEqual(list(op_input.inputs), [x])
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"myloop/while_context")
# pylint: enable=protected-access
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithInternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def body(i):
c = constant_op.constant(1.0, name="c")
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
c = g.get_operation_by_name("myloop/c")
self.assertIsNotNone(c)
# Internal control dep is preserved
self.assertEqual(op.control_inputs, [c])
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithExternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
c = constant_op.constant(1.0)
def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
# External control dep is removed and replaced with internal control dep
self.assertNotEqual(op.control_inputs[0], c.op)
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
*/
@test_util.run_v1_only("b/120545219")
def testWhileLoop(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "myloop/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Enter")
self.assertEqual(list(op_input.inputs), [x])
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"myloop/while_context")
# pylint: enable=protected-access
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithInternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def body(i):
c = constant_op.constant(1.0, name="c")
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
c = g.get_operation_by_name("myloop/c")
self.assertIsNotNone(c)
# Internal control dep is preserved
self.assertEqual(op.control_inputs, [c])
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithExternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
c = constant_op.constant(1.0)
def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
# External control dep is removed and replaced with internal control dep
self.assertNotEqual(op.control_inputs[0], c.op)
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
*/
}
}

+ 5
- 0
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -29,6 +29,11 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(expected, given);
}
public void assertIsNotNone(object given)
{
Assert.IsNotNull(given);
}
protected PythonTest self { get => this; }
}
}

Loading…
Cancel
Save