diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 1f24afb0..94e434cb 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -36,7 +36,7 @@ namespace Tensorflow.Gradients return new Tensor[] { gen_ops.mul(grad, gen_math_ops.sign(x)) }; } - [RegisterGradient("Add")] + [RegisterGradient("AddV2")] public static Tensor[] _AddGrad(Operation op, Tensor[] grads) { var x = op.inputs[0]; @@ -107,7 +107,9 @@ namespace Tensorflow.Gradients var y = op.outputs[0]; // y = e^x return tf_with(ops.control_dependencies(new Operation[] { grad }), dp => { y = math_ops.conj(y); - return new Tensor[] { math_ops.mul_no_nan(y, grad) }; + // forward_compatible(2019, 9, 14) + // return new Tensor[] { math_ops.mul_no_nan(y, grad) }; + return new Tensor[] { grad * y }; }); } @@ -167,8 +169,7 @@ namespace Tensorflow.Gradients new TF_DataType[] { tf.int32, tf.float32 }.Contains(grad.dtype)) return new Tensor[] { gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) }; - var sx = array_ops.shape(x); - var sy = array_ops.shape(y); + var (sx, sy) = SmartBroadcastGradientArgs(x, y); var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); x = math_ops.conj(x); @@ -355,8 +356,8 @@ namespace Tensorflow.Gradients : gen_math_ops.less_equal(x, y); var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); var xgrad = array_ops.where(xmask, grad, zeros); - var ygrad = array_ops.where(xmask, zeros, grad); var gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx); + var ygrad = array_ops.where(xmask, zeros, grad); var gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy); return new Tensor[] { gx, gy }; } @@ -397,14 +398,13 @@ namespace Tensorflow.Gradients _ShapesFullySpecifiedAndEqual(x, y, grad)) return new Tensor[] { grad, -grad }; - var sx = array_ops.shape(x); - var sy = array_ops.shape(y); + var (sx, sy) = SmartBroadcastGradientArgs(x, y); var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); - var r1 = gen_array_ops.reshape(math_ops.reduce_sum(grad, rx), sx); - var r2 = gen_array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy); + var gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx); + var gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy); - return new Tensor[] { r1, r2 }; + return new Tensor[] { gx, gy }; } public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad) @@ -468,15 +468,16 @@ namespace Tensorflow.Gradients x = math_ops.conj(x); y = math_ops.conj(y); - var realdiv1 = gen_math_ops.real_div(-x, y); - var realdiv2 = gen_math_ops.real_div(realdiv1, y); - var reduce_sum1 = math_ops.reduce_sum(grad * realdiv2, ry); - var reshape1 = gen_array_ops.reshape(reduce_sum1, sy); - var realdiv3 = gen_math_ops.real_div(grad, y); - var reduce_sum2 = math_ops.reduce_sum(realdiv3, rx); - var reshape2 = gen_array_ops.reshape(reduce_sum2, sx); + var reshape1 = array_ops.reshape( + math_ops.reduce_sum( + math_ops.realdiv(grad, y), rx), + sx); + var reshape2 = array_ops.reshape( + math_ops.reduce_sum( + grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), + sy); - return new Tensor[] { reshape2, reshape1 }; + return new Tensor[] { reshape1, reshape2 }; } [RegisterGradient("Sigmoid")] @@ -602,14 +603,12 @@ namespace Tensorflow.Gradients var y = op.inputs[1]; var z = op.outputs[0]; - var sx = array_ops.shape(x); - var sy = array_ops.shape(y); + var (sx, sy) = SmartBroadcastGradientArgs(x, y); var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); x = math_ops.conj(x); y = math_ops.conj(y); z = math_ops.conj(z); - var pow = gen_math_ops.pow(x, y - 1.0f); - var mul = grad * y * pow; + var mul = grad * y * math_ops.pow(x, y - 1.0f); var reduce_sum = math_ops.reduce_sum(mul, rx); var gx = gen_array_ops.reshape(reduce_sum, sx); @@ -630,5 +629,29 @@ namespace Tensorflow.Gradients return new Tensor[] { gx, gy }; } + + /// + /// Optimized version of `broadcast_gradient_args` that caches results. + /// + /// + /// + /// + private static (Tensor, Tensor) SmartBroadcastGradientArgs(Tensor x, Tensor y) + { + Tensor sx, sy; + if (x.TensorShape.is_fully_defined() && + y.TensorShape.is_fully_defined()) + { + sx = array_ops.shape(x); + sy = array_ops.shape(y); + } + else + { + sx = array_ops.shape_internal(x, optimize: false); + sy = array_ops.shape_internal(y, optimize: false); + } + + return (sx, sy); + } } }