@@ -8,6 +8,7 @@ namespace Tensorflow | |||
{ | |||
public partial class Graph | |||
{ | |||
// Current control flow context. It could be either CondContext or WhileContext | |||
public IControlFlowContext _control_flow_context; | |||
// represents the nested with(...) statements | |||
@@ -64,6 +64,9 @@ namespace Tensorflow.Operations | |||
} | |||
} | |||
/// <summary> | |||
/// Add the subgraph defined by fn() to the graph. | |||
/// </summary> | |||
public (T, Tensor) BuildCondBranch<T>(Func<T> fn) | |||
{ | |||
// Add the subgraph defined by fn() to the graph. | |||
@@ -71,6 +74,22 @@ namespace Tensorflow.Operations | |||
var original_result = fn(); | |||
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
//TODO: port this chunck of missing code: | |||
/* | |||
if len(post_summaries) > len(pre_summaries): | |||
new_summaries = post_summaries[len(pre_summaries):] | |||
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access | |||
summary_ref[:] = pre_summaries | |||
with ops.control_dependencies(new_summaries): | |||
if original_result is None: | |||
return no_op(), None | |||
else: | |||
original_result = nest.map_structure(array_ops.identity, | |||
original_result) | |||
*/ | |||
if (original_result == null) | |||
return (original_result, null); | |||
switch (original_result) | |||
{ | |||
case Operation[] results: | |||
@@ -3,7 +3,24 @@ using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations | |||
{ | |||
{ | |||
/// <summary> | |||
/// The base class for control flow context. | |||
/// | |||
/// The usage pattern is a sequence of(Enter, Exit) followed by a final | |||
/// ExitResult. | |||
/// | |||
/// We maintain the following state for control flow contexts during graph | |||
/// construction: | |||
/// 1. graph has _control_flow_context: the current context used to | |||
/// construct new nodes.Changed by ctxt.Enter() and ctxt.Exit() | |||
/// 2. op has _control_flow_context: the context to which the op belongs. | |||
/// Set at the time the op is created.Immutable. | |||
/// 3. A ControlFlowContext has _outer_context: the context in which this | |||
/// context is created.Set at the time a context is created.Immutable. | |||
/// 4. A ControlFlowContext has _context_stack. | |||
/// Pushed and popped by ctxt.Enter() and ctxt.Exit() | |||
/// </summary> | |||
public abstract class ControlFlowContext : IPython, IControlFlowContext | |||
{ | |||
/// <summary> | |||
@@ -17,6 +34,8 @@ namespace Tensorflow.Operations | |||
_context_stack = new Stack<IControlFlowContext>(); | |||
} | |||
public string name { get; set; } | |||
public void __init__() | |||
{ | |||
@@ -26,6 +45,13 @@ namespace Tensorflow.Operations | |||
{ | |||
} | |||
public void __exit__() | |||
{ | |||
} | |||
/// <summary> | |||
/// Enter this control flow context. | |||
/// </summary> | |||
public virtual void Enter() | |||
{ | |||
var graph = ops.get_default_graph(); | |||
@@ -33,6 +59,16 @@ namespace Tensorflow.Operations | |||
graph._set_control_flow_context(this); | |||
} | |||
/// <summary> | |||
/// Exit this control flow context. | |||
/// </summary> | |||
public virtual void Exit() | |||
{ | |||
var graph = ops.get_default_graph(); | |||
var last_context = _context_stack.Pop(); | |||
graph._set_control_flow_context(last_context); | |||
} | |||
public void AddOp(Operation op) | |||
{ | |||
_AddOpInternal(op); | |||
@@ -56,17 +92,6 @@ namespace Tensorflow.Operations | |||
var internal_control_inputs = op.control_inputs; | |||
} | |||
public void Exit() | |||
{ | |||
var graph = ops.get_default_graph(); | |||
var last_context = _context_stack.Pop(); | |||
graph._set_control_flow_context(last_context); | |||
} | |||
public void __exit__() | |||
{ | |||
} | |||
public void Dispose() | |||
{ | |||
} | |||
@@ -187,6 +187,48 @@ namespace Tensorflow | |||
return @switch(data, pred, name: name); | |||
} | |||
/// <summary> | |||
/// Return `true_fn()` if the predicate `pred` is true else `false_fn()`. | |||
/// | |||
/// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and | |||
/// `false_fn` must have the same non-zero number and type of outputs. | |||
/// | |||
/// **WARNING**: Any Tensors or Operations created outside of `true_fn` and | |||
/// `false_fn` will be executed regardless of which branch is selected at runtime. | |||
/// | |||
/// Although this behavior is consistent with the dataflow model of TensorFlow, | |||
/// it has frequently surprised users who expected a lazier semantics. | |||
/// Consider the following simple program: | |||
/// | |||
/// z = tf.multiply(a, b) | |||
/// result = tf.cond(x < y, ()=> tf.add(x, z), ()=> tf.square(y)) | |||
/// | |||
/// If `x<y`, the `tf.add` operation will be executed and `tf.square` | |||
/// operation will not be executed.Since `z` is needed for at least one | |||
/// branch of the `cond`, the `tf.multiply` operation is always executed, | |||
/// unconditionally. | |||
/// | |||
/// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the | |||
/// call to `cond`, and not at all during `Session.run()`). `cond` | |||
/// stitches together the graph fragments created during the `true_fn` and | |||
/// `false_fn` calls with some additional graph nodes to ensure that the right | |||
/// branch gets executed depending on the value of `pred`. | |||
/// | |||
/// `tf.cond` supports nested structures as implemented in | |||
/// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the | |||
/// same(possibly nested) value structure of lists, tuples, and/or named tuples. | |||
/// Singleton lists and tuples form the only exceptions to this: when returned by | |||
/// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. | |||
/// This behavior is disabled by passing `strict= True`. | |||
/// </summary> | |||
/// <param name="pred"> A scalar determining whether to return the result of `true_fn` or | |||
/// `false_fn`.</param> | |||
/// <param name="true_fn">The callable to be performed if pred is true.</param> | |||
/// <param name="false_fn">The callable to be performed if pred is false.</param> | |||
/// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param> | |||
/// <param name="name">Optional name prefix for the returned tensors.</param> | |||
/// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the | |||
/// callables return a singleton list, the element is extracted from the list.</returns> | |||
public static Tensor cond(Tensor pred, | |||
Func<ITensorOrOperation> true_fn = null, | |||
Func<ITensorOrOperation> false_fn = null, | |||
@@ -195,6 +237,37 @@ namespace Tensorflow | |||
{ | |||
return with(ops.name_scope(name, "cond", new { pred }), delegate | |||
{ | |||
// TODO: here a chunk of original code is missing | |||
/* | |||
if fn1 is not None: | |||
if true_fn is not None: | |||
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") | |||
true_fn = fn1 | |||
elif true_fn is None: | |||
raise TypeError("cond(): true_fn argument required") | |||
if fn2 is not None: | |||
if false_fn is not None: | |||
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") | |||
false_fn = fn2 | |||
elif false_fn is None: | |||
raise TypeError("cond(): false_fn argument required") | |||
if not callable(true_fn): | |||
raise TypeError("true_fn must be callable.") | |||
if not callable(false_fn): | |||
raise TypeError("false_fn must be callable.") | |||
with ops.name_scope(name, "cond", [pred]): | |||
if context.executing_eagerly(): | |||
if pred: | |||
return _UnpackIfSingleton(true_fn()) | |||
return _UnpackIfSingleton(false_fn()) | |||
# Add the Switch to the graph. | |||
if isinstance(pred, bool): | |||
raise TypeError("pred must not be a Python bool") | |||
*/ | |||
// Add the Switch to the graph. | |||
var (p_2, p_1) = @switch(pred, pred); | |||
var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | |||
@@ -207,15 +280,45 @@ namespace Tensorflow | |||
// Build the graph for the true branch in a new context. | |||
var context_t = new CondContext(pred, pivot_1, branch: 1); | |||
context_t.Enter(); | |||
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | |||
context_t.Exit(); | |||
ITensorOrOperation orig_res_t; | |||
Tensor res_t; | |||
try | |||
{ | |||
context_t.Enter(); | |||
(orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | |||
} | |||
finally | |||
{ | |||
context_t.Exit(); | |||
} | |||
// Build the graph for the false branch in a new context. | |||
var context_f = new CondContext(pred, pivot_2, branch: 0); | |||
context_f.Enter(); | |||
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | |||
context_f.Exit(); | |||
ITensorOrOperation orig_res_f; | |||
Tensor res_f; | |||
try | |||
{ | |||
context_f.Enter(); | |||
(orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | |||
} | |||
finally | |||
{ | |||
context_f.Exit(); | |||
} | |||
//TODO: missing original code | |||
//if not strict: | |||
// orig_res_t = _UnpackIfSingleton(orig_res_t) | |||
// orig_res_f = _UnpackIfSingleton(orig_res_f) | |||
/* | |||
# Check that the return values of the two branches have the same structure. | |||
try: | |||
nest.assert_same_structure(orig_res_t, orig_res_f) | |||
except TypeError as e: | |||
raise TypeError( | |||
"Incompatible return types of true_fn and false_fn: {}".format(e)) | |||
except ValueError as e: | |||
raise ValueError( | |||
"Incompatible return values of true_fn and false_fn: {}".format(e)) | |||
var res_t_flat = new Tensor[] { res_t }; | |||
var res_f_flat = new Tensor[] { res_f }; | |||
@@ -1,5 +1,6 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.ComponentModel; | |||
using System.Linq; | |||
@@ -17,8 +18,8 @@ namespace Tensorflow | |||
Console.WriteLine(obj.ToString()); | |||
} | |||
protected int len(Array a) | |||
=> a.Length; | |||
protected int len<T>(IEnumerable<T> a) | |||
=> a.Count(); | |||
protected IEnumerable<int> range(int end) | |||
{ | |||
@@ -61,115 +61,150 @@ namespace TensorFlowNET.UnitTest | |||
self.assertEqual(op4.name, "myop_1_1"); | |||
}); | |||
} | |||
[Ignore("Something is not right, Switch gets not inserted correctly?")] | |||
[TestMethod] | |||
public void TestCond() | |||
{ | |||
var graph = tf.Graph().as_default(); | |||
with<Graph>(graph, g => | |||
{ | |||
var x = constant_op.constant(10); | |||
var true_fn = new Func<Tensor>(() => | |||
{ | |||
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); | |||
var new_ops = g._add_new_tf_operations(); | |||
self.assertEqual(len(new_ops), 1); | |||
return x; | |||
}); | |||
control_flow_ops.cond(x < 10, true_fn, () => x); | |||
var op = g.get_operation_by_name("cond/myop"); | |||
self.assertIsNotNone(op); | |||
self.assertEqual(op.name, "cond/myop"); | |||
self.assertEqual(op.type, "Identity"); | |||
//self.assertEqual(op.outputs, new object[0]); | |||
var op_input = op.inputs[0].op; | |||
self.assertEqual(op_input.type, "Switch"); | |||
self.assertEqual(op_input.inputs[0], x); | |||
self.assertEqual(op.graph, g); | |||
self.assertIsNotNone(op._get_control_flow_context()); | |||
// TODO: op._get_control_flow_context().name not implemented | |||
//self.assertEqual(op._get_control_flow_context().name, "cond/cond_text"); | |||
}); | |||
/* | |||
@test_util.run_v1_only("b/120545219") | |||
def testCond(self): | |||
g = ops.Graph() | |||
with g.as_default(): | |||
x = test_ops.int_output() | |||
def true_fn(): | |||
ops._create_c_op(ops.get_default_graph(), | |||
ops._NodeDef("IntInput", "cond/myop"), [x], []) | |||
new_ops = g._add_new_tf_operations() | |||
self.assertEqual(len(new_ops), 1) | |||
return x | |||
control_flow_ops.cond(x < 10, true_fn, lambda: x) | |||
op = g.get_operation_by_name("cond/myop") | |||
self.assertIsNotNone(op) | |||
self.assertEqual(op.name, "cond/myop") | |||
self.assertEqual(op.type, "IntInput") | |||
self.assertEqual(op.outputs, []) | |||
op_input = op.inputs[0].op | |||
self.assertEqual(op_input.type, "Switch") | |||
self.assertEqual(op_input.inputs[0], x) | |||
self.assertEqual(op.graph, g) | |||
# pylint: disable=protected-access | |||
self.assertIsNotNone(op._get_control_flow_context()) | |||
self.assertEqual(op._get_control_flow_context().name, | |||
"cond/cond_text") | |||
# pylint: enable=protected-access | |||
*/ | |||
} | |||
/* | |||
@test_util.run_v1_only("b/120545219") | |||
def testCond(self): | |||
g = ops.Graph() | |||
with g.as_default(): | |||
x = test_ops.int_output() | |||
def true_fn(): | |||
ops._create_c_op(ops.get_default_graph(), | |||
ops._NodeDef("IntInput", "cond/myop"), [x], []) | |||
new_ops = g._add_new_tf_operations() | |||
self.assertEqual(len(new_ops), 1) | |||
return x | |||
control_flow_ops.cond(x < 10, true_fn, lambda: x) | |||
op = g.get_operation_by_name("cond/myop") | |||
self.assertIsNotNone(op) | |||
self.assertEqual(op.name, "cond/myop") | |||
self.assertEqual(op.type, "IntInput") | |||
self.assertEqual(op.outputs, []) | |||
op_input = op.inputs[0].op | |||
self.assertEqual(op_input.type, "Switch") | |||
self.assertEqual(op_input.inputs[0], x) | |||
self.assertEqual(op.graph, g) | |||
# pylint: disable=protected-access | |||
self.assertIsNotNone(op._get_control_flow_context()) | |||
self.assertEqual(op._get_control_flow_context().name, | |||
"cond/cond_text") | |||
# pylint: enable=protected-access | |||
@test_util.run_v1_only("b/120545219") | |||
def testWhileLoop(self): | |||
g = ops.Graph() | |||
with g.as_default(): | |||
x = test_ops.int_output() | |||
def body(i): | |||
ops._create_c_op(ops.get_default_graph(), | |||
ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
new_ops = g._add_new_tf_operations() | |||
self.assertEqual(len(new_ops), 1) | |||
return i | |||
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
op = g.get_operation_by_name("myloop/myop") | |||
self.assertIsNotNone(op) | |||
self.assertEqual(op.name, "myloop/myop") | |||
self.assertEqual(op.type, "IntInput") | |||
self.assertEqual(op.outputs, []) | |||
op_input = op.inputs[0].op | |||
self.assertEqual(op_input.type, "Enter") | |||
self.assertEqual(list(op_input.inputs), [x]) | |||
self.assertEqual(op.graph, g) | |||
# pylint: disable=protected-access | |||
self.assertIsNotNone(op._get_control_flow_context()) | |||
self.assertEqual(op._get_control_flow_context().name, | |||
"myloop/while_context") | |||
# pylint: enable=protected-access | |||
@test_util.run_v1_only("b/120545219") | |||
def testWhileLoopWithInternalControlDep(self): | |||
g = ops.Graph() | |||
with g.as_default(): | |||
x = test_ops.int_output() | |||
def body(i): | |||
c = constant_op.constant(1.0, name="c") | |||
ops._create_c_op(ops.get_default_graph(), | |||
ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
with ops.control_dependencies([c]): | |||
new_ops = g._add_new_tf_operations() | |||
self.assertEqual(len(new_ops), 1) | |||
return i | |||
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
op = g.get_operation_by_name("myloop/myop") | |||
self.assertIsNotNone(op) | |||
c = g.get_operation_by_name("myloop/c") | |||
self.assertIsNotNone(c) | |||
# Internal control dep is preserved | |||
self.assertEqual(op.control_inputs, [c]) | |||
@test_util.run_v1_only("b/120545219") | |||
def testWhileLoopWithExternalControlDep(self): | |||
g = ops.Graph() | |||
with g.as_default(): | |||
x = test_ops.int_output() | |||
c = constant_op.constant(1.0) | |||
def body(i): | |||
ops._create_c_op(ops.get_default_graph(), | |||
ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
with ops.control_dependencies([c]): | |||
new_ops = g._add_new_tf_operations() | |||
self.assertEqual(len(new_ops), 1) | |||
return i | |||
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
op = g.get_operation_by_name("myloop/myop") | |||
self.assertIsNotNone(op) | |||
# External control dep is removed and replaced with internal control dep | |||
self.assertNotEqual(op.control_inputs[0], c.op) | |||
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) | |||
*/ | |||
@test_util.run_v1_only("b/120545219") | |||
def testWhileLoop(self): | |||
g = ops.Graph() | |||
with g.as_default(): | |||
x = test_ops.int_output() | |||
def body(i): | |||
ops._create_c_op(ops.get_default_graph(), | |||
ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
new_ops = g._add_new_tf_operations() | |||
self.assertEqual(len(new_ops), 1) | |||
return i | |||
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
op = g.get_operation_by_name("myloop/myop") | |||
self.assertIsNotNone(op) | |||
self.assertEqual(op.name, "myloop/myop") | |||
self.assertEqual(op.type, "IntInput") | |||
self.assertEqual(op.outputs, []) | |||
op_input = op.inputs[0].op | |||
self.assertEqual(op_input.type, "Enter") | |||
self.assertEqual(list(op_input.inputs), [x]) | |||
self.assertEqual(op.graph, g) | |||
# pylint: disable=protected-access | |||
self.assertIsNotNone(op._get_control_flow_context()) | |||
self.assertEqual(op._get_control_flow_context().name, | |||
"myloop/while_context") | |||
# pylint: enable=protected-access | |||
@test_util.run_v1_only("b/120545219") | |||
def testWhileLoopWithInternalControlDep(self): | |||
g = ops.Graph() | |||
with g.as_default(): | |||
x = test_ops.int_output() | |||
def body(i): | |||
c = constant_op.constant(1.0, name="c") | |||
ops._create_c_op(ops.get_default_graph(), | |||
ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
with ops.control_dependencies([c]): | |||
new_ops = g._add_new_tf_operations() | |||
self.assertEqual(len(new_ops), 1) | |||
return i | |||
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
op = g.get_operation_by_name("myloop/myop") | |||
self.assertIsNotNone(op) | |||
c = g.get_operation_by_name("myloop/c") | |||
self.assertIsNotNone(c) | |||
# Internal control dep is preserved | |||
self.assertEqual(op.control_inputs, [c]) | |||
@test_util.run_v1_only("b/120545219") | |||
def testWhileLoopWithExternalControlDep(self): | |||
g = ops.Graph() | |||
with g.as_default(): | |||
x = test_ops.int_output() | |||
c = constant_op.constant(1.0) | |||
def body(i): | |||
ops._create_c_op(ops.get_default_graph(), | |||
ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
with ops.control_dependencies([c]): | |||
new_ops = g._add_new_tf_operations() | |||
self.assertEqual(len(new_ops), 1) | |||
return i | |||
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
op = g.get_operation_by_name("myloop/myop") | |||
self.assertIsNotNone(op) | |||
# External control dep is removed and replaced with internal control dep | |||
self.assertNotEqual(op.control_inputs[0], c.op) | |||
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) | |||
*/ | |||
} | |||
} |
@@ -29,6 +29,11 @@ namespace TensorFlowNET.UnitTest | |||
Assert.AreEqual(expected, given); | |||
} | |||
public void assertIsNotNone(object given) | |||
{ | |||
Assert.IsNotNull(given); | |||
} | |||
protected PythonTest self { get => this; } | |||
} | |||
} |