|
|
@@ -107,7 +107,7 @@ namespace Tensorflow |
|
|
|
/// </returns> |
|
|
|
public Operation minimize(Tensor loss, |
|
|
|
IVariableV1 global_step = null, |
|
|
|
List<ResourceVariable> var_list=null, |
|
|
|
List<IVariableV1> var_list=null, |
|
|
|
GateGradientType gate_gradients = GateGradientType.GATE_OP, |
|
|
|
int? aggregation_method=null, |
|
|
|
bool colocate_gradients_with_ops = false, string name=null, Tensor grad_loss=null) |
|
|
@@ -142,17 +142,17 @@ namespace Tensorflow |
|
|
|
/// <returns> |
|
|
|
/// An `Operation` that applies the specified gradients. If `global_step` |
|
|
|
/// was not None, that operation also increments `global_step`.</returns> |
|
|
|
public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_vars, IVariableV1 global_step = null, string name = null) |
|
|
|
public Operation apply_gradients(Tuple<Tensor, IVariableV1>[] grads_and_vars, IVariableV1 global_step = null, string name = null) |
|
|
|
{ |
|
|
|
// No DistributionStrategy case. |
|
|
|
var converted_grads_and_vars = new List<(Tensor, ResourceVariable, _OptimizableVariable)>(); |
|
|
|
var converted_grads_and_vars = new List<(Tensor, IVariableV1, _OptimizableVariable)>(); |
|
|
|
foreach (var (g, v) in grads_and_vars) |
|
|
|
{ |
|
|
|
if(g != null) |
|
|
|
{ |
|
|
|
// Convert the grad to Tensor or IndexedSlices if necessary. |
|
|
|
var gR = ops.convert_to_tensor_or_indexed_slices(g); |
|
|
|
var p = optimizer._get_processor(v); |
|
|
|
var p = optimizer._get_processor(v as ResourceVariable); |
|
|
|
converted_grads_and_vars.Add((gR, v, p)); |
|
|
|
} |
|
|
|
} |
|
|
@@ -230,7 +230,7 @@ namespace Tensorflow |
|
|
|
/// silently ignored). |
|
|
|
/// </summary> |
|
|
|
/// <param name="var_list"></param> |
|
|
|
protected virtual void _create_slots(ResourceVariable[] var_list) |
|
|
|
protected virtual void _create_slots(IVariableV1[] var_list) |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
@@ -369,8 +369,8 @@ namespace Tensorflow |
|
|
|
/// A list of (gradient, variable) pairs. Variable is always present, but |
|
|
|
/// gradient can be `None`. |
|
|
|
/// </returns> |
|
|
|
public Tuple<Tensor, ResourceVariable>[] compute_gradients(Tensor loss, |
|
|
|
List<ResourceVariable> var_list = null, |
|
|
|
public Tuple<Tensor, IVariableV1>[] compute_gradients(Tensor loss, |
|
|
|
List<IVariableV1> var_list = null, |
|
|
|
int? aggregation_method = null, |
|
|
|
GateGradientType gate_gradients = GateGradientType.GATE_OP, |
|
|
|
bool colocate_gradients_with_ops = false, |
|
|
@@ -381,26 +381,13 @@ namespace Tensorflow |
|
|
|
|
|
|
|
if(var_list == null) |
|
|
|
{ |
|
|
|
var vars = ops.get_collection<ResourceVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); |
|
|
|
var vars = ops.get_collection<IVariableV1>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); |
|
|
|
var tmp = variables.trainable_variables(); |
|
|
|
switch (tmp) |
|
|
|
{ |
|
|
|
case List<ResourceVariable> values: |
|
|
|
var_list = values.Concat(vars).ToList(); |
|
|
|
break; |
|
|
|
/*case List<RefVariable> values: |
|
|
|
var_list = values.Concat(vars).ToList(); |
|
|
|
break; |
|
|
|
case List<IVariableV1> values: |
|
|
|
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); |
|
|
|
break;*/ |
|
|
|
default: |
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
var_list = (tmp as List<IVariableV1>).Concat(vars).ToList(); |
|
|
|
} |
|
|
|
|
|
|
|
var_list = var_list.Concat(ops.get_collection<ResourceVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); |
|
|
|
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); |
|
|
|
var_list = var_list.Concat(ops.get_collection<IVariableV1>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); |
|
|
|
var processors = var_list.Select(v => optimizer._get_processor(v as ResourceVariable)).ToList(); |
|
|
|
var var_refs = processors.Select(x => x.target()).ToArray(); |
|
|
|
|
|
|
|
var grads = gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss == null ? null : new Tensor[] { grad_loss }, |
|
|
@@ -412,7 +399,7 @@ namespace Tensorflow |
|
|
|
grads = control_flow_ops.tuple(grads); |
|
|
|
|
|
|
|
var grads_and_vars = zip(grads, var_list) |
|
|
|
.Select(x => new Tuple<Tensor, ResourceVariable>(x.Item1, x.Item2)) |
|
|
|
.Select(x => new Tuple<Tensor, IVariableV1>(x.Item1, x.Item2)) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
return grads_and_vars; |
|
|
|