namespace Tensorflow.Keras.Optimizers { public class AdamW : Adam { string name; float weight_decay; DeviceDType deType; List no_decay_params = null; public AdamW(float learning_rate= 0.001f, float weight_decay= 0.004f, float beta_1= 0.9f, float beta_2= 0.999f, float epsilon= 1e-7f, bool amsgrad = false, List no_decay_params = null, string name= "AdamW") : base(learning_rate, beta_1, beta_2, epsilon, amsgrad) { this.name = name; this.weight_decay = weight_decay; this.no_decay_params = no_decay_params; } protected Operation _decay_weights_op(IVariableV1 var, float learning_rate, Dictionary> apply_state) { bool do_decay = _do_use_weight_decay(var.Name); if (do_decay) return var.assign_add( -learning_rate * var.AsTensor() * apply_state[deType]["weight_decay"]); return tf.no_op(); } protected bool _do_use_weight_decay(string param_name) { // Whether to use L2 weight decay for `param_name`. if (this.weight_decay == 0) return false; if (this.no_decay_params != null) { foreach (var name in no_decay_params) { if (param_name.Contains(name)) return false; } } return true; } protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary> apply_state) { var decay = _decay_weights_op(var, _hyper["learning_rate"], apply_state); tf.control_dependencies(new[] { decay }); return base._resource_apply_dense(var, grad, apply_state); } protected override void _prepare_local(DeviceDType device_dtype, Dictionary> apply_state) { this.deType = device_dtype; base._prepare_local(device_dtype, apply_state); apply_state[device_dtype]["weight_decay"] = tf.constant( weight_decay, name: "adam_weight_decay_rate"); } } }