|
@@ -22,6 +22,8 @@ namespace Tensorflow.Keras.Optimizers |
|
|
_set_hyper("decay", decay); |
|
|
_set_hyper("decay", decay); |
|
|
|
|
|
|
|
|
_momentum = momentum > 0; |
|
|
_momentum = momentum > 0; |
|
|
|
|
|
if (momentum < 0 || momentum > 1) |
|
|
|
|
|
throw new ValueError($"momentum must be a number between 0 and 1, got {momentum}."); |
|
|
|
|
|
|
|
|
_set_hyper("momentum", momentum); |
|
|
_set_hyper("momentum", momentum); |
|
|
|
|
|
|
|
@@ -30,6 +32,13 @@ namespace Tensorflow.Keras.Optimizers |
|
|
#pragma warning restore CS1717 // Assignment made to same variable |
|
|
#pragma warning restore CS1717 // Assignment made to same variable |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
protected override void _create_slots(IVariableV1[] var_list) |
|
|
|
|
|
{ |
|
|
|
|
|
if (_momentum) |
|
|
|
|
|
foreach (var var in var_list) |
|
|
|
|
|
add_slot(var, "momentum"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
protected override void _prepare_local(DeviceDType device_dtype, |
|
|
protected override void _prepare_local(DeviceDType device_dtype, |
|
|
Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) |
|
|
Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) |
|
|
{ |
|
|
{ |
|
@@ -43,7 +52,15 @@ namespace Tensorflow.Keras.Optimizers |
|
|
{ |
|
|
{ |
|
|
if (_momentum) |
|
|
if (_momentum) |
|
|
{ |
|
|
{ |
|
|
throw new NotImplementedException("_resource_apply_dense"); |
|
|
|
|
|
|
|
|
var momentum_var = get_slot(var, "momentum"); |
|
|
|
|
|
return gen_training_ops.resource_apply_keras_momentum( |
|
|
|
|
|
var.Handle, |
|
|
|
|
|
momentum_var.Handle, |
|
|
|
|
|
_get_hyper("learning_rate", var.dtype), |
|
|
|
|
|
grad, |
|
|
|
|
|
_get_hyper("momentum", var.dtype), |
|
|
|
|
|
use_locking: _use_locking, |
|
|
|
|
|
use_nesterov: nesterov); |
|
|
} |
|
|
} |
|
|
var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype()); |
|
|
var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype()); |
|
|
|
|
|
|
|
|