Browse Source

Adding `logical_xor` operator (#346)

Relative unit test in `OperationTest`.
tags/v0.12
Antonio Cifonelli Haiping 6 years ago
parent
commit
924e1592af
3 changed files with 20 additions and 0 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +8
    -0
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  3. +9
    -0
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -198,6 +198,9 @@ namespace Tensorflow
public static Tensor logical_or(Tensor x, Tensor y, string name = null)
=> gen_math_ops.logical_or(x, y, name);

public static Tensor logical_xor(Tensor x, Tensor y, string name = "LogicalXor")
=> gen_math_ops.logical_xor(x, y, name);

/// <summary>
/// Clips tensor values to a specified min and max.
/// </summary>


+ 8
- 0
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -371,6 +371,14 @@ namespace Tensorflow
return _op.outputs[0];
}
public static Tensor logical_xor(Tensor x, Tensor y, string name = "LogicalXor")
{
return logical_and(
logical_or(x, y),
logical_not(logical_and(x, y)),
name);
}
public static Tensor squared_difference(Tensor x, Tensor y, string name = null)
{
var _op = _op_def_lib._apply_op_helper("SquaredDifference", name, args: new { x, y, name });


+ 9
- 0
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -162,6 +162,15 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(d);
Assert.IsTrue(o.array_equal(check));
}

d = tf.cast(tf.logical_xor(b, c), tf.int32);
check = np.array(new[] { 1, 1, 1, 1, 1, 1, 1, 1 });

using (var sess = tf.Session())
{
var o = sess.run(d);
Assert.IsTrue(o.array_equal(check));
}
}

[TestMethod]


Loading…
Cancel
Save