You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VariableTest.cs 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System.Linq;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.Basics
  6. {
  7. [TestClass]
  8. public class VariableTest : EagerModeTestBase
  9. {
  10. [TestMethod]
  11. public void NewVariable()
  12. {
  13. var x = tf.Variable(10, name: "x");
  14. Assert.AreEqual(0, x.shape.ndim);
  15. Assert.AreEqual(10, (int)x.numpy());
  16. }
  17. [TestMethod]
  18. public void StringVar()
  19. {
  20. var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.@string);
  21. var mammal2 = tf.Variable("Tiger");
  22. }
  23. [TestMethod]
  24. public void VarSum()
  25. {
  26. var x = tf.constant(3, name: "x");
  27. var y = tf.Variable(x + 1, name: "y");
  28. Assert.AreEqual(4, (int)y.numpy());
  29. }
  30. [TestMethod]
  31. public void Assign1()
  32. {
  33. var variable = tf.Variable(31, name: "tree");
  34. var unread = variable.assign(12);
  35. Assert.AreEqual(12, (int)unread.numpy());
  36. }
  37. [TestMethod]
  38. public void Assign2()
  39. {
  40. var v1 = tf.Variable(10.0f, name: "v1");
  41. var v2 = v1.assign(v1 + 1.0f);
  42. Assert.AreEqual(v1.numpy(), v2.numpy());
  43. Assert.AreEqual(11f, (float)v1.numpy());
  44. }
  45. [TestMethod]
  46. public void Assign3()
  47. {
  48. var v1 = tf.Variable(10.0f, name: "v1");
  49. var v2 = tf.Variable(v1, name: "v2");
  50. Assert.AreEqual(v1.numpy(), v2.numpy());
  51. v1.assign(30.0f);
  52. Assert.AreNotEqual(v1.numpy(), v2.numpy());
  53. }
  54. /// <summary>
  55. /// Assign tensor to slice of other tensor.
  56. /// https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__
  57. /// </summary>
  58. [TestMethod]
  59. public void SliceAssign()
  60. {
  61. NDArray nd = new float[,]
  62. {
  63. { 1, 2, 3 },
  64. { 4, 5, 6 },
  65. { 7, 8, 9 }
  66. };
  67. var x = tf.Variable(nd);
  68. // get slice form variable
  69. var sliced = x[":2", ":2"];
  70. Assert.AreEqual(nd[0][":2"], sliced[0].numpy());
  71. Assert.AreEqual(nd[1][":2"], sliced[1].numpy());
  72. // assign to the sliced tensor
  73. sliced.assign(22 * tf.ones((2, 2)));
  74. // test assigned value
  75. nd = new float[,]
  76. {
  77. { 22, 22, 3 },
  78. { 22, 22, 6 },
  79. { 7, 8, 9 }
  80. };
  81. Assert.AreEqual(nd[0], x[0].numpy());
  82. Assert.AreEqual(nd[1], x[1].numpy());
  83. Assert.AreEqual(nd[2], x[2].numpy());
  84. }
  85. [TestMethod, Ignore]
  86. public void TypeMismatchedSliceAssign()
  87. {
  88. NDArray intNd = new int[]
  89. {
  90. 1, -2, 3
  91. };
  92. NDArray doubleNd = new double[]
  93. {
  94. -5, 6, -7
  95. };
  96. var x = tf.Variable(doubleNd);
  97. var slice = x[":"];
  98. Assert.ThrowsException<System.Exception>(
  99. // this statement exit without throwing any exception but the "test execution summary" seems not able to detect that.
  100. () => slice.assign(intNd)
  101. );
  102. }
  103. [TestMethod]
  104. public void Accumulation()
  105. {
  106. var x = tf.Variable(10, name: "x");
  107. for (int i = 0; i < 5; i++)
  108. x.assign(x + 1);
  109. Assert.AreEqual(15, (int)x.numpy());
  110. }
  111. [TestMethod]
  112. public void ShouldReturnNegative()
  113. {
  114. var x = tf.constant(new[,] { { 1, 2 } });
  115. var neg_x = tf.negative(x);
  116. Assert.IsTrue(Enumerable.SequenceEqual(new[] { 1, 2 }, neg_x.shape));
  117. Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray<int>()));
  118. }
  119. [TestMethod]
  120. public void IdentityOriginalTensor()
  121. {
  122. var a = tf.Variable(5);
  123. var a_identity = tf.identity(a);
  124. a.assign_add(1);
  125. Assert.AreEqual(5, (int)a_identity.numpy());
  126. Assert.AreEqual(6, (int)a.numpy());
  127. }
  128. }
  129. }