You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VariableTest.cs 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Linq;
  3. using static Tensorflow.Binding;
  4. namespace TensorFlowNET.UnitTest.Basics
  5. {
  6. [TestClass]
  7. public class VariableTest : EagerModeTestBase
  8. {
  9. [TestMethod]
  10. public void NewVariable()
  11. {
  12. var x = tf.Variable(10, name: "x");
  13. Assert.AreEqual(0, x.shape.ndim);
  14. Assert.AreEqual(10, (int)x.numpy());
  15. }
  16. [TestMethod]
  17. public void StringVar()
  18. {
  19. var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.@string);
  20. var mammal2 = tf.Variable("Tiger");
  21. }
  22. [TestMethod]
  23. public void VarSum()
  24. {
  25. var x = tf.constant(3, name: "x");
  26. var y = tf.Variable(x + 1, name: "y");
  27. Assert.AreEqual(4, (int)y.numpy());
  28. }
  29. [TestMethod]
  30. public void Assign1()
  31. {
  32. var variable = tf.Variable(31, name: "tree");
  33. var unread = variable.assign(12);
  34. Assert.AreEqual(12, (int)unread.numpy());
  35. }
  36. [TestMethod]
  37. public void Assign2()
  38. {
  39. var v1 = tf.Variable(10.0f, name: "v1");
  40. var v2 = v1.assign(v1 + 1.0f);
  41. Assert.AreEqual(v1.numpy(), v2.numpy());
  42. Assert.AreEqual(11f, (float)v1.numpy());
  43. }
  44. [TestMethod]
  45. public void Accumulation()
  46. {
  47. var x = tf.Variable(10, name: "x");
  48. for (int i = 0; i < 5; i++)
  49. x.assign(x + 1);
  50. Assert.AreEqual(15, (int)x.numpy());
  51. }
  52. [TestMethod]
  53. public void ShouldReturnNegative()
  54. {
  55. var x = tf.constant(new[,] { { 1, 2 } });
  56. var neg_x = tf.negative(x);
  57. Assert.IsTrue(Enumerable.SequenceEqual(new[] { 1, 2 }, neg_x.shape));
  58. Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray<int>()));
  59. }
  60. }
  61. }