|
|
@@ -61,7 +61,7 @@ namespace Tensorflow |
|
|
|
string grad_scope = scope; |
|
|
|
// Get a uid for this call to gradients that can be used to help |
|
|
|
// cluster ops for compilation. |
|
|
|
var gradient_uid = ops.get_default_graph().unique_name("uid"); |
|
|
|
var gradient_uid = curr_graph.unique_name("uid"); |
|
|
|
ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); |
|
|
|
xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); |
|
|
|
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); |
|
|
@@ -80,7 +80,7 @@ namespace Tensorflow |
|
|
|
var to_ops = ys.Select(x => x.op).ToList(); |
|
|
|
var from_ops = xs.Select(x => x.op).ToList(); |
|
|
|
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); |
|
|
|
(var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); |
|
|
|
var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); |
|
|
|
|
|
|
|
foreach (var (y, grad_y) in zip(ys, grad_ys)) |
|
|
|
_SetGrad(grads, y, grad_y); |
|
|
@@ -168,7 +168,7 @@ namespace Tensorflow |
|
|
|
{ |
|
|
|
if (in_grad != null) |
|
|
|
{ |
|
|
|
if (in_grad is Tensor && |
|
|
|
if (!(in_grad is null) && |
|
|
|
in_grad.Tag == null && // maybe a IndexedSlice |
|
|
|
t_in.dtype != TF_DataType.TF_RESOURCE) |
|
|
|
{ |
|
|
|