|
@@ -9,7 +9,7 @@ namespace Tensorflow |
|
|
{ |
|
|
{ |
|
|
public class gradients_impl |
|
|
public class gradients_impl |
|
|
{ |
|
|
{ |
|
|
public static void gradients(Tensor[] ys, |
|
|
|
|
|
|
|
|
public static Tensor[] gradients(Tensor[] ys, |
|
|
Tensor[] xs, |
|
|
Tensor[] xs, |
|
|
Tensor[] grad_ys = null, |
|
|
Tensor[] grad_ys = null, |
|
|
string name = "gradients", |
|
|
string name = "gradients", |
|
@@ -17,7 +17,7 @@ namespace Tensorflow |
|
|
bool gate_gradients = false, |
|
|
bool gate_gradients = false, |
|
|
int? aggregation_method = null) |
|
|
int? aggregation_method = null) |
|
|
{ |
|
|
{ |
|
|
_GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients); |
|
|
|
|
|
|
|
|
return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
public static Tensor[] _GradientsHelper(Tensor[] ys, |
|
|
public static Tensor[] _GradientsHelper(Tensor[] ys, |
|
@@ -91,7 +91,9 @@ namespace Tensorflow |
|
|
{ |
|
|
{ |
|
|
// 'ready' handles the case where one output gradient relies on |
|
|
// 'ready' handles the case where one output gradient relies on |
|
|
// another output's gradient. |
|
|
// another output's gradient. |
|
|
bool ready = !pending_count.ContainsKey(op.Name) || pending_count[op.Name] == 0; |
|
|
|
|
|
|
|
|
if (!pending_count.ContainsKey(op.Name)) |
|
|
|
|
|
pending_count[op.Name] = 0; |
|
|
|
|
|
bool ready = pending_count[op.Name] == 0; |
|
|
if(ready && !to_ops_set.Contains(op) && reachable_to_ops.Contains(op)) |
|
|
if(ready && !to_ops_set.Contains(op) && reachable_to_ops.Contains(op)) |
|
|
{ |
|
|
{ |
|
|
to_ops_set.Add(op); |
|
|
to_ops_set.Add(op); |
|
@@ -136,8 +138,12 @@ namespace Tensorflow |
|
|
|
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
in_grads = _NonEagerInputs(op, xs).Select(x => new Tensor(IntPtr.Zero)).ToArray(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
var inputs = (List<Tensor>)_NonEagerInputs(op, xs); |
|
|
|
|
|
|
|
|
var inputs = _NonEagerInputs(op, xs).ToList(); |
|
|
foreach (var (t_in, in_grad) in Python.zip(inputs, in_grads)) |
|
|
foreach (var (t_in, in_grad) in Python.zip(inputs, in_grads)) |
|
|
{ |
|
|
{ |
|
|
if(in_grad != null) |
|
|
if(in_grad != null) |
|
@@ -155,6 +161,15 @@ namespace Tensorflow |
|
|
return xs.Select(x => _GetGrad(grads, x)).ToArray(); |
|
|
return xs.Select(x => _GetGrad(grads, x)).ToArray(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/// <summary> |
|
|
|
|
|
/// Update pending count for the inputs of op and enqueue ready ops. |
|
|
|
|
|
/// </summary> |
|
|
|
|
|
/// <param name="grads"></param> |
|
|
|
|
|
/// <param name="op"></param> |
|
|
|
|
|
/// <param name="queue"></param> |
|
|
|
|
|
/// <param name="pending_count"></param> |
|
|
|
|
|
/// <param name="loop_state"></param> |
|
|
|
|
|
/// <param name="xs"></param> |
|
|
private static void _UpdatePendingAndEnqueueReady(Dictionary<string, Tensor[][]> grads, |
|
|
private static void _UpdatePendingAndEnqueueReady(Dictionary<string, Tensor[][]> grads, |
|
|
Operation op, |
|
|
Operation op, |
|
|
Queue<Operation> queue, |
|
|
Queue<Operation> queue, |
|
@@ -162,7 +177,28 @@ namespace Tensorflow |
|
|
object loop_state, |
|
|
object loop_state, |
|
|
Tensor[] xs) |
|
|
Tensor[] xs) |
|
|
{ |
|
|
{ |
|
|
|
|
|
foreach(var x in _NonEagerInputs(op, xs)) |
|
|
|
|
|
{ |
|
|
|
|
|
pending_count[x.op.Name] -= 1; |
|
|
|
|
|
var ready = pending_count[x.op.Name] == 0; |
|
|
|
|
|
|
|
|
|
|
|
if(loop_state != null && !ready) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (ready) |
|
|
|
|
|
{ |
|
|
|
|
|
if (control_flow_util.IsLoopExit(x.op)) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
else |
|
|
|
|
|
{ |
|
|
|
|
|
queue.Enqueue(x.op); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) |
|
|
private static void _VerifyGeneratedGradients(Tensor[] grads, Operation op) |
|
@@ -227,7 +263,10 @@ namespace Tensorflow |
|
|
bool is_stop_op = true; |
|
|
bool is_stop_op = true; |
|
|
foreach(var inp in _NonEagerInputs(op, xs)) |
|
|
foreach(var inp in _NonEagerInputs(op, xs)) |
|
|
{ |
|
|
{ |
|
|
if(pending_count.ContainsKey(op.Name) && pending_count[op.Name] > 0) |
|
|
|
|
|
|
|
|
if (!pending_count.ContainsKey(inp.op.Name)) |
|
|
|
|
|
pending_count[inp.op.Name] = 0; |
|
|
|
|
|
|
|
|
|
|
|
if (pending_count[inp.op.Name] > 0) |
|
|
{ |
|
|
{ |
|
|
is_stop_op = false; |
|
|
is_stop_op = false; |
|
|
break; |
|
|
break; |
|
@@ -267,14 +306,14 @@ namespace Tensorflow |
|
|
private static void _SetGrad(Dictionary<string, Tensor[][]> grads, Tensor t, Tensor grad) |
|
|
private static void _SetGrad(Dictionary<string, Tensor[][]> grads, Tensor t, Tensor grad) |
|
|
{ |
|
|
{ |
|
|
var op = t.op; |
|
|
var op = t.op; |
|
|
Tensor[][] op_grads = null; |
|
|
|
|
|
if (!grads.ContainsKey(op.Name)) |
|
|
|
|
|
|
|
|
Tensor[][] op_grads = grads.ContainsKey(op.Name) ? grads[op.Name] : null; |
|
|
|
|
|
if (op_grads == null) |
|
|
{ |
|
|
{ |
|
|
op_grads = op.outputs.Select(x => new Tensor[1]).ToArray(); |
|
|
op_grads = op.outputs.Select(x => new Tensor[1]).ToArray(); |
|
|
grads[op.Name] = op_grads; |
|
|
grads[op.Name] = op_grads; |
|
|
} |
|
|
} |
|
|
var t_grads = op_grads[t.value_index]; |
|
|
var t_grads = op_grads[t.value_index]; |
|
|
// t_grads[0] = grad; |
|
|
|
|
|
|
|
|
t_grads[0] = grad; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
/// <summary> |
|
|
/// <summary> |
|
@@ -348,7 +387,7 @@ namespace Tensorflow |
|
|
// Clear the boolean so we won't add the inputs again. |
|
|
// Clear the boolean so we won't add the inputs again. |
|
|
reached_ops.Remove(op); |
|
|
reached_ops.Remove(op); |
|
|
foreach (var inp in _NonEagerInputs(op, xs)) |
|
|
foreach (var inp in _NonEagerInputs(op, xs)) |
|
|
queue.Enqueue((inp as Tensor).op); |
|
|
|
|
|
|
|
|
queue.Enqueue(inp.op); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
// X in between_ops iff X is on a path of zero or more backpropagatable tensors |
|
|
// X in between_ops iff X is on a path of zero or more backpropagatable tensors |
|
@@ -363,19 +402,22 @@ namespace Tensorflow |
|
|
foreach(Tensor x in _NonEagerInputs(op, xs)) |
|
|
foreach(Tensor x in _NonEagerInputs(op, xs)) |
|
|
{ |
|
|
{ |
|
|
if (between_ops.Contains(x.op)) |
|
|
if (between_ops.Contains(x.op)) |
|
|
if (pending_count.ContainsKey(x.op.Name)) |
|
|
|
|
|
pending_count[x.op.Name] += 1; |
|
|
|
|
|
else |
|
|
|
|
|
pending_count[x.op.Name] = 1; |
|
|
|
|
|
|
|
|
{ |
|
|
|
|
|
if (!pending_count.ContainsKey(x.op.Name)) |
|
|
|
|
|
pending_count[x.op.Name] = 0; |
|
|
|
|
|
|
|
|
|
|
|
pending_count[x.op.Name] += 1; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
return (reachable_to_ops.ToArray(), pending_count, loop_state); |
|
|
return (reachable_to_ops.ToArray(), pending_count, loop_state); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
private static InputList _NonEagerInputs(Operation op, Tensor[] xs) |
|
|
|
|
|
|
|
|
private static IEnumerable<Tensor> _NonEagerInputs(Operation op, Tensor[] xs) |
|
|
{ |
|
|
{ |
|
|
return op.inputs; |
|
|
|
|
|
|
|
|
for (int i = 0; i < op.inputs.Length; i++) |
|
|
|
|
|
yield return op.inputs[i]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
/// <summary> |
|
|
/// <summary> |
|
|