You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

amsgrad.py 8.8 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. """AMSGrad Implementation based on the paper: "On the Convergence of Adam and Beyond" (ICLR 2018)
  4. Article Link: https://openreview.net/pdf?id=ryQu7f-RZ
  5. Original Implementation by: https://github.com/taki0112/AMSGrad-Tensorflow
  6. """
  7. from tensorflow.python.eager import context
  8. from tensorflow.python.framework import ops
  9. from tensorflow.python.ops import (control_flow_ops, math_ops, resource_variable_ops, state_ops, variable_scope)
  10. from tensorflow.python.training import optimizer
  11. class AMSGrad(optimizer.Optimizer):
  12. """Implementation of the AMSGrad optimization algorithm.
  13. See: `On the Convergence of Adam and Beyond - [Reddi et al., 2018] <https://openreview.net/pdf?id=ryQu7f-RZ>`__.
  14. Parameters
  15. ----------
  16. learning_rate: float
  17. A Tensor or a floating point value. The learning rate.
  18. beta1: float
  19. A float value or a constant float tensor.
  20. The exponential decay rate for the 1st moment estimates.
  21. beta2: float
  22. A float value or a constant float tensor.
  23. The exponential decay rate for the 2nd moment estimates.
  24. epsilon: float
  25. A small constant for numerical stability.
  26. This epsilon is "epsilon hat" in the Kingma and Ba paper
  27. (in the formula just before Section 2.1), not the epsilon in Algorithm 1 of the paper.
  28. use_locking: bool
  29. If True use locks for update operations.
  30. name: str
  31. Optional name for the operations created when applying gradients.
  32. Defaults to "AMSGrad".
  33. """
  34. def __init__(self, learning_rate=0.01, beta1=0.9, beta2=0.99, epsilon=1e-8, use_locking=False, name="AMSGrad"):
  35. """Construct a new Adam optimizer."""
  36. super(AMSGrad, self).__init__(use_locking, name)
  37. self._lr = learning_rate
  38. self._beta1 = beta1
  39. self._beta2 = beta2
  40. self._epsilon = epsilon
  41. self._lr_t = None
  42. self._beta1_t = None
  43. self._beta2_t = None
  44. self._epsilon_t = None
  45. self._beta1_power = None
  46. self._beta2_power = None
  47. def _create_slots(self, var_list):
  48. first_var = min(var_list, key=lambda x: x.name)
  49. create_new = self._beta1_power is None
  50. if not create_new and context.in_graph_mode():
  51. create_new = (self._beta1_power.graph is not first_var.graph)
  52. if create_new:
  53. with ops.colocate_with(first_var):
  54. self._beta1_power = variable_scope.variable(self._beta1, name="beta1_power", trainable=False)
  55. self._beta2_power = variable_scope.variable(self._beta2, name="beta2_power", trainable=False)
  56. # Create slots for the first and second moments.
  57. for v in var_list:
  58. self._zeros_slot(v, "m", self._name)
  59. self._zeros_slot(v, "v", self._name)
  60. self._zeros_slot(v, "vhat", self._name)
  61. def _prepare(self):
  62. self._lr_t = ops.convert_to_tensor(self._lr)
  63. self._beta1_t = ops.convert_to_tensor(self._beta1)
  64. self._beta2_t = ops.convert_to_tensor(self._beta2)
  65. self._epsilon_t = ops.convert_to_tensor(self._epsilon)
  66. def _apply_dense(self, grad, var):
  67. beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
  68. beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
  69. lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
  70. beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
  71. beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
  72. epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
  73. lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
  74. # m_t = beta1 * m + (1 - beta1) * g_t
  75. m = self.get_slot(var, "m")
  76. m_scaled_g_values = grad * (1 - beta1_t)
  77. m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking)
  78. # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
  79. v = self.get_slot(var, "v")
  80. v_scaled_g_values = (grad * grad) * (1 - beta2_t)
  81. v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking)
  82. # amsgrad
  83. vhat = self.get_slot(var, "vhat")
  84. vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat))
  85. v_sqrt = math_ops.sqrt(vhat_t)
  86. var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
  87. return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t])
  88. def _resource_apply_dense(self, grad, var):
  89. var = var.handle
  90. beta1_power = math_ops.cast(self._beta1_power, grad.dtype.base_dtype)
  91. beta2_power = math_ops.cast(self._beta2_power, grad.dtype.base_dtype)
  92. lr_t = math_ops.cast(self._lr_t, grad.dtype.base_dtype)
  93. beta1_t = math_ops.cast(self._beta1_t, grad.dtype.base_dtype)
  94. beta2_t = math_ops.cast(self._beta2_t, grad.dtype.base_dtype)
  95. epsilon_t = math_ops.cast(self._epsilon_t, grad.dtype.base_dtype)
  96. lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
  97. # m_t = beta1 * m + (1 - beta1) * g_t
  98. m = self.get_slot(var, "m").handle
  99. m_scaled_g_values = grad * (1 - beta1_t)
  100. m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking)
  101. # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
  102. v = self.get_slot(var, "v").handle
  103. v_scaled_g_values = (grad * grad) * (1 - beta2_t)
  104. v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking)
  105. # amsgrad
  106. vhat = self.get_slot(var, "vhat").handle
  107. vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat))
  108. v_sqrt = math_ops.sqrt(vhat_t)
  109. var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
  110. return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t])
  111. def _apply_sparse_shared(self, grad, var, indices, scatter_add):
  112. beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
  113. beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
  114. lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
  115. beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
  116. beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
  117. epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
  118. lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
  119. # m_t = beta1 * m + (1 - beta1) * g_t
  120. m = self.get_slot(var, "m")
  121. m_scaled_g_values = grad * (1 - beta1_t)
  122. m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
  123. with ops.control_dependencies([m_t]):
  124. m_t = scatter_add(m, indices, m_scaled_g_values)
  125. # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
  126. v = self.get_slot(var, "v")
  127. v_scaled_g_values = (grad * grad) * (1 - beta2_t)
  128. v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
  129. with ops.control_dependencies([v_t]):
  130. v_t = scatter_add(v, indices, v_scaled_g_values)
  131. # amsgrad
  132. vhat = self.get_slot(var, "vhat")
  133. vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat))
  134. v_sqrt = math_ops.sqrt(vhat_t)
  135. var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
  136. return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t])
  137. def _apply_sparse(self, grad, var):
  138. return self._apply_sparse_shared(
  139. grad.values,
  140. var,
  141. grad.indices,
  142. lambda x, i, v: state_ops.
  143. scatter_add( # pylint: disable=g-long-lambda
  144. x, i, v, use_locking=self._use_locking
  145. )
  146. )
  147. def _resource_scatter_add(self, x, i, v):
  148. with ops.control_dependencies([resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
  149. return x.value()
  150. def _resource_apply_sparse(self, grad, var, indices):
  151. return self._apply_sparse_shared(grad, var, indices, self._resource_scatter_add)
  152. def _finish(self, update_ops, name_scope):
  153. # Update the power accumulators.
  154. with ops.control_dependencies(update_ops):
  155. with ops.colocate_with(self._beta1_power):
  156. update_beta1 = self._beta1_power.assign(
  157. self._beta1_power * self._beta1_t, use_locking=self._use_locking
  158. )
  159. update_beta2 = self._beta2_power.assign(
  160. self._beta2_power * self._beta2_t, use_locking=self._use_locking
  161. )
  162. return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], name=name_scope)

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.