diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 5586840c..fb65d31b 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -189,6 +189,9 @@ namespace Tensorflow public static Tensor log1p(Tensor x, string name = null) => gen_math_ops.log1p(x, name); + public static Tensor logical_and(Tensor x, Tensor y, string name = null) + => gen_math_ops.logical_and(x, y, name); + /// /// Clips tensor values to a specified min and max. /// diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 8ec7e253..a8b9ac49 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -350,6 +350,13 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor logical_and(Tensor x, Tensor y, string name = null) + { + var _op = _op_def_lib._apply_op_helper("LogicalAnd", name, args: new { x, y }); + + return _op.outputs[0]; + } + public static Tensor squared_difference(Tensor x, Tensor y, string name = null) { var _op = _op_def_lib._apply_op_helper("SquaredDifference", name, args: new { x, y, name }); diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index a0a3b5e4..10046f0c 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -130,6 +130,22 @@ namespace TensorFlowNET.UnitTest } } + [TestMethod] + public void logicalAndTest() + { + var a = tf.constant(new[] {1f, 2f, 3f, 4f, -4f, -3f, -2f, -1f}); + var b = tf.less(a, 0f); + var c = tf.greater(a, 0f); + var d = tf.cast(tf.logical_and(b, c), tf.int32); + var check = np.array(new[] { 0, 0, 0, 0, 0, 0, 0, 0 }); + + using (var sess = tf.Session()) + { + var o = sess.run(d); + Assert.IsTrue(o.array_equal(check)); + } + } + [TestMethod] public void addOpTests() {