You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientDescentOptimizerTests.cs 5.0 kB

1 year ago
1 year ago
1 year ago
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. using Microsoft.VisualStudio.TestPlatform.Utilities;
  2. using Microsoft.VisualStudio.TestTools.UnitTesting;
  3. using System;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using Tensorflow.NumPy;
  7. using TensorFlowNET.UnitTest;
  8. using static Tensorflow.Binding;
  9. namespace Tensorflow.Keras.UnitTest.Optimizers
  10. {
  11. [TestClass]
  12. public class GradientDescentOptimizerTest : PythonTest
  13. {
  14. private static TF_DataType GetTypeForNumericType<T>() where T : struct
  15. {
  16. return Type.GetTypeCode(typeof(T)) switch
  17. {
  18. TypeCode.Single => np.float32,
  19. TypeCode.Double => np.float64,
  20. _ => throw new NotImplementedException(),
  21. };
  22. }
  23. private void TestBasic<T>() where T : struct
  24. {
  25. var dtype = GetTypeForNumericType<T>();
  26. // train.GradientDescentOptimizer is V1 only API.
  27. tf.Graph().as_default();
  28. using (var sess = self.cached_session())
  29. {
  30. var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
  31. var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
  32. var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
  33. var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
  34. var optimizer = tf.train.GradientDescentOptimizer(3.0f);
  35. var grads_and_vars = new[] {
  36. Tuple.Create(grads0, var0 as IVariableV1),
  37. Tuple.Create(grads1, var1 as IVariableV1)
  38. };
  39. var sgd_op = optimizer.apply_gradients(grads_and_vars);
  40. var global_variables = tf.global_variables_initializer();
  41. sess.run(global_variables);
  42. var initialVar0 = sess.run(var0);
  43. var initialVar1 = sess.run(var1);
  44. // Fetch params to validate initial values
  45. self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
  46. self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
  47. // Run 1 step of sgd
  48. sgd_op.run();
  49. // Validate updated params
  50. self.assertAllCloseAccordingToType(
  51. new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
  52. self.evaluate<T[]>(var0));
  53. self.assertAllCloseAccordingToType(
  54. new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
  55. self.evaluate<T[]>(var1));
  56. // TODO: self.assertEqual(0, len(optimizer.variables()));
  57. }
  58. }
  59. [TestMethod]
  60. public void TestBasic()
  61. {
  62. //TODO: add np.half
  63. TestBasic<float>();
  64. TestBasic<double>();
  65. }
  66. private void TestTensorLearningRate<T>() where T : struct
  67. {
  68. var dtype = GetTypeForNumericType<T>();
  69. // train.GradientDescentOptimizer is V1 only API.
  70. tf.Graph().as_default();
  71. using (var sess = self.cached_session())
  72. {
  73. var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
  74. var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
  75. var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
  76. var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
  77. var lrate = constant_op.constant(3.0);
  78. var grads_and_vars = new[] {
  79. Tuple.Create(grads0, var0 as IVariableV1),
  80. Tuple.Create(grads1, var1 as IVariableV1)
  81. };
  82. var sgd_op = tf.train.GradientDescentOptimizer(lrate)
  83. .apply_gradients(grads_and_vars);
  84. var global_variables = tf.global_variables_initializer();
  85. sess.run(global_variables);
  86. var initialVar0 = sess.run(var0);
  87. var initialVar1 = sess.run(var1);
  88. // Fetch params to validate initial values
  89. self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
  90. self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
  91. // Run 1 step of sgd
  92. sgd_op.run();
  93. // Validate updated params
  94. self.assertAllCloseAccordingToType(
  95. new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
  96. self.evaluate<T[]>(var0));
  97. self.assertAllCloseAccordingToType(
  98. new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
  99. self.evaluate<T[]>(var1));
  100. // TODO: self.assertEqual(0, len(optimizer.variables()));
  101. }
  102. }
  103. [TestMethod]
  104. public void TestTensorLearningRate()
  105. {
  106. //TODO: add np.half
  107. TestTensorLearningRate<float>();
  108. TestTensorLearningRate<double>();
  109. }
  110. }
  111. }