You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VariableTest.cs 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System.Linq;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.Basics
  6. {
  7. [TestClass]
  8. public class VariableTest : EagerModeTestBase
  9. {
  10. [TestMethod]
  11. public void NewVariable()
  12. {
  13. var x = tf.Variable(10, name: "x");
  14. Assert.AreEqual(0, x.shape.ndim);
  15. Assert.AreEqual(10, (int)x.numpy());
  16. }
  17. [TestMethod]
  18. public void StringVar()
  19. {
  20. var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.@string);
  21. var mammal2 = tf.Variable("Tiger");
  22. }
  23. [TestMethod]
  24. public void VarSum()
  25. {
  26. var x = tf.constant(3, name: "x");
  27. var y = tf.Variable(x + 1, name: "y");
  28. Assert.AreEqual(4, (int)y.numpy());
  29. }
  30. [TestMethod]
  31. public void Assign1()
  32. {
  33. var variable = tf.Variable(31, name: "tree");
  34. var unread = variable.assign(12);
  35. Assert.AreEqual(12, (int)unread.numpy());
  36. }
  37. [TestMethod]
  38. public void Assign2()
  39. {
  40. var v1 = tf.Variable(10.0f, name: "v1");
  41. var v2 = v1.assign(v1 + 1.0f);
  42. Assert.AreEqual(v1.numpy(), v2.numpy());
  43. Assert.AreEqual(11f, (float)v1.numpy());
  44. }
  45. /// <summary>
  46. /// Assign tensor to slice of other tensor.
  47. /// https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__
  48. /// </summary>
  49. [TestMethod]
  50. public void SliceAssign()
  51. {
  52. NDArray nd = new float[,]
  53. {
  54. { 1, 2, 3 },
  55. { 4, 5, 6 },
  56. { 7, 8, 9 }
  57. };
  58. var x = tf.Variable(nd);
  59. // get slice form variable
  60. var sliced = x[":2", ":2"];
  61. Assert.AreEqual(nd[0][":2"], sliced[0].numpy());
  62. Assert.AreEqual(nd[1][":2"], sliced[1].numpy());
  63. // assign to the sliced tensor
  64. sliced.assign(22 * tf.ones((2, 2)));
  65. // test assigned value
  66. nd = new float[,]
  67. {
  68. { 22, 22, 3 },
  69. { 22, 22, 6 },
  70. { 7, 8, 9 }
  71. };
  72. Assert.AreEqual(nd[0], x[0].numpy());
  73. Assert.AreEqual(nd[1], x[1].numpy());
  74. Assert.AreEqual(nd[2], x[2].numpy());
  75. }
  76. [TestMethod, Ignore]
  77. public void TypeMismatchedSliceAssign()
  78. {
  79. NDArray intNd = new int[]
  80. {
  81. 1, -2, 3
  82. };
  83. NDArray doubleNd = new double[]
  84. {
  85. -5, 6, -7
  86. };
  87. var x = tf.Variable(doubleNd);
  88. var slice = x[":"];
  89. Assert.ThrowsException<System.Exception>(
  90. // this statement exit without throwing any exception but the "test execution summary" seems not able to detect that.
  91. () => slice.assign(intNd)
  92. );
  93. }
  94. [TestMethod]
  95. public void Accumulation()
  96. {
  97. var x = tf.Variable(10, name: "x");
  98. for (int i = 0; i < 5; i++)
  99. x.assign(x + 1);
  100. Assert.AreEqual(15, (int)x.numpy());
  101. }
  102. [TestMethod]
  103. public void ShouldReturnNegative()
  104. {
  105. var x = tf.constant(new[,] { { 1, 2 } });
  106. var neg_x = tf.negative(x);
  107. Assert.IsTrue(Enumerable.SequenceEqual(new[] { 1, 2 }, neg_x.shape));
  108. Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray<int>()));
  109. }
  110. }
  111. }