using System;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
namespace Tensorflow.Keras.Optimizers
{
///
/// Optimizer that implements the RMSprop algorithm.
///
public class RMSprop : OptimizerV2
{
RMSpropArgs args;
bool centered => args.Centered;
protected override string _name => "RMSprop";
public RMSprop(RMSpropArgs args) : base(args)
{
this.args = args;
_set_hyper("rho", args.RHO);
_set_hyper("momentum", args.Momentum);
}
protected override void _create_slots(IVariableV1[] var_list)
{
foreach (var var in var_list)
add_slot(var, "rms");
if (_momentum)
foreach (var var in var_list)
add_slot(var, "momentum");
if (centered)
foreach (var var in var_list)
add_slot(var, "mg");
}
protected override void _prepare_local(DeviceDType device_dtype, Dictionary> _apply_state)
{
base._prepare_local(device_dtype, _apply_state);
var rho = array_ops.identity(_get_hyper("rho", device_dtype.DType));
_apply_state[device_dtype]["neg_lr_t"] = -_apply_state[device_dtype]["lr_t"];
_apply_state[device_dtype]["epsilon"] = ops.convert_to_tensor(args.Epsilon, dtype: device_dtype.DType);
_apply_state[device_dtype]["rho"] = rho;
_apply_state[device_dtype]["momentum"] = array_ops.identity(_get_hyper("momentum", device_dtype.DType));
_apply_state[device_dtype]["one_minus_rho"] = 1.0f - rho;
}
protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary> _apply_state)
{
Dictionary coefficients = null;
foreach (var state in _apply_state)
{
if (state.Key.DType == var.dtype.as_base_dtype()
&& state.Key.Device == var.Device)
{
coefficients = state.Value;
break;
}
}
var rms = get_slot(var, "rms");
if (_momentum)
{
throw new NotImplementedException("");
}
else
{
var rms_t = coefficients["rho"] * rms.AsTensor() + coefficients["one_minus_rho"] * math_ops.square(grad);
rms_t = state_ops.assign(rms, rms_t, use_locking: _use_locking);
var denom_t = rms_t;
if (centered)
{
throw new NotImplementedException("");
}
var var_t = var.AsTensor() - coefficients["lr_t"] * grad / (math_ops.sqrt(denom_t) + coefficients["epsilon"]);
return state_ops.assign(var, var_t, use_locking: _use_locking).op;
}
}
}
}