fix: add the momentum parameter's implemention of SGDtags/v0.110.4-Transformer-Model
@@ -63,6 +63,6 @@ namespace Tensorflow.Keras | |||||
bool centered = false, | bool centered = false, | ||||
string name = "RMSprop"); | string name = "RMSprop"); | ||||
IOptimizer SGD(float learning_rate); | |||||
IOptimizer SGD(float learning_rate, float momentum); | |||||
} | } | ||||
} | } |
@@ -51,5 +51,9 @@ namespace Tensorflow | |||||
public static Tensor resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) | public static Tensor resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) | ||||
=> tf.Context.ExecuteOp("ResourceApplyGradientDescent", name, | => tf.Context.ExecuteOp("ResourceApplyGradientDescent", name, | ||||
new ExecuteOpArgs(var, alpha, delta).SetAttributes(new { use_locking })); | new ExecuteOpArgs(var, alpha, delta).SetAttributes(new { use_locking })); | ||||
public static Tensor resource_apply_keras_momentum(Tensor var, Tensor accum, Tensor lr, Tensor grad, Tensor momentum, bool use_locking = false, bool use_nesterov = false, string name = null) | |||||
=> tf.Context.ExecuteOp("ResourceApplyKerasMomentum", name, | |||||
new ExecuteOpArgs(var, accum, lr, grad, momentum).SetAttributes(new { use_locking, use_nesterov })); | |||||
} | } | ||||
} | } |
@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
Name = name | Name = name | ||||
}); | }); | ||||
public IOptimizer SGD(float learning_rate) | |||||
=> new SGD(learning_rate); | |||||
public IOptimizer SGD(float learning_rate, float momentum) | |||||
=> new SGD(learning_rate, momentum); | |||||
} | } | ||||
} | } |
@@ -22,6 +22,8 @@ namespace Tensorflow.Keras.Optimizers | |||||
_set_hyper("decay", decay); | _set_hyper("decay", decay); | ||||
_momentum = momentum > 0; | _momentum = momentum > 0; | ||||
if (momentum < 0 || momentum > 1) | |||||
throw new ValueError($"momentum must be a number between 0 and 1, got {momentum}."); | |||||
_set_hyper("momentum", momentum); | _set_hyper("momentum", momentum); | ||||
@@ -30,6 +32,13 @@ namespace Tensorflow.Keras.Optimizers | |||||
#pragma warning restore CS1717 // Assignment made to same variable | #pragma warning restore CS1717 // Assignment made to same variable | ||||
} | } | ||||
protected override void _create_slots(IVariableV1[] var_list) | |||||
{ | |||||
if (_momentum) | |||||
foreach (var var in var_list) | |||||
add_slot(var, "momentum"); | |||||
} | |||||
protected override void _prepare_local(DeviceDType device_dtype, | protected override void _prepare_local(DeviceDType device_dtype, | ||||
Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | ||||
{ | { | ||||
@@ -43,7 +52,15 @@ namespace Tensorflow.Keras.Optimizers | |||||
{ | { | ||||
if (_momentum) | if (_momentum) | ||||
{ | { | ||||
throw new NotImplementedException("_resource_apply_dense"); | |||||
var momentum_var = get_slot(var, "momentum"); | |||||
return gen_training_ops.resource_apply_keras_momentum( | |||||
var.Handle, | |||||
momentum_var.Handle, | |||||
_get_hyper("learning_rate", var.dtype), | |||||
grad, | |||||
_get_hyper("momentum", var.dtype), | |||||
use_locking: _use_locking, | |||||
use_nesterov: nesterov); | |||||
} | } | ||||
var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype()); | var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype()); | ||||