Browse Source

Unit testing `operator *`s. Add missing `operator /`s (#324)

Also unit testing the new operators.
tags/v0.12
Antonio Haiping 6 years ago
parent
commit
ebf4c9f018
2 changed files with 383 additions and 2 deletions
  1. +17
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  2. +366
    -0
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 17
- 2
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using System;
using System.Linq;
using static Tensorflow.Python;

namespace Tensorflow
@@ -63,9 +64,20 @@ namespace Tensorflow
public static Tensor operator *(long constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor);
public static Tensor operator *(ulong constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor);

public static Tensor operator /(Tensor x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(Tensor x, float y) => BinaryOpWrapper("truediv", x, y);
private static readonly TF_DataType[] _intTfDataTypes = {
TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64,
TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32,
TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64
};
public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y);
public static Tensor operator /(Tensor x, Tensor y) =>
_intTfDataTypes.Contains(x._dtype)
? BinaryOpWrapper("floordiv", x, y)
: BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y);
public static Tensor operator /(Tensor x, float y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(Tensor x, double y) => BinaryOpWrapper("truediv", x, y);

public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y);
@@ -99,6 +111,9 @@ namespace Tensorflow
case "add":
result = gen_math_ops.add(x1, y1, name: scope);
break;
case "floordiv":
result = gen_math_ops.floor_div(x1, y1, name: scope);
break;
case "truediv":
result = gen_math_ops.real_div(x1, y1, name: scope);
break;


+ 366
- 0
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -467,5 +467,371 @@ namespace TensorFlowNET.UnitTest
}
#endregion
}

private IEnumerable<int> MultiplyArray(IReadOnlyCollection<int> first, IReadOnlyCollection<int> second)
{
if(first.Count != second.Count)
throw new ArgumentException("Arrays should be of equal size!");

var firstEnumerator = first.GetEnumerator();
var secondEnumerator = second.GetEnumerator();
var result = new List<int>();
while (firstEnumerator.MoveNext())
{
secondEnumerator.MoveNext();
result.Add(firstEnumerator.Current * secondEnumerator.Current);
}

firstEnumerator.Dispose();
secondEnumerator.Dispose();

return result;
}
private IEnumerable<float> MultiplyArray(IReadOnlyCollection<float> first, IReadOnlyCollection<float> second)
{
if(first.Count != second.Count)
throw new ArgumentException("Arrays should be of equal size!");

var firstEnumerator = first.GetEnumerator();
var secondEnumerator = second.GetEnumerator();
var result = new List<float>();
while (firstEnumerator.MoveNext())
{
secondEnumerator.MoveNext();
result.Add(firstEnumerator.Current * secondEnumerator.Current);
}

firstEnumerator.Dispose();
secondEnumerator.Dispose();

return result;
}
private IEnumerable<double> MultiplyArray(IReadOnlyCollection<double> first, IReadOnlyCollection<double> second)
{
if(first.Count != second.Count)
throw new ArgumentException("Arrays should be of equal size!");

var firstEnumerator = first.GetEnumerator();
var secondEnumerator = second.GetEnumerator();
var result = new List<double>();
while (firstEnumerator.MoveNext())
{
secondEnumerator.MoveNext();
result.Add(firstEnumerator.Current * secondEnumerator.Current);
}

firstEnumerator.Dispose();
secondEnumerator.Dispose();

return result;
}

[TestMethod]
public void mulOpTests()
{
const int rows = 2; // to avoid broadcasting effect
const int cols = 10;

#region intTest
const int firstIntVal = 2;
const int secondIntVal = 3;

var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray();
var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray();
var intResult = MultiplyArray(firstIntFeed, secondIntFeed).Sum();

var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
var c = tf.reduce_sum(tf.reduce_sum(tf.multiply(a, b), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator *(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(a * b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator *(Tensor x, int y)
c = tf.reduce_sum(tf.reduce_sum(a * secondIntVal, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator *(int x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(firstIntVal * b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}
#endregion

#region floatTest
const float firstFloatVal = 2.0f;
const float secondFloatVal = 3.0f;

var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray();
var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray();
var floatResult = MultiplyArray(firstFloatFeed, secondFloatFeed).Sum();

a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
c = tf.reduce_sum(tf.reduce_sum(tf.multiply(a, b), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}

// Testing `operator *(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(a * b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}

// Testing `operator *(Tensor x, float y)
c = tf.reduce_sum(tf.reduce_sum(a * secondFloatVal, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}

// Testing `operator *(float x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(firstFloatVal * b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}
#endregion

#region doubleTest
const double firstDoubleVal = 2.0;
const double secondDoubleVal = 3.0;

var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray();
var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray();
var doubleResult = MultiplyArray(firstDoubleFeed, secondDoubleFeed).Sum();

a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
c = tf.reduce_sum(tf.reduce_sum(tf.multiply(a, b), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}

// Testing `operator *(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(a * b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}

// Testing `operator *(Tensor x, double y)
c = tf.reduce_sum(tf.reduce_sum(a * secondFloatVal, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}

// Testing `operator *(double x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(firstFloatVal * b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double) o, doubleResult);
}
#endregion
}

[TestMethod]
public void divOpTests()
{
const int rows = 2; // to avoid broadcasting effect
const int cols = 10;

#region intTest
const int firstIntVal = 6;
const int secondIntVal = 3;

var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray();
var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray();
var intResult = (int)(firstIntFeed.Sum() / (float)secondIntVal);

var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
var c = tf.reduce_sum(tf.reduce_sum(gen_math_ops.floor_div(a, b), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator /(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(a / b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator /(Tensor x, int y)
c = tf.reduce_sum(tf.reduce_sum(a / secondIntVal, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator /(int x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(firstIntVal / b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}
#endregion

#region floatTest
const float firstFloatVal = 6.0f;
const float secondFloatVal = 3.0f;

var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray();
var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray();
var floatResult = MultiplyArray(firstFloatFeed, secondFloatFeed.Select(x => 1/x).ToArray()).Sum();

a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
c = tf.reduce_sum(tf.reduce_sum(tf.divide(a, b), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}

// Testing `operator /(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(a / b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}

// Testing `operator /(Tensor x, float y)
c = tf.reduce_sum(tf.reduce_sum(a / secondFloatVal, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}

// Testing `operator /(float x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(firstFloatVal / b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}
#endregion

#region doubleTest
const double firstDoubleVal = 6.0;
const double secondDoubleVal = 3.0;

var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray();
var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray();
var doubleResult = MultiplyArray(firstDoubleFeed, secondDoubleFeed.Select(x => 1/x).ToArray()).Sum();

a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
c = tf.reduce_sum(tf.reduce_sum(tf.divide(a, b), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}

// Testing `operator /(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(a / b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}

// Testing `operator /(Tensor x, double y)
c = tf.reduce_sum(tf.reduce_sum(a / secondFloatVal, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}

// Testing `operator /(double x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(firstFloatVal / b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}
#endregion
}
}
}

Loading…
Cancel
Save