You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VariableTest.cs 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. using FluentAssertions;
  2. using Microsoft.VisualStudio.TestTools.UnitTesting;
  3. using NumSharp;
  4. using System.Linq;
  5. using Tensorflow;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.UnitTest.Basics
  8. {
  9. [TestClass]
  10. public class VariableTest
  11. {
  12. [TestMethod]
  13. public void NewVariable()
  14. {
  15. var x = tf.Variable(10, name: "x");
  16. Assert.AreEqual("x:0", x.name);
  17. Assert.AreEqual(0, x.shape.ndim);
  18. Assert.AreEqual(10, (int)x.numpy());
  19. }
  20. [TestMethod]
  21. public void StringVar()
  22. {
  23. var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.@string);
  24. var mammal2 = tf.Variable("Tiger");
  25. }
  26. [TestMethod]
  27. public void VarSum()
  28. {
  29. var x = tf.constant(3, name: "x");
  30. var y = tf.Variable(x + 1, name: "y");
  31. Assert.AreEqual(4, (int)y.numpy());
  32. }
  33. [TestMethod]
  34. public void Assign1()
  35. {
  36. var variable = tf.Variable(31, name: "tree");
  37. var unread = variable.assign(12);
  38. Assert.AreEqual(12, (int)unread.numpy());
  39. }
  40. [TestMethod]
  41. public void Assign2()
  42. {
  43. var v1 = tf.Variable(10.0f, name: "v1");
  44. var v2 = v1.assign(v1 + 1.0f);
  45. Assert.AreEqual(v1.numpy(), v2.numpy());
  46. Assert.AreEqual(11f, (float)v1.numpy());
  47. }
  48. [TestMethod]
  49. public void Accumulation()
  50. {
  51. var x = tf.Variable(10, name: "x");
  52. for (int i = 0; i < 5; i++)
  53. x = x + 1;
  54. Assert.AreEqual(15, (int)x.numpy());
  55. }
  56. [TestMethod]
  57. public void ShouldReturnNegative()
  58. {
  59. var x = tf.constant(new[,] { { 1, 2 } });
  60. var neg_x = tf.negative(x);
  61. Assert.IsTrue(Enumerable.SequenceEqual(new[] { 1, 2 }, neg_x.shape));
  62. Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray<int>()));
  63. }
  64. }
  65. }