You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientDescentOptimizerTests.cs 4.9 kB

1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using Tensorflow.NumPy;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.Training
  7. {
  8. [TestClass]
  9. public class GradientDescentOptimizerTest : PythonTest
  10. {
  11. private static TF_DataType GetTypeForNumericType<T>() where T : struct
  12. {
  13. return Type.GetTypeCode(typeof(T)) switch
  14. {
  15. TypeCode.Single => np.float32,
  16. TypeCode.Double => np.float64,
  17. _ => throw new NotImplementedException(),
  18. };
  19. }
  20. private void TestBasic<T>() where T : struct
  21. {
  22. var dtype = GetTypeForNumericType<T>();
  23. // train.GradientDescentOptimizer is V1 only API.
  24. tf.Graph().as_default();
  25. using (var sess = self.cached_session())
  26. {
  27. var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
  28. var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
  29. var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
  30. var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
  31. var optimizer = tf.train.GradientDescentOptimizer(3.0f);
  32. var grads_and_vars = new[] {
  33. Tuple.Create(grads0, var0 as IVariableV1),
  34. Tuple.Create(grads1, var1 as IVariableV1)
  35. };
  36. var sgd_op = optimizer.apply_gradients(grads_and_vars);
  37. var global_variables = tf.global_variables_initializer();
  38. sess.run(global_variables);
  39. var initialVar0 = sess.run(var0);
  40. var initialVar1 = sess.run(var1);
  41. // Fetch params to validate initial values
  42. self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
  43. self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
  44. // Run 1 step of sgd
  45. sgd_op.run();
  46. // Validate updated params
  47. self.assertAllCloseAccordingToType(
  48. new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
  49. self.evaluate<T[]>(var0));
  50. self.assertAllCloseAccordingToType(
  51. new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
  52. self.evaluate<T[]>(var1));
  53. // TODO: self.assertEqual(0, len(optimizer.variables()));
  54. }
  55. }
  56. [TestMethod]
  57. public void TestBasic()
  58. {
  59. //TODO: add np.half
  60. TestBasic<float>();
  61. TestBasic<double>();
  62. }
  63. private void TestTensorLearningRate<T>() where T : struct
  64. {
  65. var dtype = GetTypeForNumericType<T>();
  66. // train.GradientDescentOptimizer is V1 only API.
  67. tf.Graph().as_default();
  68. using (var sess = self.cached_session())
  69. {
  70. var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
  71. var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
  72. var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
  73. var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
  74. var lrate = constant_op.constant(3.0);
  75. var grads_and_vars = new[] {
  76. Tuple.Create(grads0, var0 as IVariableV1),
  77. Tuple.Create(grads1, var1 as IVariableV1)
  78. };
  79. var sgd_op = tf.train.GradientDescentOptimizer(lrate)
  80. .apply_gradients(grads_and_vars);
  81. var global_variables = tf.global_variables_initializer();
  82. sess.run(global_variables);
  83. var initialVar0 = sess.run(var0);
  84. var initialVar1 = sess.run(var1);
  85. // Fetch params to validate initial values
  86. self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
  87. self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
  88. // Run 1 step of sgd
  89. sgd_op.run();
  90. // Validate updated params
  91. self.assertAllCloseAccordingToType(
  92. new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
  93. self.evaluate<T[]>(var0));
  94. self.assertAllCloseAccordingToType(
  95. new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
  96. self.evaluate<T[]>(var1));
  97. // TODO: self.assertEqual(0, len(optimizer.variables()));
  98. }
  99. }
  100. [TestMethod]
  101. public void TestTensorLearningRate()
  102. {
  103. //TODO: add np.half
  104. TestTensorLearningRate<float>();
  105. TestTensorLearningRate<double>();
  106. }
  107. }
  108. }