|
|
@@ -53,10 +53,76 @@ namespace Tensorflow |
|
|
|
using (var namescope = new ops.name_scope<Tensor>(name, "gradients", values: all)) |
|
|
|
{ |
|
|
|
grad_scope = namescope; |
|
|
|
// Get a uid for this call to gradients that can be used to help |
|
|
|
// cluster ops for compilation. |
|
|
|
var gradient_uid = ops.get_default_graph().unique_name("uid"); |
|
|
|
|
|
|
|
var to_ops = ys1.Select(x => x.op).ToList(); |
|
|
|
var from_ops = xs1.Select(x => x.op).ToList(); |
|
|
|
var stop_gradient_ops = stop_gradients1.Select(x => x.op).ToList(); |
|
|
|
_PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs1); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// |
|
|
|
/// </summary> |
|
|
|
/// <param name="grad_ys"></param> |
|
|
|
/// <param name="ys"></param> |
|
|
|
/// <param name="colocate_gradients_with_ops"></param> |
|
|
|
/// <param name="gradient_uid"></param> |
|
|
|
private void _DefaultGradYs(List<Tensor> grad_ys, List<Tensor> ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__") |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Initialize the pending count for ops between two lists of Operations. |
|
|
|
/// 'pending_count[op]' indicates the number of backprop inputs |
|
|
|
/// to this operation. |
|
|
|
/// </summary> |
|
|
|
/// <param name="to_ops"></param> |
|
|
|
/// <param name="from_ops"></param> |
|
|
|
/// <param name="colocate_gradients_with_ops"></param> |
|
|
|
/// <param name="func_graphs"></param> |
|
|
|
/// <param name="xs"></param> |
|
|
|
private static void _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, List<Tensor> xs) |
|
|
|
{ |
|
|
|
List<Operation> reached_ops = new List<Operation>(); |
|
|
|
_MarkReachedOps(from_ops, reached_ops, func_graphs); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Mark all ops reached from "from_ops" |
|
|
|
/// </summary> |
|
|
|
/// <param name="from_ops"></param> |
|
|
|
/// <param name="reached_ops"></param> |
|
|
|
/// <param name="func_graphs"></param> |
|
|
|
private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<object> func_graphs) |
|
|
|
{ |
|
|
|
foreach(var op in from_ops) |
|
|
|
{ |
|
|
|
reached_ops.Add(op); |
|
|
|
foreach(var output in op.outputs) |
|
|
|
{ |
|
|
|
reached_ops.AddRange(_Consumers(output, func_graphs)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
reached_ops.Reverse(); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Returns the consumers of t, crossing closure boundaries where necessary. |
|
|
|
/// </summary> |
|
|
|
/// <param name="t"></param> |
|
|
|
/// <param name="func_graphs"></param> |
|
|
|
private static List<Operation> _Consumers(Tensor t, List<object> func_graphs) |
|
|
|
{ |
|
|
|
var consumers = t.consumers(); |
|
|
|
return consumers; |
|
|
|
} |
|
|
|
|
|
|
|
private static List<Tensor> _AsList(object ys) |
|
|
|
{ |
|
|
|
List<Tensor> ret = null; |
|
|
|