From 9dba68004140fe34cd9efae2c92f78db2441f73e Mon Sep 17 00:00:00 2001 From: Meinrad Recheis Date: Sun, 14 Apr 2019 11:23:33 +0200 Subject: [PATCH] minor changes --- .../Operations/control_flow_ops.py.cs | 22 ------------------- .../control_flow_ops_test/CondTestCases.cs | 2 +- .../ops_test/ControlDependenciesTest.cs | 8 ++++++- 3 files changed, 8 insertions(+), 24 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index aebcfaef..f1c799e9 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -290,33 +290,11 @@ namespace Tensorflow { // TODO: here a chunk of original code is missing /* - if fn1 is not None: - if true_fn is not None: - raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") - true_fn = fn1 - elif true_fn is None: - raise TypeError("cond(): true_fn argument required") - if fn2 is not None: - if false_fn is not None: - raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") - false_fn = fn2 - elif false_fn is None: - raise TypeError("cond(): false_fn argument required") - - if not callable(true_fn): - raise TypeError("true_fn must be callable.") - if not callable(false_fn): - raise TypeError("false_fn must be callable.") - with ops.name_scope(name, "cond", [pred]): if context.executing_eagerly(): if pred: return _UnpackIfSingleton(true_fn()) return _UnpackIfSingleton(false_fn()) - - # Add the Switch to the graph. - if isinstance(pred, bool): - raise TypeError("pred must not be a Python bool") */ // Add the Switch to the graph. diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs index 338d0388..8fb9d9bb 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs @@ -19,7 +19,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test var y = tf.constant(5); var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), () => tf.add(y, tf.constant(23))); - tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); + //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); self.assertEquals(eval_scalar(z), 34); }); } diff --git a/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs b/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs index 315313c6..74935c76 100644 --- a/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs +++ b/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs @@ -139,7 +139,7 @@ namespace TensorFlowNET.UnitTest.ops_test var a_2 = constant_op.constant(3.0); var a_3 = constant_op.constant(4.0); var a_4 = constant_op.constant(5.0); - Operation b_1 = null, b_2 = null; + Tensor b_1 = null, b_2 = null; with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl => { b_1 = constant_op.constant(6.0); @@ -157,6 +157,12 @@ namespace TensorFlowNET.UnitTest.ops_test }); }); }); + var z=tf.add(a_1, tf.multiply(b_2, b_1)); + with(g.control_dependencies(new[] {z}), ctrl => + { + var z1 = tf.add(a_3, tf.multiply(a_4, a_2)); + }); + tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); }