Browse Source

Merge pull request #1136 from Beacontownfc/mybranch

Add a new optimizer
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
eac68ff05c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 101 additions and 0 deletions
  1. +21
    -0
      src/TensorFlowNET.Core/Keras/IOptimizerApi.cs
  2. +64
    -0
      src/TensorFlowNET.Keras/Optimizers/AdamW.cs
  3. +16
    -0
      src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs

+ 21
- 0
src/TensorFlowNET.Core/Keras/IOptimizerApi.cs View File

@@ -25,6 +25,27 @@ namespace Tensorflow.Keras
bool amsgrad = false,
string name = "Adam");

/// <summary>
/// Adam enables L2 weight decay on gradients.
/// </summary>
/// <param name="learning_rate"></param>
/// <param name="weight_decay"></param>
/// <param name="beta_1"></param>
/// <param name="beta_2"></param>
/// <param name="epsilon"></param>
/// <param name="amsgrad"></param>
/// <param name="decay_params"></param>
/// <param name="name"></param>
/// <returns></returns>
IOptimizer AdamW(float learning_rate = 0.001f,
float weight_decay = 0.004f,
float beta_1 = 0.9f,
float beta_2 = 0.999f,
float epsilon = 1e-7f,
bool amsgrad = false,
List<string> no_decay_params = null,
string name = "AdamW");

/// <summary>
/// Construct a new RMSprop optimizer.
/// </summary>


+ 64
- 0
src/TensorFlowNET.Keras/Optimizers/AdamW.cs View File

@@ -0,0 +1,64 @@
namespace Tensorflow.Keras.Optimizers
{
public class AdamW : Adam
{
string name;
float weight_decay;
DeviceDType deType;
List<string> no_decay_params = null;
public AdamW(float learning_rate= 0.001f,
float weight_decay= 0.004f,
float beta_1= 0.9f,
float beta_2= 0.999f,
float epsilon= 1e-7f,
bool amsgrad = false,
List<string> no_decay_params = null,
string name= "AdamW") : base(learning_rate, beta_1, beta_2, epsilon, amsgrad)
{
this.name = name;
this.weight_decay = weight_decay;
this.no_decay_params = no_decay_params;
}

protected Operation _decay_weights_op(IVariableV1 var, float learning_rate, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{
bool do_decay = _do_use_weight_decay(var.Name);
if (do_decay) return var.assign_add(
-learning_rate * var.AsTensor() * apply_state[deType]["weight_decay"]);
return tf.no_op();
}


protected bool _do_use_weight_decay(string param_name)
{
// Whether to use L2 weight decay for `param_name`.
if (this.weight_decay == 0)
return false;

if (this.no_decay_params != null)
{
foreach (var name in no_decay_params)
{
if (param_name.Contains(name)) return false;
}

}
return true;
}

protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{
var decay = _decay_weights_op(var, _hyper["learning_rate"], apply_state);
tf.control_dependencies(new[] { decay });
return base._resource_apply_dense(var, grad, apply_state);
}

protected override void _prepare_local(DeviceDType device_dtype, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{
this.deType = device_dtype;
base._prepare_local(device_dtype, apply_state);
apply_state[device_dtype]["weight_decay"] = tf.constant(
weight_decay, name: "adam_weight_decay_rate");
}
}
}

+ 16
- 0
src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs View File

@@ -29,6 +29,22 @@ namespace Tensorflow.Keras.Optimizers
amsgrad: amsgrad,
name: name);

public IOptimizer AdamW(float learning_rate = 0.001f,
float weight_decay = 0.004f,
float beta_1 = 0.9f,
float beta_2 = 0.999f,
float epsilon = 1e-7f,
bool amsgrad = false,
List<string> no_decay_params = null,
string name = "AdamW") => new AdamW(learning_rate: learning_rate,
beta_1: beta_1,
beta_2: beta_2,
epsilon: epsilon,
amsgrad: amsgrad,
name: name,
weight_decay: weight_decay,
no_decay_params: no_decay_params);

/// <summary>
/// Construct a new RMSprop optimizer.
/// </summary>


Loading…
Cancel
Save