You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AdamOptimizer.cs 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Tensorflow.Framework;
  6. using static Tensorflow.Python;
  7. namespace Tensorflow.Train
  8. {
  9. /// <summary>
  10. /// Optimizer that implements the Adam algorithm.
  11. /// http://arxiv.org/abs/1412.6980
  12. /// </summary>
  13. public class AdamOptimizer : Optimizer
  14. {
  15. float _beta1;
  16. float _beta2;
  17. float _epsilon;
  18. Tensor _lr_t, _beta1_t, _beta2_t, _epsilon_t;
  19. public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam")
  20. : base(learning_rate, use_locking, name)
  21. {
  22. _beta1 = beta1;
  23. _beta2 = beta2;
  24. _epsilon = epsilon;
  25. }
  26. public override Operation _apply_sparse(IndexedSlices grad, RefVariable var)
  27. {
  28. return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) =>
  29. {
  30. return state_ops.scatter_add(x, i, v, use_locking: _use_locking);
  31. });
  32. }
  33. private Operation _apply_sparse_shared(Tensor grad, RefVariable var, Tensor indices, Func<RefVariable, Tensor, Tensor, Tensor> scatter_add)
  34. {
  35. var (beta1_power_v, beta2_power_v) = _get_beta_accumulators();
  36. Tensor beta1_power = math_ops.cast(beta1_power_v, var.dtype.as_base_dtype());
  37. Tensor beta2_power = math_ops.cast(beta2_power_v, var.dtype.as_base_dtype());
  38. var lr_t = math_ops.cast(_lr_t, var.dtype.as_base_dtype());
  39. var beta1_t = math_ops.cast(_beta1_t, var.dtype.as_base_dtype());
  40. var beta2_t = math_ops.cast(_beta2_t, var.dtype.as_base_dtype());
  41. var epsilon_t = math_ops.cast(_epsilon_t, var.dtype.as_base_dtype());
  42. var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power));
  43. var m = get_slot(var, "m");
  44. var m_scaled_g_values = grad * (1 - beta1_t);
  45. var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking);
  46. with(ops.control_dependencies(new[] { m_t }), delegate
  47. {
  48. m_t = scatter_add(m, indices, m_scaled_g_values);
  49. });
  50. var v = get_slot(var, "v");
  51. var v_scaled_g_values = (grad * grad) * (1 - beta2_t);
  52. var v_t = state_ops.assign(v, v * beta2_t, use_locking: _use_locking);
  53. with(ops.control_dependencies(new[] { v_t }), delegate
  54. {
  55. v_t = scatter_add(v, indices, v_scaled_g_values);
  56. });
  57. var v_sqrt = math_ops.sqrt(v_t);
  58. var var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking: _use_locking);
  59. return control_flow_ops.group(new[] { var_update, m_t, v_t });
  60. }
  61. protected override void _create_slots(RefVariable[] var_list)
  62. {
  63. var first_var = var_list.OrderBy(x => x.name).First();
  64. _create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var);
  65. _create_non_slot_variable(initial_value: _beta2, name: "beta2_power", colocate_with: first_var);
  66. // Create slots for the first and second moments.
  67. foreach(var v in var_list)
  68. {
  69. _zeros_slot(v, "m", Name);
  70. _zeros_slot(v, "v", Name);
  71. }
  72. }
  73. private (RefVariable, RefVariable) _get_beta_accumulators()
  74. {
  75. ops.init_scope();
  76. var graph = ops.get_default_graph();
  77. return (_get_non_slot_variable("beta1_power", graph: graph),
  78. _get_non_slot_variable("beta2_power", graph: graph));
  79. }
  80. public override void _prepare()
  81. {
  82. var lr = _call_if_callable(_lr);
  83. var beta1 = _call_if_callable(_beta1);
  84. var beta2 = _call_if_callable(_beta2);
  85. var epsilon = _call_if_callable(_epsilon);
  86. _lr_t = ops.convert_to_tensor(lr, name: "learning_rate");
  87. _beta1_t = ops.convert_to_tensor(beta1, name: "beta1");
  88. _beta2_t = ops.convert_to_tensor(beta2, name: "beta2");
  89. _epsilon_t = ops.convert_to_tensor(epsilon, name: "epsilon");
  90. }
  91. }
  92. }