diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs
index 1f24afb0..94e434cb 100644
--- a/src/TensorFlowNET.Core/Gradients/math_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs
@@ -36,7 +36,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { gen_ops.mul(grad, gen_math_ops.sign(x)) };
}
- [RegisterGradient("Add")]
+ [RegisterGradient("AddV2")]
public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
{
var x = op.inputs[0];
@@ -107,7 +107,9 @@ namespace Tensorflow.Gradients
var y = op.outputs[0]; // y = e^x
return tf_with(ops.control_dependencies(new Operation[] { grad }), dp => {
y = math_ops.conj(y);
- return new Tensor[] { math_ops.mul_no_nan(y, grad) };
+ // forward_compatible(2019, 9, 14)
+ // return new Tensor[] { math_ops.mul_no_nan(y, grad) };
+ return new Tensor[] { grad * y };
});
}
@@ -167,8 +169,7 @@ namespace Tensorflow.Gradients
new TF_DataType[] { tf.int32, tf.float32 }.Contains(grad.dtype))
return new Tensor[] { gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) };
- var sx = array_ops.shape(x);
- var sy = array_ops.shape(y);
+ var (sx, sy) = SmartBroadcastGradientArgs(x, y);
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
x = math_ops.conj(x);
@@ -355,8 +356,8 @@ namespace Tensorflow.Gradients
: gen_math_ops.less_equal(x, y);
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
var xgrad = array_ops.where(xmask, grad, zeros);
- var ygrad = array_ops.where(xmask, zeros, grad);
var gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx);
+ var ygrad = array_ops.where(xmask, zeros, grad);
var gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy);
return new Tensor[] { gx, gy };
}
@@ -397,14 +398,13 @@ namespace Tensorflow.Gradients
_ShapesFullySpecifiedAndEqual(x, y, grad))
return new Tensor[] { grad, -grad };
- var sx = array_ops.shape(x);
- var sy = array_ops.shape(y);
+ var (sx, sy) = SmartBroadcastGradientArgs(x, y);
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
- var r1 = gen_array_ops.reshape(math_ops.reduce_sum(grad, rx), sx);
- var r2 = gen_array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy);
+ var gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx);
+ var gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy);
- return new Tensor[] { r1, r2 };
+ return new Tensor[] { gx, gy };
}
public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad)
@@ -468,15 +468,16 @@ namespace Tensorflow.Gradients
x = math_ops.conj(x);
y = math_ops.conj(y);
- var realdiv1 = gen_math_ops.real_div(-x, y);
- var realdiv2 = gen_math_ops.real_div(realdiv1, y);
- var reduce_sum1 = math_ops.reduce_sum(grad * realdiv2, ry);
- var reshape1 = gen_array_ops.reshape(reduce_sum1, sy);
- var realdiv3 = gen_math_ops.real_div(grad, y);
- var reduce_sum2 = math_ops.reduce_sum(realdiv3, rx);
- var reshape2 = gen_array_ops.reshape(reduce_sum2, sx);
+ var reshape1 = array_ops.reshape(
+ math_ops.reduce_sum(
+ math_ops.realdiv(grad, y), rx),
+ sx);
+ var reshape2 = array_ops.reshape(
+ math_ops.reduce_sum(
+ grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry),
+ sy);
- return new Tensor[] { reshape2, reshape1 };
+ return new Tensor[] { reshape1, reshape2 };
}
[RegisterGradient("Sigmoid")]
@@ -602,14 +603,12 @@ namespace Tensorflow.Gradients
var y = op.inputs[1];
var z = op.outputs[0];
- var sx = array_ops.shape(x);
- var sy = array_ops.shape(y);
+ var (sx, sy) = SmartBroadcastGradientArgs(x, y);
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
x = math_ops.conj(x);
y = math_ops.conj(y);
z = math_ops.conj(z);
- var pow = gen_math_ops.pow(x, y - 1.0f);
- var mul = grad * y * pow;
+ var mul = grad * y * math_ops.pow(x, y - 1.0f);
var reduce_sum = math_ops.reduce_sum(mul, rx);
var gx = gen_array_ops.reshape(reduce_sum, sx);
@@ -630,5 +629,29 @@ namespace Tensorflow.Gradients
return new Tensor[] { gx, gy };
}
+
+ ///
+ /// Optimized version of `broadcast_gradient_args` that caches results.
+ ///
+ ///
+ ///
+ ///
+ private static (Tensor, Tensor) SmartBroadcastGradientArgs(Tensor x, Tensor y)
+ {
+ Tensor sx, sy;
+ if (x.TensorShape.is_fully_defined() &&
+ y.TensorShape.is_fully_defined())
+ {
+ sx = array_ops.shape(x);
+ sy = array_ops.shape(y);
+ }
+ else
+ {
+ sx = array_ops.shape_internal(x, optimize: false);
+ sy = array_ops.shape_internal(y, optimize: false);
+ }
+
+ return (sx, sy);
+ }
}
}