You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RMSprop.cs 3.1 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. using System;
  2. using System.Collections.Generic;
  3. using Tensorflow.Keras.ArgsDefinition;
  4. namespace Tensorflow.Keras.Optimizers
  5. {
  6. /// <summary>
  7. /// Optimizer that implements the RMSprop algorithm.
  8. /// </summary>
  9. public class RMSprop : OptimizerV2
  10. {
  11. RMSpropArgs args;
  12. bool centered => args.Centered;
  13. protected override string _name => "RMSprop";
  14. public RMSprop(RMSpropArgs args) : base(args)
  15. {
  16. this.args = args;
  17. _set_hyper("rho", args.RHO);
  18. _set_hyper("momentum", args.Momentum);
  19. }
  20. protected override void _create_slots(IVariableV1[] var_list)
  21. {
  22. foreach (var var in var_list)
  23. add_slot(var, "rms");
  24. if (_momentum)
  25. foreach (var var in var_list)
  26. add_slot(var, "momentum");
  27. if (centered)
  28. foreach (var var in var_list)
  29. add_slot(var, "mg");
  30. }
  31. protected override void _prepare_local(DeviceDType device_dtype, Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state)
  32. {
  33. base._prepare_local(device_dtype, _apply_state);
  34. var rho = array_ops.identity(_get_hyper("rho", device_dtype.DType));
  35. _apply_state[device_dtype]["neg_lr_t"] = -_apply_state[device_dtype]["lr_t"];
  36. _apply_state[device_dtype]["epsilon"] = ops.convert_to_tensor(args.Epsilon, dtype: device_dtype.DType);
  37. _apply_state[device_dtype]["rho"] = rho;
  38. _apply_state[device_dtype]["momentum"] = array_ops.identity(_get_hyper("momentum", device_dtype.DType));
  39. _apply_state[device_dtype]["one_minus_rho"] = 1.0f - rho;
  40. }
  41. protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state)
  42. {
  43. Dictionary<string, Tensor> coefficients = null;
  44. foreach (var state in _apply_state)
  45. {
  46. if (state.Key.DType == var.dtype.as_base_dtype()
  47. && state.Key.Device == var.Device)
  48. {
  49. coefficients = state.Value;
  50. break;
  51. }
  52. }
  53. var rms = get_slot(var, "rms");
  54. if (_momentum)
  55. {
  56. throw new NotImplementedException("");
  57. }
  58. else
  59. {
  60. var rms_t = coefficients["rho"] * rms.AsTensor() + coefficients["one_minus_rho"] * math_ops.square(grad);
  61. rms_t = state_ops.assign(rms, rms_t, use_locking: _use_locking);
  62. var denom_t = rms_t;
  63. if (centered)
  64. {
  65. throw new NotImplementedException("");
  66. }
  67. var var_t = var.AsTensor() - coefficients["lr_t"] * grad / (math_ops.sqrt(denom_t) + coefficients["epsilon"]);
  68. return state_ops.assign(var, var_t, use_locking: _use_locking).op;
  69. }
  70. }
  71. }
  72. }