diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index ff43c206..f438f870 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -32,6 +32,28 @@ namespace Tensorflow
///
public Tensor erf(Tensor x, string name = null)
=> math_ops.erf(x, name);
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Tensor bincount(Tensor arr, Tensor weights = null,
+ Tensor minlength = null,
+ Tensor maxlength = null,
+ TF_DataType dtype = TF_DataType.TF_INT32,
+ string name = null,
+ TensorShape axis = null,
+ bool binary_output = false)
+ => math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength,
+ dtype: dtype, name: name, axis: axis, binary_output: binary_output);
}
public Tensor abs(Tensor x, string name = null)
diff --git a/src/TensorFlowNET.Core/APIs/tf.strings.cs b/src/TensorFlowNET.Core/APIs/tf.strings.cs
index b2c9c8a9..53a611d6 100644
--- a/src/TensorFlowNET.Core/APIs/tf.strings.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.strings.cs
@@ -67,7 +67,7 @@ namespace Tensorflow
string name = null, string @uint = "BYTE")
=> ops.substr(input, pos, len, @uint: @uint, name: name);
- public SparseTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
+ public RaggedTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
=> ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index d990b13b..f6775ad9 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -249,13 +249,6 @@ namespace Tensorflow
return _op.outputs[0];
}
- public static Tensor cumsum(Tensor x, T axis, bool exclusive = false, bool reverse = false, string name = null)
- {
- var _op = tf.OpDefLib._apply_op_helper("Cumsum", name, args: new { x, axis, exclusive, reverse });
-
- return _op.outputs[0];
- }
-
///
/// Computes the sum along segments of a tensor.
///
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index dfff531e..ef7988fe 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -168,15 +168,12 @@ namespace Tensorflow
}
public static Tensor cumsum(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null)
- {
- return tf_with(ops.name_scope(name, "Cumsum", new { x }), scope =>
- {
- name = scope;
- x = ops.convert_to_tensor(x, name: "x");
-
- return gen_math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name);
- });
- }
+ => tf_with(ops.name_scope(name, "Cumsum", new { x }), scope =>
+ {
+ name = scope;
+ return tf.Context.ExecuteOp("Cumsum", name, new ExecuteOpArgs(x, axis)
+ .SetAttributes(new { exclusive, reverse }));
+ });
///
/// Computes Psi, the derivative of Lgamma (the log of the absolute value of
@@ -807,6 +804,31 @@ namespace Tensorflow
.SetAttributes(new { adj_x, adj_y }));
});
+ public static Tensor bincount(Tensor arr, Tensor weights = null,
+ Tensor minlength = null,
+ Tensor maxlength = null,
+ TF_DataType dtype = TF_DataType.TF_INT32,
+ string name = null,
+ TensorShape axis = null,
+ bool binary_output = false)
+ => tf_with(ops.name_scope(name, "bincount"), scope =>
+ {
+ name = scope;
+ if(!binary_output && axis == null)
+ {
+ var array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0;
+ var output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * (math_ops.reduce_max(arr) + 1);
+ if (minlength != null)
+ output_size = math_ops.maximum(minlength, output_size);
+ if (maxlength != null)
+ output_size = math_ops.minimum(maxlength, output_size);
+ var weights = constant_op.constant(new long[0], dtype: dtype);
+ return tf.Context.ExecuteOp("Bincount", name, new ExecuteOpArgs(arr, output_size, weights));
+ }
+
+ throw new NotImplementedException("");
+ });
+
///
/// Returns the complex conjugate of a complex number.
///
diff --git a/src/TensorFlowNET.Core/Operations/string_ops.cs b/src/TensorFlowNET.Core/Operations/string_ops.cs
index f7178194..650f9a87 100644
--- a/src/TensorFlowNET.Core/Operations/string_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/string_ops.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using NumSharp;
using Tensorflow.Framework;
using static Tensorflow.Binding;
@@ -43,7 +44,7 @@ namespace Tensorflow
=> tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len)
.SetAttributes(new { unit = @uint }));
- public SparseTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
+ public RaggedTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
{
return tf_with(ops.name_scope(name, "StringSplit"), scope =>
{
@@ -60,7 +61,12 @@ namespace Tensorflow
indices.set_shape(new TensorShape(-1, 2));
values.set_shape(new TensorShape(-1));
shape.set_shape(new TensorShape(2));
- return new SparseTensor(indices, values, shape);
+
+ var sparse_result = new SparseTensor(indices, values, shape);
+ return RaggedTensor.from_value_rowids(sparse_result.values,
+ value_rowids: sparse_result.indices[Slice.All, 0],
+ nrows: sparse_result.dense_shape[0],
+ validate: false);
});
}
}
diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
index 26cd5139..7c6e3e00 100644
--- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
+++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
@@ -50,6 +50,7 @@ tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.true
TRACE;DEBUG
x64
+ TensorFlow.NET.xml
diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
index 0851a12b..9b9a9085 100644
--- a/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Text;
+using System.Linq;
using Tensorflow.Framework;
using static Tensorflow.Binding;
@@ -27,9 +28,30 @@ namespace Tensorflow
///
public class RaggedTensor : CompositeTensor
{
- public RaggedTensor(Tensor values, RowPartition row_partition, bool validate = true)
+ Tensor _values;
+ RowPartition _row_partition;
+ public TF_DataType dtype => _values.dtype;
+ public TensorShape shape
{
+ get
+ {
+ var nrows = _row_partition.static_nrows;
+ var ncols = _row_partition.static_uniform_row_length;
+ return new TensorShape(nrows, ncols);
+ }
+ }
+ public RaggedTensor(Tensor values,
+ bool @internal = true,
+ RowPartition row_partition = null)
+ {
+ _values = values;
+ _row_partition = row_partition;
+ }
+
+ public static RaggedTensor from_row_partition(Tensor values, RowPartition row_partition, bool validate = true)
+ {
+ return new RaggedTensor(values, @internal: true, row_partition: row_partition);
}
///
@@ -49,8 +71,21 @@ namespace Tensorflow
var row_partition = RowPartition.from_value_rowids(value_rowids,
nrows: nrows,
validate: validate);
- return new RaggedTensor(values, row_partition, validate: validate);
+ return from_row_partition(values, row_partition, validate: validate);
});
}
+
+ public override string ToString()
+ => $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]";
+
+ public static implicit operator Tensor(RaggedTensor indexedSlices)
+ {
+ return indexedSlices._values;
+ }
+
+ public static implicit operator RaggedTensor(Tensor tensor)
+ {
+ return tensor.Tag as RaggedTensor;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
index 587cc154..d58226d8 100644
--- a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
+++ b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
@@ -27,11 +27,35 @@ namespace Tensorflow
///
public class RowPartition : CompositeTensor
{
+ Tensor _row_splits;
+ Tensor _row_lengths;
+ Tensor _value_rowids;
+ Tensor _nrows;
+
+ public int static_nrows
+ {
+ get
+ {
+ return _row_splits.shape[0] - 1;
+ }
+ }
+
+ public int static_uniform_row_length
+ {
+ get
+ {
+ return -1;
+ }
+ }
+
public RowPartition(Tensor row_splits,
Tensor row_lengths = null, Tensor value_rowids = null, Tensor nrows = null,
Tensor uniform_row_length = null)
{
-
+ _row_splits = row_splits;
+ _row_lengths = row_lengths;
+ _value_rowids = value_rowids;
+ _nrows = nrows;
}
///
@@ -47,8 +71,18 @@ namespace Tensorflow
{
return tf_with(ops.name_scope(null, "RowPartitionFromValueRowIds"), scope =>
{
- Tensor row_lengths = null;
- Tensor row_splits = null;
+ var value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32);
+ var nrows_int32 = math_ops.cast(nrows, dtypes.int32);
+ var row_lengths = tf.math.bincount(value_rowids_int32,
+ minlength: nrows_int32,
+ maxlength: nrows_int32,
+ dtype: value_rowids.dtype);
+ var row_splits = array_ops.concat(new object[]
+ {
+ ops.convert_to_tensor(new long[] { 0 }),
+ tf.cumsum(row_lengths)
+ }, axis: 0);
+
return new RowPartition(row_splits,
row_lengths: row_lengths,
value_rowids: value_rowids,
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index 6325707d..0c50a5a1 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -49,6 +49,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
false
+
+ Tensorflow.Keras.xml
+
+
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs
index 82cd6eea..d98c5207 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs
@@ -62,8 +62,9 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
[TestMethod]
public void StringSplit()
{
- var tensor = tf.constant(new[] { "hello world", "tensorflow .net" });
- tf.strings.split(tensor);
+ var tensor = tf.constant(new[] { "hello world", "tensorflow .net csharp", "fsharp" });
+ var ragged_tensor = tf.strings.split(tensor);
+ Assert.AreEqual((3, -1), ragged_tensor.shape);
}
}
}