You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SGD.cs 2.7 kB

5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow.Keras.ArgsDefinition;
  5. namespace Tensorflow.Keras.Optimizers
  6. {
  7. public class SGD : OptimizerV2
  8. {
  9. protected override string _name => "SGD";
  10. #pragma warning disable CS0169 // The field 'SGD.nesterov' is never used
  11. bool nesterov;
  12. #pragma warning restore CS0169 // The field 'SGD.nesterov' is never used
  13. public SGD(float learning_rate,
  14. float momentum = 0.0f,
  15. bool nesterov = false,
  16. float decay = 0.0f) : base(new OptimizerV2Args { })
  17. {
  18. _set_hyper("learning_rate", learning_rate);
  19. _set_hyper("decay", decay);
  20. _momentum = momentum > 0;
  21. if (momentum < 0 || momentum > 1)
  22. throw new ValueError($"momentum must be a number between 0 and 1, got {momentum}.");
  23. _set_hyper("momentum", momentum);
  24. #pragma warning disable CS1717 // Assignment made to same variable
  25. nesterov = nesterov;
  26. #pragma warning restore CS1717 // Assignment made to same variable
  27. }
  28. protected override void _create_slots(IVariableV1[] var_list)
  29. {
  30. if (_momentum)
  31. foreach (var var in var_list)
  32. add_slot(var, "momentum");
  33. }
  34. protected override void _prepare_local(DeviceDType device_dtype,
  35. Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state)
  36. {
  37. base._prepare_local(device_dtype, _apply_state);
  38. _apply_state[device_dtype]["momentum"] = array_ops.identity(
  39. _get_hyper("momentum", device_dtype.DType));
  40. }
  41. protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state)
  42. {
  43. if (_momentum)
  44. {
  45. var momentum_var = get_slot(var, "momentum");
  46. return gen_training_ops.resource_apply_keras_momentum(
  47. var.Handle,
  48. momentum_var.Handle,
  49. _get_hyper("learning_rate", var.dtype),
  50. grad,
  51. _get_hyper("momentum", var.dtype),
  52. use_locking: _use_locking,
  53. use_nesterov: nesterov);
  54. }
  55. var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype());
  56. return gen_training_ops.resource_apply_gradient_descent(var.Handle,
  57. _apply_state[device_dtype]["lr_t"],
  58. grad,
  59. use_locking: _use_locking);
  60. }
  61. }
  62. }