Browse Source

Added unit-tests for autocasting mechanism.

tags/v0.12
Eli Belash 6 years ago
parent
commit
4d8ae9a396
1 changed files with 57 additions and 0 deletions
  1. +57
    -0
      test/TensorFlowNET.UnitTest/SessionTest.cs

+ 57
- 0
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -7,6 +7,7 @@ using System.Runtime.CompilerServices;
using System.Text;
using FluentAssertions;
using Google.Protobuf;
using NumSharp.Backends;
using Tensorflow;
using Tensorflow.Util;
using static Tensorflow.Binding;
@@ -131,5 +132,61 @@ namespace TensorFlowNET.UnitTest
}
}
}

[TestMethod]
public void Autocast_Case1()
{
var sess = tf.Session().as_default();
var input = tf.placeholder(tf.float64, shape: new TensorShape(6));
var op = tf.reshape(input, new int[] {2, 3});
sess.run(tf.global_variables_initializer());
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6)));

ret.Should().BeOfType<double>().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
print(ret.dtype);
print(ret);
}

[TestMethod]
public void Autocast_Case2()
{
var sess = tf.Session().as_default();
var input = tf.placeholder(tf.float64, shape: new TensorShape(6));
var op = tf.reshape(input, new int[] {2, 3});
sess.run(tf.global_variables_initializer());
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f));

ret.Should().BeOfType<double>().And.BeShaped(2, 3).And.BeOfValuesApproximately(0.001d, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1);
print(ret.dtype);
print(ret);
}

[TestMethod]
public void Autocast_Case3()
{
var sess = tf.Session().as_default();
var input = tf.placeholder(tf.int16, shape: new TensorShape(6));
var op = tf.reshape(input, new int[] {2, 3});
sess.run(tf.global_variables_initializer());
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f));

ret.Should().BeOfType<short>().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
print(ret.dtype);
print(ret);
}

[TestMethod]
public void Autocast_Case4()
{
var sess = tf.Session().as_default();
var input = tf.placeholder(tf.@byte, shape: new TensorShape(6));
var op = tf.reshape(input, new int[] {2, 3});
sess.run(tf.global_variables_initializer());
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f));

ret.Should().BeOfType<byte>().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
print(ret.dtype);
print(ret);
}
}
}

Loading…
Cancel
Save