|
|
@@ -212,7 +212,7 @@ namespace Tensorflow.Gradients |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
var broads = SmartBroadcastGradientArgs(x, y); |
|
|
|
var broads = SmartBroadcastGradientArgs(x, y, grad); |
|
|
|
var (sx, rx, must_reduce_x) = broads[0]; |
|
|
|
var (sy, ry, must_reduce_y) = broads[1]; |
|
|
|
|
|
|
@@ -468,7 +468,7 @@ namespace Tensorflow.Gradients |
|
|
|
_ShapesFullySpecifiedAndEqual(x, y, grad)) |
|
|
|
return new Tensor[] { grad, -grad }; |
|
|
|
|
|
|
|
var broads = SmartBroadcastGradientArgs(x, y); |
|
|
|
var broads = SmartBroadcastGradientArgs(x, y, grad); |
|
|
|
var (sx, rx, must_reduce_x) = broads[0]; |
|
|
|
var (sy, ry, must_reduce_y) = broads[1]; |
|
|
|
|
|
|
@@ -718,7 +718,7 @@ namespace Tensorflow.Gradients |
|
|
|
|
|
|
|
var z = op.outputs[0]; |
|
|
|
|
|
|
|
var broads = SmartBroadcastGradientArgs(x, y); |
|
|
|
var broads = SmartBroadcastGradientArgs(x, y, grad); |
|
|
|
var (sx, rx, must_reduce_x) = broads[0]; |
|
|
|
var (sy, ry, must_reduce_y) = broads[1]; |
|
|
|
|
|
|
@@ -753,7 +753,7 @@ namespace Tensorflow.Gradients |
|
|
|
/// <param name="x"></param> |
|
|
|
/// <param name="y"></param> |
|
|
|
/// <returns></returns> |
|
|
|
private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y) |
|
|
|
private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad) |
|
|
|
{ |
|
|
|
Tensor sx, sy; |
|
|
|
if (x.TensorShape.is_fully_defined() && |
|
|
@@ -771,8 +771,8 @@ namespace Tensorflow.Gradients |
|
|
|
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); |
|
|
|
return new[] |
|
|
|
{ |
|
|
|
(sx, rx, true), |
|
|
|
(sy, ry, true) |
|
|
|
(sx, rx, !x.TensorShape.Equals(grad.TensorShape)), |
|
|
|
(sy, ry, !y.TensorShape.Equals(grad.TensorShape)) |
|
|
|
}; |
|
|
|
} |
|
|
|
} |
|
|
|