You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientDescentOptimizerTests.cs 2.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Linq;
  4. using System.Runtime.Intrinsics.X86;
  5. using System.Security.AccessControl;
  6. using Tensorflow.NumPy;
  7. using TensorFlowNET.UnitTest;
  8. using static Tensorflow.Binding;
  9. namespace Tensorflow.Keras.UnitTest.Optimizers
  10. {
  11. [TestClass]
  12. public class GradientDescentOptimizerTest : PythonTest
  13. {
  14. private void TestBasicGeneric<T>() where T : struct
  15. {
  16. var dtype = Type.GetTypeCode(typeof(T)) switch
  17. {
  18. TypeCode.Single => np.float32,
  19. TypeCode.Double => np.float64,
  20. _ => throw new NotImplementedException(),
  21. };
  22. // train.GradientDescentOptimizer is V1 only API.
  23. tf.Graph().as_default();
  24. using (self.cached_session())
  25. {
  26. var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
  27. var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
  28. var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
  29. var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
  30. var optimizer = tf.train.GradientDescentOptimizer(3.0f);
  31. var grads_and_vars = new[] {
  32. Tuple.Create(grads0, var0 as IVariableV1),
  33. Tuple.Create(grads1, var1 as IVariableV1)
  34. };
  35. var sgd_op = optimizer.apply_gradients(grads_and_vars);
  36. var global_variables = variables.global_variables_initializer();
  37. self.evaluate<T>(global_variables);
  38. // Fetch params to validate initial values
  39. // TODO: use self.evaluate<T[]> instead of self.evaluate<double[]>
  40. self.assertAllCloseAccordingToType(new double[] { 1.0, 2.0 }, self.evaluate<double[]>(var0));
  41. self.assertAllCloseAccordingToType(new double[] { 3.0, 4.0 }, self.evaluate<double[]>(var1));
  42. // Run 1 step of sgd
  43. sgd_op.run();
  44. // Validate updated params
  45. self.assertAllCloseAccordingToType(
  46. new double[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
  47. self.evaluate<double[]>(var0));
  48. self.assertAllCloseAccordingToType(
  49. new double[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
  50. self.evaluate<double[]>(var1));
  51. // TODO: self.assertEqual(0, len(optimizer.variables()));
  52. }
  53. }
  54. [TestMethod]
  55. public void TestBasic()
  56. {
  57. //TODO: add np.half
  58. TestBasicGeneric<float>();
  59. TestBasicGeneric<double>();
  60. }
  61. }
  62. }