Browse Source

Adding `Cumsum` operation (#322)

Unit testing the operation too.
tags/v0.12
Antonio Haiping 6 years ago
parent
commit
f8912135a7
4 changed files with 62 additions and 0 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +7
    -0
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  3. +11
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  4. +41
    -0
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -348,6 +348,9 @@ namespace Tensorflow
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
=> math_ops.cast(x, dtype, name); => math_ops.cast(x, dtype, name);


public static Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null)
=> math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name);

public static Tensor argmax(Tensor input, int axis = -1, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64) public static Tensor argmax(Tensor input, int axis = -1, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64)
=> gen_math_ops.arg_max(input, axis, name: name, output_type: output_type); => gen_math_ops.arg_max(input, axis, name: name, output_type: output_type);




+ 7
- 0
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -238,6 +238,13 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }
public static Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Cumsum", name, args: new { x, axis, exclusive, reverse });
return _op.outputs[0];
}
/// <summary> /// <summary>
/// Computes the sum along segments of a tensor. /// Computes the sum along segments of a tensor.
/// </summary> /// </summary>


+ 11
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -80,6 +80,17 @@ namespace Tensorflow
}); });
} }


public static Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null)
{
return with(ops.name_scope(name, "Cumsum", new {x}), scope =>
{
name = scope;
x = ops.convert_to_tensor(x, name: "x");

return gen_math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name);
});
}

/// <summary> /// <summary>
/// Computes Psi, the derivative of Lgamma (the log of the absolute value of /// Computes Psi, the derivative of Lgamma (the log of the absolute value of
/// `Gamma(x)`), element-wise. /// `Gamma(x)`), element-wise.


+ 41
- 0
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -89,6 +89,47 @@ namespace TensorFlowNET.UnitTest
} }
} }


[TestMethod]
public void cumSumTest()
{
var a = tf.constant(new[] { 1, 1, 2, 3, 4, 5 });
var b = tf.cumsum(a);
var check = np.array(1, 2, 4, 7, 11, 16);

using (var sess = tf.Session())
{
var o = sess.run(b);
Assert.IsTrue(o.array_equal(check));
}

b = tf.cumsum(a, exclusive: true);
check = np.array(0, 1, 2, 4, 7, 11);

using (var sess = tf.Session())
{
var o = sess.run(b);
Assert.IsTrue(o.array_equal(check));
}

b = tf.cumsum(a, reverse: true);
check = np.array(16, 15, 14, 12, 9, 5);

using (var sess = tf.Session())
{
var o = sess.run(b);
Assert.IsTrue(o.array_equal(check));
}

b = tf.cumsum(a, exclusive:true, reverse: true);
check = np.array(15, 14, 12, 9, 5, 0);

using (var sess = tf.Session())
{
var o = sess.run(b);
Assert.IsTrue(o.array_equal(check));
}
}

[TestMethod] [TestMethod]
public void addOpTests() public void addOpTests()
{ {


Loading…
Cancel
Save