You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientTest.cs 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System.Linq;
  4. using Tensorflow;
  5. namespace TensorFlowNET.UnitTest
  6. {
  7. [TestClass]
  8. public class GradientTest
  9. {
  10. [TestMethod]
  11. public void Gradients()
  12. {
  13. var graph = tf.Graph().as_default();
  14. var a = tf.constant(0.0);
  15. var b = 2.0 * a;
  16. Assert.AreEqual(b.name, "mul:0");
  17. Assert.AreEqual(b.op.inputs[0].name, "mul/x:0");
  18. Assert.AreEqual(b.op.inputs[1].name, "Const:0");
  19. var ys = a + b;
  20. Assert.AreEqual(ys.name, "add:0");
  21. Assert.AreEqual(ys.op.inputs[0].name, "Const:0");
  22. Assert.AreEqual(ys.op.inputs[1].name, "mul:0");
  23. var g = tf.gradients(ys, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b });
  24. Assert.AreEqual(g[0].name, "gradients/Fill:0");
  25. Assert.AreEqual(g[1].name, "gradients/Fill:0");
  26. }
  27. [TestMethod]
  28. public void StridedSlice()
  29. {
  30. var graph = tf.Graph().as_default();
  31. var t = tf.constant(np.array(new int[,,]
  32. {
  33. {
  34. { 11, 12, 13 },
  35. { 21, 22, 23 }
  36. },
  37. {
  38. { 31, 32, 33 },
  39. { 41, 42, 43 }
  40. },
  41. {
  42. { 51, 52, 53 },
  43. { 61, 62, 63 }
  44. }
  45. }));
  46. var slice = tf.strided_slice(t,
  47. begin: new[] { 0, 0, 0 },
  48. end: new[] { 3, 2, 3 },
  49. strides: new[] { 2, 2, 2 });
  50. var y = slice + slice;
  51. var g = tf.gradients(y, new Tensor[] { slice, slice });
  52. var r = slice.eval();
  53. Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 }));
  54. Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData<int>(), new[] { 11, 13 }));
  55. Assert.IsTrue(Enumerable.SequenceEqual(r[1].GetData<int>(), new[] { 51, 53 }));
  56. }
  57. }
  58. }