You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientTest.cs 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using static Tensorflow.Binding;
  4. namespace TensorFlowNET.UnitTest.ManagedAPI
  5. {
  6. [TestClass]
  7. public class GradientTest
  8. {
  9. [TestMethod]
  10. public void GradientFloatTest()
  11. {
  12. var x = tf.Variable(3.0, dtype: tf.float32);
  13. using var tape = tf.GradientTape();
  14. var y = tf.square(x);
  15. var y_grad = tape.gradient(y, x);
  16. Assert.AreEqual(9.0f, (float)y);
  17. }
  18. [TestMethod]
  19. public void GradientDefaultTest()
  20. {
  21. var x = tf.Variable(3.0);
  22. using var tape = tf.GradientTape();
  23. var y = tf.square(x);
  24. var y_grad = tape.gradient(y, x);
  25. Assert.AreEqual(9.0, (double)y);
  26. }
  27. [TestMethod]
  28. public void GradientDoubleTest()
  29. {
  30. var x = tf.Variable(3.0, dtype: tf.float64);
  31. using var tape = tf.GradientTape();
  32. var y = tf.square(x);
  33. var y_grad = tape.gradient(y, x);
  34. Assert.AreEqual(9.0, (double)y);
  35. }
  36. [TestMethod]
  37. public void GradientOperatorMulTest()
  38. {
  39. var x = tf.constant(0f);
  40. var w = tf.Variable(new float[] { 1, 1 });
  41. using var gt = tf.GradientTape();
  42. var y = x * w;
  43. var gr = gt.gradient(y, w);
  44. Assert.AreEqual(new float[] { 0, 0 }, gr.numpy());
  45. }
  46. [TestMethod]
  47. public void GradientSliceTest()
  48. {
  49. var X = tf.zeros(new TensorShape(10));
  50. var W = tf.Variable(-0.06f, name: "weight");
  51. var b = tf.Variable(-0.73f, name: "bias");
  52. using var g = tf.GradientTape();
  53. var pred = W * X + b;
  54. var test = tf.slice(pred, new[] { 0 }, pred.shape);
  55. var gradients = g.gradient(test, (W, b));
  56. Assert.AreNotEqual(gradients.Item1, null);
  57. Assert.AreNotEqual(gradients.Item2, null);
  58. }
  59. [TestMethod]
  60. public void GradientConcatTest()
  61. {
  62. var X = tf.zeros(new TensorShape(10));
  63. var W = tf.Variable(-0.06f, name: "weight");
  64. var b = tf.Variable(-0.73f, name: "bias");
  65. var test = tf.concat(new Tensor[] { W, b }, 0);
  66. using var g = tf.GradientTape();
  67. var pred = test[0] * X + test[1];
  68. var gradients = g.gradient(pred, (W, b));
  69. Assert.AreEqual((float)gradients.Item1, 0);
  70. Assert.AreEqual((float)gradients.Item2, 10);
  71. }
  72. }
  73. }