|
|
@@ -2,7 +2,6 @@ |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using System.Text; |
|
|
|
using distribute_lib = Tensorflow.Distribute; |
|
|
|
using static Tensorflow.Python; |
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
@@ -82,7 +81,8 @@ namespace Tensorflow |
|
|
|
var grads_and_vars = compute_gradients(loss, var_list:var_list, |
|
|
|
gate_gradients: gate_gradients, |
|
|
|
aggregation_method:aggregation_method, |
|
|
|
colocate_gradients_with_ops: colocate_gradients_with_ops); |
|
|
|
colocate_gradients_with_ops: colocate_gradients_with_ops, |
|
|
|
grad_loss: grad_loss); |
|
|
|
|
|
|
|
var vars_with_grad = grads_and_vars.Where(x => x.Item1 != null).Select(x => x.Item2).ToArray(); |
|
|
|
if (vars_with_grad.Length == 0) |
|
|
@@ -232,30 +232,31 @@ namespace Tensorflow |
|
|
|
int? aggregation_method = null, |
|
|
|
GateGradientType gate_gradients = GateGradientType.GATE_OP, |
|
|
|
bool colocate_gradients_with_ops = false, |
|
|
|
Tensor[] grad_loss = null) |
|
|
|
Tensor grad_loss = null) |
|
|
|
{ |
|
|
|
// Scale loss if using a "mean" loss reduction and multiple replicas. |
|
|
|
loss = _scale_loss(loss); |
|
|
|
int num_towers = 1; |
|
|
|
if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN) |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
var tmp = variables.trainable_variables(); |
|
|
|
var vars = ops.get_collection<RefVariable>(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); |
|
|
|
switch (tmp) |
|
|
|
{ |
|
|
|
case List<RefVariable> values: |
|
|
|
var_list = values; |
|
|
|
var_list = values.Concat(vars).ToList(); |
|
|
|
break; |
|
|
|
case List<VariableV1> values: |
|
|
|
var_list = values.Select(x => x as RefVariable).ToList(); |
|
|
|
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
var_list = var_list.Concat(ops.get_collection<RefVariable>(ops.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); |
|
|
|
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); |
|
|
|
var var_refs = processors.Select(x => x.target()).ToArray(); |
|
|
|
|
|
|
|
var grads = gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss, |
|
|
|
gate_gradients: (gate_gradients == GateGradientType.GATE_OP), |
|
|
|
var grads = gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss == null ? null : new Tensor[] { grad_loss }, |
|
|
|
gate_gradients: gate_gradients == GateGradientType.GATE_OP, |
|
|
|
aggregation_method: aggregation_method, |
|
|
|
colocate_gradients_with_ops: colocate_gradients_with_ops); |
|
|
|
|
|
|
@@ -269,6 +270,14 @@ namespace Tensorflow |
|
|
|
return grads_and_vars; |
|
|
|
} |
|
|
|
|
|
|
|
private Tensor _scale_loss(Tensor loss_value) |
|
|
|
{ |
|
|
|
ops.get_default_graph()._is_loss_scaled_by_optimizer = false; |
|
|
|
// TODO |
|
|
|
// if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: |
|
|
|
return loss_value; |
|
|
|
} |
|
|
|
|
|
|
|
protected T _call_if_callable<T>(T param) |
|
|
|
{ |
|
|
|
return param; |
|
|
|