You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AdamW.cs 2.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. namespace Tensorflow.Keras.Optimizers
  2. {
  3. public class AdamW : Adam
  4. {
  5. string name;
  6. float weight_decay;
  7. DeviceDType deType;
  8. List<string> no_decay_params = null;
  9. public AdamW(float learning_rate= 0.001f,
  10. float weight_decay= 0.004f,
  11. float beta_1= 0.9f,
  12. float beta_2= 0.999f,
  13. float epsilon= 1e-7f,
  14. bool amsgrad = false,
  15. List<string> no_decay_params = null,
  16. string name= "AdamW") : base(learning_rate, beta_1, beta_2, epsilon, amsgrad)
  17. {
  18. this.name = name;
  19. this.weight_decay = weight_decay;
  20. this.no_decay_params = no_decay_params;
  21. }
  22. protected Operation _decay_weights_op(IVariableV1 var, float learning_rate, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
  23. {
  24. bool do_decay = _do_use_weight_decay(var.Name);
  25. if (do_decay) return var.assign_add(
  26. -learning_rate * var.AsTensor() * apply_state[deType]["weight_decay"]);
  27. return tf.no_op();
  28. }
  29. protected bool _do_use_weight_decay(string param_name)
  30. {
  31. // Whether to use L2 weight decay for `param_name`.
  32. if (this.weight_decay == 0)
  33. return false;
  34. if (this.no_decay_params != null)
  35. {
  36. foreach (var name in no_decay_params)
  37. {
  38. if (param_name.Contains(name)) return false;
  39. }
  40. }
  41. return true;
  42. }
  43. protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
  44. {
  45. var decay = _decay_weights_op(var, _hyper["learning_rate"], apply_state);
  46. tf.control_dependencies(new[] { decay });
  47. return base._resource_apply_dense(var, grad, apply_state);
  48. }
  49. protected override void _prepare_local(DeviceDType device_dtype, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
  50. {
  51. this.deType = device_dtype;
  52. base._prepare_local(device_dtype, apply_state);
  53. apply_state[device_dtype]["weight_decay"] = tf.constant(
  54. weight_decay, name: "adam_weight_decay_rate");
  55. }
  56. }
  57. }