You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientTest.cs 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using static Tensorflow.Binding;
  3. namespace TensorFlowNET.UnitTest.ManagedAPI
  4. {
  5. [TestClass]
  6. public class GradientTest
  7. {
  8. [TestMethod]
  9. public void GradientFloatTest()
  10. {
  11. var x = tf.Variable(3.0, dtype: tf.float32);
  12. using var tape = tf.GradientTape();
  13. var y = tf.square(x);
  14. var y_grad = tape.gradient(y, x);
  15. Assert.AreEqual(9.0f, (float)y);
  16. }
  17. [TestMethod]
  18. public void GradientDefaultTest()
  19. {
  20. var x = tf.Variable(3.0);
  21. using var tape = tf.GradientTape();
  22. var y = tf.square(x);
  23. var y_grad = tape.gradient(y, x);
  24. Assert.AreEqual(9.0, (double)y);
  25. }
  26. [TestMethod]
  27. public void GradientDoubleTest()
  28. {
  29. var x = tf.Variable(3.0, dtype: tf.float64);
  30. using var tape = tf.GradientTape();
  31. var y = tf.square(x);
  32. var y_grad = tape.gradient(y, x);
  33. Assert.AreEqual(9.0, (double)y);
  34. }
  35. [TestMethod]
  36. public void GradientOperatorMulTest()
  37. {
  38. var x = tf.constant(0f);
  39. var w = tf.Variable(new float[] { 1, 1 });
  40. using var gt = tf.GradientTape();
  41. var y = x * w;
  42. var gr = gt.gradient(y, w);
  43. Assert.AreEqual(new float[] { 0, 0 }, gr.numpy());
  44. }
  45. }
  46. }