You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VariableTest.cs 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. using FluentAssertions;
  2. using Microsoft.VisualStudio.TestTools.UnitTesting;
  3. using NumSharp;
  4. using System.Linq;
  5. using Tensorflow;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.UnitTest.Basics
  8. {
  9. [TestClass]
  10. public class VariableTest
  11. {
  12. [TestMethod]
  13. public void NewVariable()
  14. {
  15. var x = tf.Variable(10, name: "x");
  16. Assert.AreEqual(0, x.shape.ndim);
  17. Assert.AreEqual(10, (int)x.numpy());
  18. }
  19. [TestMethod]
  20. public void StringVar()
  21. {
  22. var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.@string);
  23. var mammal2 = tf.Variable("Tiger");
  24. }
  25. [TestMethod]
  26. public void VarSum()
  27. {
  28. var x = tf.constant(3, name: "x");
  29. var y = tf.Variable(x + 1, name: "y");
  30. Assert.AreEqual(4, (int)y.numpy());
  31. }
  32. [TestMethod]
  33. public void Assign1()
  34. {
  35. var variable = tf.Variable(31, name: "tree");
  36. var unread = variable.assign(12);
  37. Assert.AreEqual(12, (int)unread.numpy());
  38. }
  39. [TestMethod]
  40. public void Assign2()
  41. {
  42. var v1 = tf.Variable(10.0f, name: "v1");
  43. var v2 = v1.assign(v1 + 1.0f);
  44. Assert.AreEqual(v1.numpy(), v2.numpy());
  45. Assert.AreEqual(11f, (float)v1.numpy());
  46. }
  47. [TestMethod]
  48. public void Accumulation()
  49. {
  50. var x = tf.Variable(10, name: "x");
  51. for (int i = 0; i < 5; i++)
  52. x.assign(x + 1);
  53. Assert.AreEqual(15, (int)x.numpy());
  54. }
  55. [TestMethod]
  56. public void ShouldReturnNegative()
  57. {
  58. var x = tf.constant(new[,] { { 1, 2 } });
  59. var neg_x = tf.negative(x);
  60. Assert.IsTrue(Enumerable.SequenceEqual(new[] { 1, 2 }, neg_x.shape));
  61. Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray<int>()));
  62. }
  63. }
  64. }