|
|
@@ -46,6 +46,8 @@ namespace Tensorflow |
|
|
|
all.AddRange(stop_gradients); |
|
|
|
all.AddRange(grad_ys); |
|
|
|
|
|
|
|
var grads = new Dictionary<string, object>(); |
|
|
|
|
|
|
|
Python.with<ops.name_scope>(new ops.name_scope(name, "gradients", values: all), scope => |
|
|
|
{ |
|
|
|
string grad_scope = scope; |
|
|
@@ -78,7 +80,7 @@ namespace Tensorflow |
|
|
|
* aggregate the list of received gradients into a Add() Operation if there |
|
|
|
* is more than one. |
|
|
|
**/ |
|
|
|
var grads = new Dictionary<string, Tensor[][]>(); |
|
|
|
|
|
|
|
for(int i = 0; i < ys.Count(); i++) |
|
|
|
{ |
|
|
|
(var y, var grad_y) = Python.zip(ys, grad_ys, i); |
|
|
@@ -111,6 +113,7 @@ namespace Tensorflow |
|
|
|
//loop_state.EnterGradWhileContext(op, before: true); |
|
|
|
var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); |
|
|
|
|
|
|
|
Tensor[] in_grads = null; |
|
|
|
var is_partitioned_call = _IsPartitionedCall(op); |
|
|
|
var is_func_call = false; |
|
|
|
var has_out_grads = true; |
|
|
@@ -124,13 +127,60 @@ namespace Tensorflow |
|
|
|
{ |
|
|
|
// A grad_fn must be defined, either as a function or as None |
|
|
|
// for ops that do not have gradients. |
|
|
|
var grad_fn = ops.get_gradient_function(op); |
|
|
|
|
|
|
|
Python.with<ops.name_scope>(new ops.name_scope(op.Name + "_grad"), delegate |
|
|
|
{ |
|
|
|
if (grad_fn != null) |
|
|
|
{ |
|
|
|
in_grads = _MaybeCompile(grad_scope, op, out_grads[0], null, grad_fn); |
|
|
|
_VerifyGeneratedGradients(in_grads, op); |
|
|
|
} |
|
|
|
}); |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for(int i =0; i< in_grads.Length; i++) |
|
|
|
{ |
|
|
|
var inputs = (List<Tensor>)_NonEagerInputs(op, xs); |
|
|
|
var (t_in, in_grad) = Python.zip(inputs, in_grads, i); |
|
|
|
if(in_grad != null) |
|
|
|
{ |
|
|
|
in_grad.shape = t_in.shape; |
|
|
|
_SetGrad(grads, t_in, in_grad); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Update pending count for the inputs of op and enqueue ready ops. |
|
|
|
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); |
|
|
|
} |
|
|
|
}); |
|
|
|
|
|
|
|
return null; |
|
|
|
return xs.Select(x => _GetGrad(grads, x)).ToArray(); |
|
|
|
} |
|
|
|
|
|
|
|
private static void _UpdatePendingAndEnqueueReady(Dictionary<string, Tensor[][]> grads, |
|
|
|
Operation op, |
|
|
|
Queue<Operation> queue, |
|
|
|
Dictionary<string ,int> pending_count, |
|
|
|
object loop_state, |
|
|
|
Tensor[] xs) |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) |
|
|
|
{ |
|
|
|
if (grads.Count() != op.inputs._inputs.Count()) |
|
|
|
throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + |
|
|
|
$"inputs {op.inputs._inputs.Count()}"); |
|
|
|
} |
|
|
|
|
|
|
|
private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor out_grads, Action func, Func<Operation, Tensor, (Tensor, Tensor)> grad_fn) |
|
|
|
{ |
|
|
|
var in_grads = grad_fn(op, out_grads); |
|
|
|
return new Tensor[] { in_grads.Item1, in_grads.Item2 }; |
|
|
|
} |
|
|
|
|
|
|
|
private static bool _IsPartitionedCall(Operation op) |
|
|
@@ -138,9 +188,9 @@ namespace Tensorflow |
|
|
|
return op.OpType == "PartitionedCall" || op.OpType == "StatefulPartitionedCall"; |
|
|
|
} |
|
|
|
|
|
|
|
private static Tensor[] _AggregatedGrads(Dictionary<string, Tensor[][]> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) |
|
|
|
private static Tensor[] _AggregatedGrads(Dictionary<string, object> grads, Operation op, string gradient_uid, object loop_state, int aggregation_method = 0) |
|
|
|
{ |
|
|
|
var out_grads = _GetGrads(grads, op); |
|
|
|
var out_grads = _GetGrads(grads, op) as object[]; |
|
|
|
for(int i = 0; i < out_grads.Length; i++) |
|
|
|
{ |
|
|
|
var out_grad = out_grads[i]; |
|
|
@@ -195,12 +245,22 @@ namespace Tensorflow |
|
|
|
return stop_ops.ToArray(); |
|
|
|
} |
|
|
|
|
|
|
|
private static Tensor[][] _GetGrads(Dictionary<string, Tensor[][]> grads, Operation op) |
|
|
|
private static Tensor _GetGrad(Dictionary<string, Tensor[][]> grads, Tensor t) |
|
|
|
{ |
|
|
|
var op = t.op; |
|
|
|
if (!grads.ContainsKey(op.Name)) |
|
|
|
return null; |
|
|
|
Tensor[][] op_grads = grads[op.Name]; |
|
|
|
var t_grad = op_grads[t.value_index]; |
|
|
|
return t_grad[0]; |
|
|
|
} |
|
|
|
|
|
|
|
private static object _GetGrads(Dictionary<string, object> grads, Operation op) |
|
|
|
{ |
|
|
|
if (grads.ContainsKey(op.Name)) |
|
|
|
return grads[op.Name]; |
|
|
|
else |
|
|
|
return op.outputs.Select(x => new Tensor[0]).ToArray(); |
|
|
|
return op.outputs.Select(x => new object[0]).ToArray(); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
@@ -209,17 +269,17 @@ namespace Tensorflow |
|
|
|
/// <param name="grads"></param> |
|
|
|
/// <param name="t"></param> |
|
|
|
/// <param name="grad"></param> |
|
|
|
private static void _SetGrad(Dictionary<string, Tensor[][]> grads, Tensor t, Tensor grad) |
|
|
|
private static void _SetGrad(Dictionary<string, object> grads, Tensor t, Tensor grad) |
|
|
|
{ |
|
|
|
var op = t.op; |
|
|
|
Tensor[][] op_grads = null; |
|
|
|
object op_grads = null; |
|
|
|
if (!grads.ContainsKey(op.Name)) |
|
|
|
{ |
|
|
|
op_grads = op.outputs.Select(x => new Tensor[1]).ToArray(); |
|
|
|
op_grads = op.outputs.Select(x => new object[1]).ToList(); |
|
|
|
grads[op.Name] = op_grads; |
|
|
|
} |
|
|
|
var t_grads = op_grads[t.value_index]; |
|
|
|
t_grads[0] = grad; |
|
|
|
var t_grads = (op_grads as object[])[t.value_index]; |
|
|
|
// t_grads[0] = grad; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
@@ -322,6 +382,7 @@ namespace Tensorflow |
|
|
|
{ |
|
|
|
return op.inputs; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Mark all ops reached from "from_ops" |
|
|
|
/// </summary> |
|
|
|