using System; using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; namespace Tensorflow.Keras.Optimizers { public class SGD : OptimizerV2 { protected override string _name => "SGD"; #pragma warning disable CS0169 // The field 'SGD.nesterov' is never used bool nesterov; #pragma warning restore CS0169 // The field 'SGD.nesterov' is never used public SGD(float learning_rate, float momentum = 0.0f, bool nesterov = false, float decay = 0.0f) : base(new OptimizerV2Args { }) { _set_hyper("learning_rate", learning_rate); _set_hyper("decay", decay); _momentum = momentum > 0; if (momentum < 0 || momentum > 1) throw new ValueError($"momentum must be a number between 0 and 1, got {momentum}."); _set_hyper("momentum", momentum); #pragma warning disable CS1717 // Assignment made to same variable nesterov = nesterov; #pragma warning restore CS1717 // Assignment made to same variable } protected override void _create_slots(IVariableV1[] var_list) { if (_momentum) foreach (var var in var_list) add_slot(var, "momentum"); } protected override void _prepare_local(DeviceDType device_dtype, Dictionary> _apply_state) { base._prepare_local(device_dtype, _apply_state); _apply_state[device_dtype]["momentum"] = array_ops.identity( _get_hyper("momentum", device_dtype.DType)); } protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary> _apply_state) { if (_momentum) { var momentum_var = get_slot(var, "momentum"); return gen_training_ops.resource_apply_keras_momentum( var.Handle, momentum_var.Handle, _get_hyper("learning_rate", var.dtype), grad, _get_hyper("momentum", var.dtype), use_locking: _use_locking, use_nesterov: nesterov); } var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype()); return gen_training_ops.resource_apply_gradient_descent(var.Handle, _apply_state[device_dtype]["lr_t"], grad, use_locking: _use_locking); } } }