Browse Source

add SmartBroadcastGradientArgs

tags/v0.20
Oceania2018 5 years ago
parent
commit
11017d4f14
1 changed files with 45 additions and 22 deletions
  1. +45
    -22
      src/TensorFlowNET.Core/Gradients/math_grad.cs

+ 45
- 22
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow.Gradients
return new Tensor[] { gen_ops.mul(grad, gen_math_ops.sign(x)) };
}

[RegisterGradient("Add")]
[RegisterGradient("AddV2")]
public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
{
var x = op.inputs[0];
@@ -107,7 +107,9 @@ namespace Tensorflow.Gradients
var y = op.outputs[0]; // y = e^x
return tf_with(ops.control_dependencies(new Operation[] { grad }), dp => {
y = math_ops.conj(y);
return new Tensor[] { math_ops.mul_no_nan(y, grad) };
// forward_compatible(2019, 9, 14)
// return new Tensor[] { math_ops.mul_no_nan(y, grad) };
return new Tensor[] { grad * y };
});
}

@@ -167,8 +169,7 @@ namespace Tensorflow.Gradients
new TF_DataType[] { tf.int32, tf.float32 }.Contains(grad.dtype))
return new Tensor[] { gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) };

var sx = array_ops.shape(x);
var sy = array_ops.shape(y);
var (sx, sy) = SmartBroadcastGradientArgs(x, y);
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);

x = math_ops.conj(x);
@@ -355,8 +356,8 @@ namespace Tensorflow.Gradients
: gen_math_ops.less_equal(x, y);
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
var xgrad = array_ops.where(xmask, grad, zeros);
var ygrad = array_ops.where(xmask, zeros, grad);
var gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx);
var ygrad = array_ops.where(xmask, zeros, grad);
var gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy);
return new Tensor[] { gx, gy };
}
@@ -397,14 +398,13 @@ namespace Tensorflow.Gradients
_ShapesFullySpecifiedAndEqual(x, y, grad))
return new Tensor[] { grad, -grad };

var sx = array_ops.shape(x);
var sy = array_ops.shape(y);
var (sx, sy) = SmartBroadcastGradientArgs(x, y);
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);

var r1 = gen_array_ops.reshape(math_ops.reduce_sum(grad, rx), sx);
var r2 = gen_array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy);
var gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx);
var gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy);

return new Tensor[] { r1, r2 };
return new Tensor[] { gx, gy };
}

public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad)
@@ -468,15 +468,16 @@ namespace Tensorflow.Gradients
x = math_ops.conj(x);
y = math_ops.conj(y);

var realdiv1 = gen_math_ops.real_div(-x, y);
var realdiv2 = gen_math_ops.real_div(realdiv1, y);
var reduce_sum1 = math_ops.reduce_sum(grad * realdiv2, ry);
var reshape1 = gen_array_ops.reshape(reduce_sum1, sy);
var realdiv3 = gen_math_ops.real_div(grad, y);
var reduce_sum2 = math_ops.reduce_sum(realdiv3, rx);
var reshape2 = gen_array_ops.reshape(reduce_sum2, sx);
var reshape1 = array_ops.reshape(
math_ops.reduce_sum(
math_ops.realdiv(grad, y), rx),
sx);
var reshape2 = array_ops.reshape(
math_ops.reduce_sum(
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry),
sy);

return new Tensor[] { reshape2, reshape1 };
return new Tensor[] { reshape1, reshape2 };
}

[RegisterGradient("Sigmoid")]
@@ -602,14 +603,12 @@ namespace Tensorflow.Gradients
var y = op.inputs[1];
var z = op.outputs[0];

var sx = array_ops.shape(x);
var sy = array_ops.shape(y);
var (sx, sy) = SmartBroadcastGradientArgs(x, y);
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
x = math_ops.conj(x);
y = math_ops.conj(y);
z = math_ops.conj(z);
var pow = gen_math_ops.pow(x, y - 1.0f);
var mul = grad * y * pow;
var mul = grad * y * math_ops.pow(x, y - 1.0f);
var reduce_sum = math_ops.reduce_sum(mul, rx);
var gx = gen_array_ops.reshape(reduce_sum, sx);

@@ -630,5 +629,29 @@ namespace Tensorflow.Gradients

return new Tensor[] { gx, gy };
}

/// <summary>
/// Optimized version of `broadcast_gradient_args` that caches results.
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <returns></returns>
private static (Tensor, Tensor) SmartBroadcastGradientArgs(Tensor x, Tensor y)
{
Tensor sx, sy;
if (x.TensorShape.is_fully_defined() &&
y.TensorShape.is_fully_defined())
{
sx = array_ops.shape(x);
sy = array_ops.shape(y);
}
else
{
sx = array_ops.shape_internal(x, optimize: false);
sy = array_ops.shape_internal(y, optimize: false);
}

return (sx, sy);
}
}
}

Loading…
Cancel
Save