You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientTapeTest.cs 3.3 kB

Performance optimization, refactoring and revamping. (#362) * Refactored DisposableObject * Added different build directory for TensorflowNET.Examples.GPU * _FetchHandler: Switched to NPTypeCode * gfile.cs, Walk(...): Handle case when directory top doesn't exist. * Tensor.Creation: Perf-opted when creating tensor from NDArray of string * Graph.cs: refactor and added docs * Tensor.Creation.cs: perf-ops * Tensor.Explicit.cs: perf-ops * Copied globals.regen from NumSharp - Added supported_numericals_TF_DataType * Tensor perf-ops and cleanup, Revamped dtypes.cs, some renames. - Cleanup and docs to all Tensor.cs files - Changed all uses of System.Convert to NumSharp.Utilities.Converts - Added all missing types in dtypes.cs - Renamed tensor.Data<T> to tensor.ToArray<T>, added obsolete message - Renamed tensor.Data() to tensor.BufferToArray(), added obsolete message - Made GraphKeys to use const string instead allocating strings at every use of GraphKeys. * Tensor: Added guards for explicit casts. * Tensor: Added explicit cast to string * Tensor.ToArray<T>(): Added support for cases when tensor is scalar. * Tensor.BufferToArray(): Fixed to use long instead of int. * TensorShape: Revamped and documented. * BaseSession: Added Session.run(ITensorOrOperation fetche, params FeedItem[] feed_dict) * Tensor: renamed _dtype to _override_dtype - Fixed all locations _dtype is used incorrectly. * Fixed unit tests * Tensor.Operations: Reverted commit * DisposableObject: sorted internal_dispose to properly handle Dispose() calls * Tensor.DisposeUnmanagedResources: Nullify _handle after delete. * TensorShape.this[...]: fixed guard check. * DisposableObject #362
6 years ago
Performance optimization, refactoring and revamping. (#362) * Refactored DisposableObject * Added different build directory for TensorflowNET.Examples.GPU * _FetchHandler: Switched to NPTypeCode * gfile.cs, Walk(...): Handle case when directory top doesn't exist. * Tensor.Creation: Perf-opted when creating tensor from NDArray of string * Graph.cs: refactor and added docs * Tensor.Creation.cs: perf-ops * Tensor.Explicit.cs: perf-ops * Copied globals.regen from NumSharp - Added supported_numericals_TF_DataType * Tensor perf-ops and cleanup, Revamped dtypes.cs, some renames. - Cleanup and docs to all Tensor.cs files - Changed all uses of System.Convert to NumSharp.Utilities.Converts - Added all missing types in dtypes.cs - Renamed tensor.Data<T> to tensor.ToArray<T>, added obsolete message - Renamed tensor.Data() to tensor.BufferToArray(), added obsolete message - Made GraphKeys to use const string instead allocating strings at every use of GraphKeys. * Tensor: Added guards for explicit casts. * Tensor: Added explicit cast to string * Tensor.ToArray<T>(): Added support for cases when tensor is scalar. * Tensor.BufferToArray(): Fixed to use long instead of int. * TensorShape: Revamped and documented. * BaseSession: Added Session.run(ITensorOrOperation fetche, params FeedItem[] feed_dict) * Tensor: renamed _dtype to _override_dtype - Fixed all locations _dtype is used incorrectly. * Fixed unit tests * Tensor.Operations: Reverted commit * DisposableObject: sorted internal_dispose to properly handle Dispose() calls * Tensor.DisposeUnmanagedResources: Nullify _handle after delete. * TensorShape.this[...]: fixed guard check. * DisposableObject #362
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System.Linq;
  4. using Tensorflow;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.Gradient
  7. {
  8. [TestClass]
  9. public class GradientTapeTest
  10. {
  11. [TestMethod]
  12. public void GradientTape()
  13. {
  14. var x = tf.ones((2, 2));
  15. using (var t = tf.GradientTape())
  16. {
  17. t.watch(x);
  18. }
  19. }
  20. [TestMethod]
  21. public void Gradients()
  22. {
  23. var a = tf.constant(0.0);
  24. var b = 2.0 * a;
  25. //Assert.AreEqual(b.name, "mul:0");
  26. //Assert.AreEqual(b.op.inputs[0].name, "mul/x:0");
  27. //Assert.AreEqual(b.op.inputs[1].name, "Const:0");
  28. var ys = a + b;
  29. //Assert.AreEqual(ys.name, "add:0");
  30. //Assert.AreEqual(ys.op.inputs[0].name, "Const:0");
  31. //Assert.AreEqual(ys.op.inputs[1].name, "mul:0");
  32. //var g = tf.gradients(ys, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b });
  33. //Assert.AreEqual(g[0].name, "gradients/Fill:0");
  34. //Assert.AreEqual(g[1].name, "gradients/Fill:0");
  35. }
  36. [TestMethod]
  37. public void Gradient2x()
  38. {
  39. var x = tf.constant(7.0f);
  40. var y = x * x * tf.constant(0.1f);
  41. //var grad = tf.gradients(y, x);
  42. //Assert.AreEqual(grad[0].name, "gradients/AddN:0");
  43. //float r = sess.run(grad[0]);
  44. //Assert.AreEqual(r, 1.4f);
  45. }
  46. [TestMethod]
  47. public void Gradient3x()
  48. {
  49. var graph = tf.Graph().as_default();
  50. tf_with(tf.Session(graph), sess => {
  51. var x = tf.constant(7.0f);
  52. var y = x * x * x * tf.constant(0.1f);
  53. var grad = tf.gradients(y, x);
  54. Assert.AreEqual(grad[0].name, "gradients/AddN:0");
  55. float r = sess.run(grad[0]);
  56. Assert.AreEqual(r, 14.700001f);
  57. });
  58. }
  59. [TestMethod]
  60. public void StridedSlice()
  61. {
  62. var graph = tf.Graph().as_default();
  63. var t = tf.constant(np.array(new int[,,]
  64. {
  65. {
  66. { 11, 12, 13 },
  67. { 21, 22, 23 }
  68. },
  69. {
  70. { 31, 32, 33 },
  71. { 41, 42, 43 }
  72. },
  73. {
  74. { 51, 52, 53 },
  75. { 61, 62, 63 }
  76. }
  77. }));
  78. var slice = tf.strided_slice(t,
  79. begin: new[] { 0, 0, 0 },
  80. end: new[] { 3, 2, 3 },
  81. strides: new[] { 2, 2, 2 });
  82. var y = slice + slice;
  83. var g = tf.gradients(y, new Tensor[] { slice, slice });
  84. using (var sess = tf.Session(graph))
  85. {
  86. var r = sess.run(slice);
  87. Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 }));
  88. Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData<int>(), new[] { 11, 13 }));
  89. Assert.IsTrue(Enumerable.SequenceEqual(r[1].GetData<int>(), new[] { 51, 53 }));
  90. }
  91. }
  92. }
  93. }