You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore_optimizers.py 4.2 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. from __future__ import absolute_import, division, print_function
  4. from mindspore.nn import optim as optimizer
  5. import mindspore as ms
  6. from mindspore.nn import Cell
  7. __all__ = ['Adadelta', 'Adagrad', 'Adam', 'Adamax', 'Ftrl', 'Nadam', 'RMSprop', 'SGD', 'Momentum', 'Lamb', 'LARS']
  8. class Adadelta(Cell):
  9. def __init__(self):
  10. pass
  11. def app_gradients(self):
  12. raise Exception('Adadelta optimizer function not implemented')
  13. class Adagrad(Cell):
  14. def __init__(self):
  15. pass
  16. def apply_gradients(self):
  17. raise Exception('Adagrad optimizer function not implemented')
  18. class Adam(Cell):
  19. def __init__(
  20. self,
  21. learning_rate=0.001,
  22. beta_1=0.9,
  23. beta_2=0.999,
  24. epsilon=1e-8,
  25. ):
  26. self.adam = optimizer.Adam
  27. self.learn_rate = learning_rate
  28. self.beta_1 = beta_1
  29. self.beta_2 = beta_2
  30. self.epsilon = epsilon
  31. def apply_gradients(self, grads_and_vars):
  32. grads, vars = list(zip(*grads_and_vars))
  33. optimizer_adam = self.adam(
  34. vars, learning_rate=self.learn_rate, beta1=self.beta_1, beta2=self.beta_2, eps=self.epsilon
  35. )
  36. optimizer_adam(grads)
  37. class Adamax(Cell):
  38. def __init__(self):
  39. pass
  40. def apply_gradients(self):
  41. raise Exception('Adamax optimizer function not implemented')
  42. class Ftrl(Cell):
  43. def __init__(self):
  44. pass
  45. def apply_gradients(self):
  46. raise Exception('Ftrl optimizer function not implemented')
  47. class Nadam(Cell):
  48. def __init__(self):
  49. pass
  50. def apply_gradients(self):
  51. raise Exception('Nadam optimizer function not implemented')
  52. class RMSprop(Cell):
  53. def __init__(self):
  54. pass
  55. def apply_gradients(self):
  56. raise Exception('RMSprop optimizer function not implemented')
  57. class RMSprop(Cell):
  58. def __init__(self):
  59. pass
  60. def apply_gradients(self):
  61. raise Exception('RMSprop optimizer function not implemented')
  62. class SGD(Cell):
  63. def __init__(self, learning_rate, momentum):
  64. self.sgd = optimizer.SGD
  65. self.learn_rate = learning_rate
  66. self.momentum = momentum
  67. def apply_gradients(self, grads_and_vars):
  68. grads, vars = list(zip(*grads_and_vars))
  69. optimizer_sgd = self.sgd(vars, learning_rate=self.learn_rate, momentum=self.momentum)
  70. optimizer_sgd(grads)
  71. class Momentum(Cell):
  72. def __init__(self, learning_rate, momentum):
  73. self.mom = optimizer.Momentum
  74. self.learn_rate = learning_rate
  75. self.momentum = momentum
  76. def apply_gradients(self, grads_and_vars, **kwargs):
  77. grads, vars = list(zip(*grads_and_vars))
  78. optimizer_mom = self.mom(vars, learning_rate=self.learn_rate, momentum=self.momentum, **kwargs)
  79. optimizer_mom(grads)
  80. class Lamb(Cell):
  81. def __init__(
  82. self, decay_steps, warmup_steps=0, start_learning_rate=0.1, end_learning_rate=0.0001, power=1.0, beta1=0.9,
  83. beta2=0.999, eps=1e-06, weight_decay=0.0
  84. ):
  85. self.lamb = optimizer.Lamb
  86. self.decay_steps = decay_steps
  87. self.warmup_steps = warmup_steps
  88. self.start_learning_rate = start_learning_rate
  89. self.end_learning_rate = end_learning_rate
  90. self.power = power
  91. self.beta1 = beta1
  92. self.beta2 = beta2
  93. self.eps = eps
  94. self.weight_decay = weight_decay
  95. def apply_gradients(self, grads_and_vars):
  96. grads, vars = list(zip(*grads_and_vars))
  97. optimizer_lamb = self.lamb(
  98. params=vars, decay_steps=self.decay_steps, warmup_steps=self.warmup_steps,
  99. start_learning_rate=self.start_learning_rate, end_learning_rate=self.end_learning_rate, power=self.power,
  100. beta1=self.beta1, beta2=self.beta2, eps=self.eps, weight_decay=self.weight_decay
  101. )
  102. optimizer_lamb(grads)
  103. class LARS(object):
  104. def __init__(self, optimizer, **kwargs):
  105. self.lars = ms.nn.LARS(optimizer=optimizer, **kwargs)
  106. def apply_gradients(self, grads_and_vars):
  107. grads, _ = list(zip(*grads_and_vars))
  108. self.lars(grads)

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.