|
|
@@ -1,4 +1,4 @@ |
|
|
|
namespace Tensorflow.Keras.Optimizers |
|
|
|
namespace Tensorflow.Keras.Optimizers |
|
|
|
{ |
|
|
|
public class AdamW : Adam |
|
|
|
{ |
|
|
@@ -22,9 +22,6 @@ |
|
|
|
|
|
|
|
protected Operation _decay_weights_op(IVariableV1 var, float learning_rate, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) |
|
|
|
{ |
|
|
|
var device_dtype = new DeviceDType(); |
|
|
|
device_dtype.DType = var.dtype; |
|
|
|
device_dtype.Device = var.Device; |
|
|
|
bool do_decay = _do_use_weight_decay(var.Name); |
|
|
|
if (do_decay) return var.assign_add( |
|
|
|
-learning_rate * var.AsTensor() * apply_state[deType]["weight_decay"]); |
|
|
|