Browse Source

Merge pull request #299 from acifonelli/master

Add missing `operator -`s
tags/v0.10
Haiping GitHub 6 years ago
parent
commit
1fc38386e6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 188 additions and 1 deletions
  1. +4
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  2. +184
    -0
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 4
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -33,10 +33,13 @@ namespace Tensorflow

public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1);

public static Tensor operator -(double x, Tensor y) => BinaryOpWrapper("sub", x, y);
public static Tensor operator -(float x, Tensor y) => BinaryOpWrapper("sub", x, y);
public static Tensor operator -(int x, Tensor y) => BinaryOpWrapper("sub", x, y);
public static Tensor operator -(Tensor x, Tensor y) => BinaryOpWrapper("sub", x, y);
public static Tensor operator -(Tensor x, int y) => BinaryOpWrapper("sub", x, y);
public static Tensor operator -(Tensor x, float y) => BinaryOpWrapper("sub", x, y);
public static Tensor operator -(Tensor x, double y) => BinaryOpWrapper("sub", x, y);
public static Tensor operator -(float x, Tensor y) => BinaryOpWrapper("sub", x, y);

public static Tensor operator *(float x, Tensor y) => BinaryOpWrapper("mul", x, y);
public static Tensor operator *(double x, Tensor y) => BinaryOpWrapper("mul", x, y);


+ 184
- 0
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -215,5 +215,189 @@ namespace TensorFlowNET.UnitTest
}
#endregion
}

[TestMethod]
public void subOpTests()
{
const int rows = 2; // to avoid broadcasting effect
const int cols = 10;

#region intTest
const int firstIntVal = -2;
const int secondIntVal = 3;

var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray();
var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray();
var intResult = firstIntFeed.Sum() - secondIntFeed.Sum();
var intResultTwo = -firstIntFeed.Sum();

var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
var c = tf.reduce_sum(tf.reduce_sum(tf.sub(a, b), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator -(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(a - b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator -(Tensor x, int y)
c = tf.reduce_sum(tf.reduce_sum(a - secondIntVal, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator -(int x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(secondIntVal - a, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, Math.Abs(intResult));
}

// Testing `operator -(Tensor x)
c = tf.reduce_sum(tf.reduce_sum(-a, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResultTwo);
}
#endregion

#region floatTest
const float firstFloatVal = -2.0f;
const float secondFloatVal = 3.0f;

var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray();
var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray();
var floatResult = firstFloatFeed.Sum() - secondFloatFeed.Sum();
var floatResultTwo = -firstFloatFeed.Sum();

a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
c = tf.reduce_sum(tf.reduce_sum(tf.sub(a, b), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}

// Testing `operator -(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(a - b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}

// Testing `operator -(Tensor x, float y)
c = tf.reduce_sum(tf.reduce_sum(a - secondFloatVal, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}

// Testing `operator -(float x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(secondFloatVal - a, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, Math.Abs(floatResult));
}

// Testing `operator -(Tensor x)
c = tf.reduce_sum(tf.reduce_sum(-a, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResultTwo);
}
#endregion

#region doubleTest
const double firstDoubleVal = -2.0;
const double secondDoubleVal = 3.0;

var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray();
var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray();
var doubleResult = firstDoubleFeed.Sum() - secondDoubleFeed.Sum();
var doubleResultTwo = -firstDoubleFeed.Sum();

a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
c = tf.reduce_sum(tf.reduce_sum(tf.sub(a, b), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}

// Testing `operator -(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(a - b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}

// Testing `operator -(Tensor x, double y)
c = tf.reduce_sum(tf.reduce_sum(a - secondFloatVal, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}

// Testing `operator -(double x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(secondFloatVal - a, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, Math.Abs(doubleResult));
}

// Testing `operator -(Tensor x)
c = tf.reduce_sum(tf.reduce_sum(-a, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResultTwo);
}
#endregion
}
}
}

Loading…
Cancel
Save