using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using distribute_lib = Tensorflow.Distribute;
namespace Tensorflow
{
///
/// Base class for optimizers.
/// This class defines the API to add Ops to train a model. You never use this
/// class directly, but instead instantiate one of its subclasses such as
/// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
///
public abstract class Optimizer
{
public string Name { get; set; }
public double LearningRate { get; set; }
public Tensor LearningRateTensor { get; set; }
public bool _use_locking;
public Dictionary _slots;
public Dictionary _non_slot_dict;
public Dictionary _deferred_slot_restorations;
public Optimizer(double learning_rate, bool use_locking, string name = "")
{
if (String.IsNullOrEmpty(name))
throw new NotImplementedException("Must specify the optimizer name");
Name = name;
_use_locking = use_locking;
// Dictionary of slots.
_slots = new Dictionary();
_non_slot_dict = new Dictionary();
_deferred_slot_restorations = new Dictionary();
}
///
/// Add operations to minimize `loss` by updating `var_list`
///
///
///
public Optimizer minimize(Tensor loss,
GateGradientType gate_gradients = GateGradientType.GATE_OP,
bool colocate_gradients_with_ops = false)
{
compute_gradients(loss,
gate_gradients: gate_gradients,
colocate_gradients_with_ops: colocate_gradients_with_ops);
return this;
}
///
/// Compute gradients of `loss` for the variables in `var_list`.
///
///
///
public List> compute_gradients(Tensor loss,
List var_list = null,
int? aggregation_method = null,
GateGradientType gate_gradients = GateGradientType.GATE_OP,
bool colocate_gradients_with_ops = false,
List grad_loss = null)
{
int num_towers = 1;
if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN)
{
}
var tmp = variables.trainable_variables();
switch (tmp)
{
case List values:
var_list = values;
break;
}
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList();
var var_refs = processors.Select(x => x.target()).ToList();
gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss,
gate_gradients: (gate_gradients == GateGradientType.GATE_OP),
aggregation_method: aggregation_method,
colocate_gradients_with_ops: colocate_gradients_with_ops);
return null;
}
}
}