Browse Source

add CumsumGrad, BroadcastToGrad

tags/v0.12
Oceania2018 6 years ago
parent
commit
5a414f0ddc
7 changed files with 94 additions and 2 deletions
  1. +10
    -0
      src/TensorFlowNET.Core/APIs/tf.array.cs
  2. +21
    -0
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  3. +14
    -0
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  4. +14
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Operations/math_ops.cs
  7. +33
    -0
      test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs

+ 10
- 0
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -53,6 +53,16 @@ namespace Tensorflow
public Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0)
=> array_ops.boolean_mask(tensor, mask, name: name, axis: axis);

/// <summary>
/// Broadcast an array for a compatible shape.
/// </summary>
/// <param name="input"></param>
/// <param name="shape"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor broadcast_to(Tensor input, TensorShape shape, string name = null)
=> gen_array_ops.broadcast_to(input, shape, name: name);

public Tensor check_numerics(Tensor tensor, string message, string name = null)
=> gen_array_ops.check_numerics(tensor, message, name: name);



+ 21
- 0
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -27,6 +27,27 @@ namespace Tensorflow.Gradients
[RegisterGradient("array_grad")]
public class array_grad
{
[RegisterGradient("BroadcastTo")]
public static Tensor[] _BroadcastToGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var input_value = op.inputs[0];
var broadcast_shape = op.inputs[1];
var input_value_shape = array_ops.shape(input_value);
var (_, reduction_axes) = gen_array_ops.broadcast_gradient_args(broadcast_shape,
input_value_shape);
var updates_grad_reshaped = math_ops.reduce_sum(grad,
axis: reduction_axes,
keepdims: true);
var updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape);

return new Tensor[]
{
updates_grad,
null
};
}

[RegisterGradient("ConcatV2")]
public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads)
{


+ 14
- 0
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -58,6 +58,20 @@ namespace Tensorflow.Gradients
return new Tensor[] { r1, r2 };
}

[RegisterGradient("Cumsum")]
public static Tensor[] _CumsumGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var axis = op.inputs[1];
var exclusive = op.get_attr<bool>("exclusive");
var reverse = op.get_attr<bool>("reverse");
return new Tensor[]
{
math_ops.cumsum(grad, axis, exclusive: exclusive, reverse: !reverse),
null
};
}

[RegisterGradient("DivNoNan")]
public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
{


+ 14
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -515,5 +515,19 @@ namespace Tensorflow

return _op.outputs[0];
}

/// <summary>
/// Broadcast an array for a compatible shape.
/// </summary>
/// <param name="input"></param>
/// <param name="shape"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor broadcast_to(Tensor input, int[] shape, string name = null)
{
var _op = _op_def_lib._apply_op_helper("BroadcastTo", name, args: new { input, shape, name });

return _op.outputs[0];
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -238,7 +238,7 @@ namespace Tensorflow
return _op.outputs[0];
}
public static Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null)
public static Tensor cumsum<T>(Tensor x, T axis, bool exclusive = false, bool reverse = false, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Cumsum", name, args: new { x, axis, exclusive, reverse });


+ 1
- 1
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -80,7 +80,7 @@ namespace Tensorflow
});
}

public static Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null)
public static Tensor cumsum<T>(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null)
{
return tf_with(ops.name_scope(name, "Cumsum", new {x}), scope =>
{


+ 33
- 0
test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs View File

@@ -11,6 +11,39 @@ namespace TensorFlowNET.UnitTest.gradients_test
[TestClass]
public class GradientsTest : PythonTest
{
[TestMethod]
public void BroadcastToGrad()
{
var graph = tf.Graph().as_default();

var x = tf.constant(2, dtype: dtypes.float32);
var y = tf.broadcast_to(x, (2, 4, 3));
var grad = tf.gradients(y, x);

using (var sess = tf.Session(graph))
{
float result = sess.run(grad[0]);
Assert.AreEqual(result, 24.0f);
}
}

[TestMethod]
public void CumsumGrad()
{
var graph = tf.Graph().as_default();

var x = tf.constant(2, dtype: dtypes.float32);
var y = tf.broadcast_to(x, (2, 4, 3));
var z = tf.cumsum(y, axis: 1);
var grad = tf.gradients(z, x);

using (var sess = tf.Session(graph))
{
float result = sess.run(grad[0]);
Assert.AreEqual(result, 60.0f);
}
}

[Ignore("TODO")]
[TestMethod]
public void testGradients()


Loading…
Cancel
Save