|
|
@@ -1,5 +1,6 @@ |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using System.Text; |
|
|
|
using distribute_lib = Tensorflow.Distribute; |
|
|
|
|
|
|
@@ -48,8 +49,10 @@ namespace Tensorflow |
|
|
|
/// <param name="gate_gradients"></param> |
|
|
|
public List<KeyValuePair<object, object>> compute_gradients(Tensor loss, |
|
|
|
List<RefVariable> var_list = null, |
|
|
|
int? aggregation_method = null, |
|
|
|
GateGradientType gate_gradients = GateGradientType.GATE_OP, |
|
|
|
bool colocate_gradients_with_ops = false) |
|
|
|
bool colocate_gradients_with_ops = false, |
|
|
|
List<Tensor> grad_loss = null) |
|
|
|
{ |
|
|
|
int num_towers = 1; |
|
|
|
if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN) |
|
|
@@ -65,10 +68,13 @@ namespace Tensorflow |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
foreach(var v in var_list) |
|
|
|
{ |
|
|
|
var processors = var_list.Select(v => optimizer._get_processor(v)); |
|
|
|
var var_refs = processors.Select(x => x.target()).ToList(); |
|
|
|
|
|
|
|
} |
|
|
|
gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss, |
|
|
|
gate_gradients: (gate_gradients == GateGradientType.GATE_OP), |
|
|
|
aggregation_method: aggregation_method, |
|
|
|
colocate_gradients_with_ops: colocate_gradients_with_ops); |
|
|
|
|
|
|
|
return null; |
|
|
|
} |
|
|
|