|
|
@@ -11,21 +11,19 @@ namespace TensorFlowNET.UnitTest |
|
|
|
[TestClass] |
|
|
|
public class ConstantTest |
|
|
|
{ |
|
|
|
Tensor tensor; |
|
|
|
|
|
|
|
[TestMethod] |
|
|
|
public void ScalarConst() |
|
|
|
{ |
|
|
|
tensor = tf.constant(8); // int |
|
|
|
tensor = tf.constant(6.0f); // float |
|
|
|
tensor = tf.constant(6.0); // double |
|
|
|
var tensor1 = tf.constant(8); // int |
|
|
|
var tensor2 = tf.constant(6.0f); // float |
|
|
|
var tensor3 = tf.constant(6.0); // double |
|
|
|
} |
|
|
|
|
|
|
|
[TestMethod] |
|
|
|
public void StringConst() |
|
|
|
{ |
|
|
|
string str = "Hello, TensorFlow.NET!"; |
|
|
|
tensor = tf.constant(str); |
|
|
|
var tensor = tf.constant(str); |
|
|
|
Python.with<Session>(tf.Session(), sess => |
|
|
|
{ |
|
|
|
var result = sess.run(tensor); |
|
|
@@ -37,7 +35,7 @@ namespace TensorFlowNET.UnitTest |
|
|
|
public void ZerosConst() |
|
|
|
{ |
|
|
|
// small size |
|
|
|
tensor = tf.zeros(new Shape(3, 2), TF_DataType.TF_INT32, "small"); |
|
|
|
var tensor = tf.zeros(new Shape(3, 2), TF_DataType.TF_INT32, "small"); |
|
|
|
Python.with<Session>(tf.Session(), sess => |
|
|
|
{ |
|
|
|
var result = sess.run(tensor); |
|
|
@@ -67,11 +65,34 @@ namespace TensorFlowNET.UnitTest |
|
|
|
{ |
|
|
|
var nd = np.array(new int[][] |
|
|
|
{ |
|
|
|
new int[]{ 1, 2, 3 }, |
|
|
|
new int[]{ 4, 5, 6 } |
|
|
|
new int[]{ 3, 1, 1 }, |
|
|
|
new int[]{ 2, 1, 3 } |
|
|
|
}); |
|
|
|
|
|
|
|
tensor = tf.constant(nd); |
|
|
|
var tensor = tf.constant(nd); |
|
|
|
Python.with<Session>(tf.Session(), sess => |
|
|
|
{ |
|
|
|
var result = sess.run(tensor); |
|
|
|
var data = result.Data<int>(); |
|
|
|
|
|
|
|
Assert.AreEqual(result.shape[0], 2); |
|
|
|
Assert.AreEqual(result.shape[1], 3); |
|
|
|
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 2, 1, 1, 3 }, data)); |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
[TestMethod] |
|
|
|
public void Multiply() |
|
|
|
{ |
|
|
|
var a = tf.constant(3.0); |
|
|
|
var b = tf.constant(2.0); |
|
|
|
var c = a * b; |
|
|
|
|
|
|
|
var sess = tf.Session(); |
|
|
|
double result = sess.run(c); |
|
|
|
sess.close(); |
|
|
|
|
|
|
|
Assert.AreEqual(6.0, result); |
|
|
|
} |
|
|
|
} |
|
|
|
} |