From 6862d3a0432b0623bba23e51c14d42ac1974e22f Mon Sep 17 00:00:00 2001
From: Beacontownfc <19636977267@qq.com>
Date: Fri, 7 Jul 2023 00:25:38 +0000
Subject: [PATCH 1/2] Add AdamW optimizer
---
src/TensorFlowNET.Core/Keras/IOptimizerApi.cs | 21 ++++++
src/TensorFlowNET.Keras/Optimizers/AdamW.cs | 67 +++++++++++++++++++
.../Optimizers/OptimizerApi.cs | 16 +++++
3 files changed, 104 insertions(+)
create mode 100644 src/TensorFlowNET.Keras/Optimizers/AdamW.cs
diff --git a/src/TensorFlowNET.Core/Keras/IOptimizerApi.cs b/src/TensorFlowNET.Core/Keras/IOptimizerApi.cs
index 961ce91a..d0d3a74f 100644
--- a/src/TensorFlowNET.Core/Keras/IOptimizerApi.cs
+++ b/src/TensorFlowNET.Core/Keras/IOptimizerApi.cs
@@ -25,6 +25,27 @@ namespace Tensorflow.Keras
bool amsgrad = false,
string name = "Adam");
+ ///
+ /// Adam enables L2 weight decay on gradients.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ IOptimizer AdamW(float learning_rate = 0.001f,
+ float weight_decay = 0.004f,
+ float beta_1 = 0.9f,
+ float beta_2 = 0.999f,
+ float epsilon = 1e-7f,
+ bool amsgrad = false,
+ List no_decay_params = null,
+ string name = "AdamW");
+
///
/// Construct a new RMSprop optimizer.
///
diff --git a/src/TensorFlowNET.Keras/Optimizers/AdamW.cs b/src/TensorFlowNET.Keras/Optimizers/AdamW.cs
new file mode 100644
index 00000000..469b8ad2
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Optimizers/AdamW.cs
@@ -0,0 +1,67 @@
+namespace Tensorflow.Keras.Optimizers
+{
+ public class AdamW : Adam
+ {
+ string name;
+ float weight_decay;
+ DeviceDType deType;
+ List no_decay_params = null;
+ public AdamW(float learning_rate= 0.001f,
+ float weight_decay= 0.004f,
+ float beta_1= 0.9f,
+ float beta_2= 0.999f,
+ float epsilon= 1e-7f,
+ bool amsgrad = false,
+ List no_decay_params = null,
+ string name= "AdamW") : base(learning_rate, beta_1, beta_2, epsilon, amsgrad)
+ {
+ this.name = name;
+ this.weight_decay = weight_decay;
+ this.no_decay_params = no_decay_params;
+ }
+
+ protected Operation _decay_weights_op(IVariableV1 var, float learning_rate, Dictionary> apply_state)
+ {
+ var device_dtype = new DeviceDType();
+ device_dtype.DType = var.dtype;
+ device_dtype.Device = var.Device;
+ bool do_decay = _do_use_weight_decay(var.Name);
+ if (do_decay) return var.assign_add(
+ -learning_rate * var.AsTensor() * apply_state[deType]["weight_decay"]);
+ return tf.no_op();
+ }
+
+
+ protected bool _do_use_weight_decay(string param_name)
+ {
+ // Whether to use L2 weight decay for `param_name`.
+ if (this.weight_decay == 0)
+ return false;
+
+ if (this.no_decay_params != null)
+ {
+ foreach (var name in no_decay_params)
+ {
+ if (param_name.Contains(name)) return false;
+ }
+
+ }
+ return true;
+ }
+
+ protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary> apply_state)
+ {
+ var decay = _decay_weights_op(var, _hyper["learning_rate"], apply_state);
+ tf.control_dependencies(new[] { decay });
+ return base._resource_apply_dense(var, grad, apply_state);
+ }
+
+ protected override void _prepare_local(DeviceDType device_dtype, Dictionary> apply_state)
+ {
+ this.deType = device_dtype;
+ base._prepare_local(device_dtype, apply_state);
+ apply_state[device_dtype]["weight_decay"] = tf.constant(
+ weight_decay, name: "adam_weight_decay_rate");
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs b/src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs
index 31eb88be..28069426 100644
--- a/src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs
+++ b/src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs
@@ -29,6 +29,22 @@ namespace Tensorflow.Keras.Optimizers
amsgrad: amsgrad,
name: name);
+ public IOptimizer AdamW(float learning_rate = 0.001f,
+ float weight_decay = 0.004f,
+ float beta_1 = 0.9f,
+ float beta_2 = 0.999f,
+ float epsilon = 1e-7f,
+ bool amsgrad = false,
+ List no_decay_params = null,
+ string name = "AdamW") => new AdamW(learning_rate: learning_rate,
+ beta_1: beta_1,
+ beta_2: beta_2,
+ epsilon: epsilon,
+ amsgrad: amsgrad,
+ name: name,
+ weight_decay: weight_decay,
+ no_decay_params: no_decay_params);
+
///
/// Construct a new RMSprop optimizer.
///
From cc6ddc144fa85010b111df2b4c596c7230052080 Mon Sep 17 00:00:00 2001
From: Beacontownfc <19636977267@qq.com>
Date: Fri, 7 Jul 2023 00:33:41 +0000
Subject: [PATCH 2/2] Add AdamW optimizer
---
src/TensorFlowNET.Keras/Optimizers/AdamW.cs | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/src/TensorFlowNET.Keras/Optimizers/AdamW.cs b/src/TensorFlowNET.Keras/Optimizers/AdamW.cs
index 469b8ad2..d111b5d3 100644
--- a/src/TensorFlowNET.Keras/Optimizers/AdamW.cs
+++ b/src/TensorFlowNET.Keras/Optimizers/AdamW.cs
@@ -1,4 +1,4 @@
-namespace Tensorflow.Keras.Optimizers
+namespace Tensorflow.Keras.Optimizers
{
public class AdamW : Adam
{
@@ -22,9 +22,6 @@
protected Operation _decay_weights_op(IVariableV1 var, float learning_rate, Dictionary> apply_state)
{
- var device_dtype = new DeviceDType();
- device_dtype.DType = var.dtype;
- device_dtype.Device = var.Device;
bool do_decay = _do_use_weight_decay(var.Name);
if (do_decay) return var.assign_add(
-learning_rate * var.AsTensor() * apply_state[deType]["weight_decay"]);