@@ -290,33 +290,11 @@ namespace Tensorflow | |||||
{ | { | ||||
// TODO: here a chunk of original code is missing | // TODO: here a chunk of original code is missing | ||||
/* | /* | ||||
if fn1 is not None: | |||||
if true_fn is not None: | |||||
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") | |||||
true_fn = fn1 | |||||
elif true_fn is None: | |||||
raise TypeError("cond(): true_fn argument required") | |||||
if fn2 is not None: | |||||
if false_fn is not None: | |||||
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") | |||||
false_fn = fn2 | |||||
elif false_fn is None: | |||||
raise TypeError("cond(): false_fn argument required") | |||||
if not callable(true_fn): | |||||
raise TypeError("true_fn must be callable.") | |||||
if not callable(false_fn): | |||||
raise TypeError("false_fn must be callable.") | |||||
with ops.name_scope(name, "cond", [pred]): | with ops.name_scope(name, "cond", [pred]): | ||||
if context.executing_eagerly(): | if context.executing_eagerly(): | ||||
if pred: | if pred: | ||||
return _UnpackIfSingleton(true_fn()) | return _UnpackIfSingleton(true_fn()) | ||||
return _UnpackIfSingleton(false_fn()) | return _UnpackIfSingleton(false_fn()) | ||||
# Add the Switch to the graph. | |||||
if isinstance(pred, bool): | |||||
raise TypeError("pred must not be a Python bool") | |||||
*/ | */ | ||||
// Add the Switch to the graph. | // Add the Switch to the graph. | ||||
@@ -19,7 +19,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
var y = tf.constant(5); | var y = tf.constant(5); | ||||
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), | var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), | ||||
() => tf.add(y, tf.constant(23))); | () => tf.add(y, tf.constant(23))); | ||||
tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||||
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||||
self.assertEquals(eval_scalar(z), 34); | self.assertEquals(eval_scalar(z), 34); | ||||
}); | }); | ||||
} | } | ||||
@@ -139,7 +139,7 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
var a_2 = constant_op.constant(3.0); | var a_2 = constant_op.constant(3.0); | ||||
var a_3 = constant_op.constant(4.0); | var a_3 = constant_op.constant(4.0); | ||||
var a_4 = constant_op.constant(5.0); | var a_4 = constant_op.constant(5.0); | ||||
Operation b_1 = null, b_2 = null; | |||||
Tensor b_1 = null, b_2 = null; | |||||
with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl => | with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl => | ||||
{ | { | ||||
b_1 = constant_op.constant(6.0); | b_1 = constant_op.constant(6.0); | ||||
@@ -157,6 +157,12 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
}); | }); | ||||
}); | }); | ||||
}); | }); | ||||
var z=tf.add(a_1, tf.multiply(b_2, b_1)); | |||||
with(g.control_dependencies(new[] {z}), ctrl => | |||||
{ | |||||
var z1 = tf.add(a_3, tf.multiply(a_4, a_2)); | |||||
}); | |||||
tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||||
assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); | assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); | ||||
assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); | assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); | ||||
} | } | ||||