diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 3a674e83..34303bf9 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -53,6 +53,16 @@ namespace Tensorflow public Tensor boolean_mask(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) => array_ops.boolean_mask(tensor, mask, name: name, axis: axis); + /// + /// Broadcast an array for a compatible shape. + /// + /// + /// + /// + /// + public Tensor broadcast_to(Tensor input, TensorShape shape, string name = null) + => gen_array_ops.broadcast_to(input, shape, name: name); + public Tensor check_numerics(Tensor tensor, string message, string name = null) => gen_array_ops.check_numerics(tensor, message, name: name); diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index f07d2825..74e9ef10 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -27,6 +27,27 @@ namespace Tensorflow.Gradients [RegisterGradient("array_grad")] public class array_grad { + [RegisterGradient("BroadcastTo")] + public static Tensor[] _BroadcastToGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var input_value = op.inputs[0]; + var broadcast_shape = op.inputs[1]; + var input_value_shape = array_ops.shape(input_value); + var (_, reduction_axes) = gen_array_ops.broadcast_gradient_args(broadcast_shape, + input_value_shape); + var updates_grad_reshaped = math_ops.reduce_sum(grad, + axis: reduction_axes, + keepdims: true); + var updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape); + + return new Tensor[] + { + updates_grad, + null + }; + } + [RegisterGradient("ConcatV2")] public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads) { diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 49dcbc45..b3c620c0 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -58,6 +58,20 @@ namespace Tensorflow.Gradients return new Tensor[] { r1, r2 }; } + [RegisterGradient("Cumsum")] + public static Tensor[] _CumsumGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var axis = op.inputs[1]; + var exclusive = op.get_attr("exclusive"); + var reverse = op.get_attr("reverse"); + return new Tensor[] + { + math_ops.cumsum(grad, axis, exclusive: exclusive, reverse: !reverse), + null + }; + } + [RegisterGradient("DivNoNan")] public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads) { diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 36837477..847ace24 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -515,5 +515,19 @@ namespace Tensorflow return _op.outputs[0]; } + + /// + /// Broadcast an array for a compatible shape. + /// + /// + /// + /// + /// + public static Tensor broadcast_to(Tensor input, int[] shape, string name = null) + { + var _op = _op_def_lib._apply_op_helper("BroadcastTo", name, args: new { input, shape, name }); + + return _op.outputs[0]; + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 81870e5b..7192dc57 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -238,7 +238,7 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null) + public static Tensor cumsum(Tensor x, T axis, bool exclusive = false, bool reverse = false, string name = null) { var _op = _op_def_lib._apply_op_helper("Cumsum", name, args: new { x, axis, exclusive, reverse }); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 94c42ba2..d4dfc12b 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -80,7 +80,7 @@ namespace Tensorflow }); } - public static Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null) + public static Tensor cumsum(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null) { return tf_with(ops.name_scope(name, "Cumsum", new {x}), scope => { diff --git a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs index ecd69977..2fae1e5b 100644 --- a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs +++ b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs @@ -11,6 +11,39 @@ namespace TensorFlowNET.UnitTest.gradients_test [TestClass] public class GradientsTest : PythonTest { + [TestMethod] + public void BroadcastToGrad() + { + var graph = tf.Graph().as_default(); + + var x = tf.constant(2, dtype: dtypes.float32); + var y = tf.broadcast_to(x, (2, 4, 3)); + var grad = tf.gradients(y, x); + + using (var sess = tf.Session(graph)) + { + float result = sess.run(grad[0]); + Assert.AreEqual(result, 24.0f); + } + } + + [TestMethod] + public void CumsumGrad() + { + var graph = tf.Graph().as_default(); + + var x = tf.constant(2, dtype: dtypes.float32); + var y = tf.broadcast_to(x, (2, 4, 3)); + var z = tf.cumsum(y, axis: 1); + var grad = tf.gradients(z, x); + + using (var sess = tf.Session(graph)) + { + float result = sess.run(grad[0]); + Assert.AreEqual(result, 60.0f); + } + } + [Ignore("TODO")] [TestMethod] public void testGradients()