You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientTest.cs 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System.Linq;
  4. using Tensorflow;
  5. using static Tensorflow.Python;
  6. namespace TensorFlowNET.UnitTest
  7. {
  8. [TestClass]
  9. public class GradientTest
  10. {
  11. [TestMethod]
  12. public void Gradients()
  13. {
  14. var graph = tf.Graph().as_default();
  15. var a = tf.constant(0.0);
  16. var b = 2.0 * a;
  17. Assert.AreEqual(b.name, "mul:0");
  18. Assert.AreEqual(b.op.inputs[0].name, "mul/x:0");
  19. Assert.AreEqual(b.op.inputs[1].name, "Const:0");
  20. var ys = a + b;
  21. Assert.AreEqual(ys.name, "add:0");
  22. Assert.AreEqual(ys.op.inputs[0].name, "Const:0");
  23. Assert.AreEqual(ys.op.inputs[1].name, "mul:0");
  24. var g = tf.gradients(ys, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b });
  25. Assert.AreEqual(g[0].name, "gradients/Fill:0");
  26. Assert.AreEqual(g[1].name, "gradients/Fill:0");
  27. }
  28. [TestMethod]
  29. public void Gradient2x()
  30. {
  31. var graph = tf.Graph().as_default();
  32. using (var sess = tf.Session(graph))
  33. {
  34. var x = tf.constant(7.0f);
  35. var y = x * x * tf.constant(0.1f);
  36. var grad = tf.gradients(y, x);
  37. Assert.AreEqual(grad[0].name, "gradients/AddN:0");
  38. float r = sess.run(grad[0])[0];
  39. Assert.AreEqual(r, 1.4f);
  40. }
  41. }
  42. [TestMethod]
  43. public void Gradient3x()
  44. {
  45. var graph = tf.Graph().as_default();
  46. tf_with(tf.Session(graph), sess => {
  47. var x = tf.constant(7.0f);
  48. var y = x * x * x * tf.constant(0.1f);
  49. var grad = tf.gradients(y, x);
  50. Assert.AreEqual(grad[0].name, "gradients/AddN:0");
  51. float r = sess.run(grad[0])[0];
  52. Assert.AreEqual(r, 14.700001f);
  53. });
  54. }
  55. [TestMethod]
  56. public void StridedSlice()
  57. {
  58. var graph = tf.Graph().as_default();
  59. var t = tf.constant(np.array(new int[,,]
  60. {
  61. {
  62. { 11, 12, 13 },
  63. { 21, 22, 23 }
  64. },
  65. {
  66. { 31, 32, 33 },
  67. { 41, 42, 43 }
  68. },
  69. {
  70. { 51, 52, 53 },
  71. { 61, 62, 63 }
  72. }
  73. }));
  74. var slice = tf.strided_slice(t,
  75. begin: new[] { 0, 0, 0 },
  76. end: new[] { 3, 2, 3 },
  77. strides: new[] { 2, 2, 2 });
  78. var y = slice + slice;
  79. var g = tf.gradients(y, new Tensor[] { slice, slice });
  80. using (var sess = tf.Session(graph))
  81. {
  82. var r = sess.run(slice)[0];
  83. Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 }));
  84. Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData<int>(), new[] { 11, 13 }));
  85. Assert.IsTrue(Enumerable.SequenceEqual(r[1].GetData<int>(), new[] { 51, 53 }));
  86. }
  87. }
  88. }
  89. }