From bad9aba49ee7a0c39cfc76723cd6473a318895ca Mon Sep 17 00:00:00 2001 From: Antonio Cifonelli Date: Mon, 5 Aug 2019 23:07:26 +0200 Subject: [PATCH] Adding `logical_not` operator (#343) Relative unit test in `OperationTest`. --- src/TensorFlowNET.Core/APIs/tf.math.cs | 3 +++ src/TensorFlowNET.Core/Operations/gen_math_ops.cs | 7 +++++++ test/TensorFlowNET.UnitTest/OperationsTest.cs | 11 ++++++++++- 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index fb65d31b..b787bf1d 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -192,6 +192,9 @@ namespace Tensorflow public static Tensor logical_and(Tensor x, Tensor y, string name = null) => gen_math_ops.logical_and(x, y, name); + public static Tensor logical_not(Tensor x, string name = null) + => gen_math_ops.logical_not(x, name); + /// /// Clips tensor values to a specified min and max. /// diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index a8b9ac49..c3b30d8f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -357,6 +357,13 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor logical_not(Tensor x, string name = null) + { + var _op = _op_def_lib._apply_op_helper("LogicalNot", name, args: new { x }); + + return _op.outputs[0]; + } + public static Tensor squared_difference(Tensor x, Tensor y, string name = null) { var _op = _op_def_lib._apply_op_helper("SquaredDifference", name, args: new { x, y, name }); diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 10046f0c..68c44831 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -131,7 +131,7 @@ namespace TensorFlowNET.UnitTest } [TestMethod] - public void logicalAndTest() + public void logicalOpsTest() { var a = tf.constant(new[] {1f, 2f, 3f, 4f, -4f, -3f, -2f, -1f}); var b = tf.less(a, 0f); @@ -144,6 +144,15 @@ namespace TensorFlowNET.UnitTest var o = sess.run(d); Assert.IsTrue(o.array_equal(check)); } + + d = tf.cast(tf.logical_not(b), tf.int32); + check = np.array(new[] { 1, 1, 1, 1, 0, 0, 0, 0 }); + + using (var sess = tf.Session()) + { + var o = sess.run(d); + Assert.IsTrue(o.array_equal(check)); + } } [TestMethod]