You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VariableTest.cs 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System.Linq;
  4. using static Tensorflow.Binding;
  5. using System;
  6. namespace TensorFlowNET.UnitTest.Basics
  7. {
  8. [TestClass]
  9. public class VariableTest : EagerModeTestBase
  10. {
  11. [TestMethod]
  12. public void NewVariable()
  13. {
  14. var x = tf.Variable(10, name: "x");
  15. Assert.AreEqual(0, x.shape.ndim);
  16. Assert.AreEqual(x.numpy(), 10);
  17. }
  18. [TestMethod]
  19. public void StringVar()
  20. {
  21. var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.@string);
  22. var mammal2 = tf.Variable("Tiger");
  23. }
  24. [TestMethod]
  25. public void VarSum()
  26. {
  27. var x = tf.constant(3, name: "x");
  28. var y = tf.Variable(x + 1, name: "y");
  29. Assert.AreEqual(y.numpy(), 4);
  30. }
  31. [TestMethod]
  32. public void Assign1()
  33. {
  34. var variable = tf.Variable(31, name: "tree");
  35. var unread = variable.assign(12);
  36. Assert.AreEqual(unread.numpy(), 12);
  37. }
  38. [TestMethod]
  39. public void Assign2()
  40. {
  41. var v1 = tf.Variable(10.0f, name: "v1");
  42. var v2 = v1.assign(v1 + 1.0f);
  43. Assert.AreEqual(v1.numpy(), v2.numpy());
  44. Assert.AreEqual(v1.numpy(), 11f);
  45. }
  46. [TestMethod]
  47. public void Assign3()
  48. {
  49. var v1 = tf.Variable(10.0f, name: "v1");
  50. var v2 = tf.Variable(v1, name: "v2");
  51. Assert.AreEqual(v1.numpy(), v2.numpy());
  52. v1.assign(30.0f);
  53. Assert.AreNotEqual(v1.numpy(), v2.numpy());
  54. }
  55. /// <summary>
  56. /// Assign tensor to slice of other tensor.
  57. /// https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__
  58. /// </summary>
  59. [TestMethod]
  60. public void SliceAssign()
  61. {
  62. NDArray nd = new float[,]
  63. {
  64. { 1, 2, 3 },
  65. { 4, 5, 6 },
  66. { 7, 8, 9 }
  67. };
  68. var x = tf.Variable(nd);
  69. // get slice form variable
  70. var sliced = x[":2", ":2"];
  71. Assert.AreEqual(nd[0][":2"], sliced[0].numpy());
  72. Assert.AreEqual(nd[1][":2"], sliced[1].numpy());
  73. // assign to the sliced tensor
  74. sliced.assign(22 * tf.ones((2, 2)));
  75. // test assigned value
  76. nd = new float[,]
  77. {
  78. { 22, 22, 3 },
  79. { 22, 22, 6 },
  80. { 7, 8, 9 }
  81. };
  82. Assert.AreEqual(nd[0], x[0].numpy());
  83. Assert.AreEqual(nd[1], x[1].numpy());
  84. Assert.AreEqual(nd[2], x[2].numpy());
  85. }
  86. [TestMethod]
  87. [ExpectedException(typeof(ArrayTypeMismatchException))]
  88. public void TypeMismatchedSliceAssign()
  89. {
  90. NDArray intNd = new int[]
  91. {
  92. 1, -2, 3
  93. };
  94. NDArray doubleNd = new double[]
  95. {
  96. -5, 6, -7
  97. };
  98. var x = tf.Variable(doubleNd);
  99. x[":"].assign(intNd);
  100. }
  101. [TestMethod]
  102. public void Accumulation()
  103. {
  104. var x = tf.Variable(10, name: "x");
  105. for (int i = 0; i < 5; i++)
  106. x.assign(x + 1);
  107. Assert.AreEqual(x.numpy(), 15);
  108. }
  109. [TestMethod]
  110. public void ShouldReturnNegative()
  111. {
  112. var x = tf.constant(new[,] { { 1, 2 } });
  113. var neg_x = tf.negative(x);
  114. Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 1, 2 }, neg_x.shape.dims));
  115. Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray<int>()));
  116. }
  117. [TestMethod]
  118. public void IdentityOriginalTensor()
  119. {
  120. var a = tf.Variable(5);
  121. var a_identity = tf.identity(a);
  122. a.assign_add(1);
  123. Assert.AreEqual(a_identity.numpy(), 5);
  124. Assert.AreEqual(a.numpy(), 6);
  125. }
  126. }
  127. }