From 559921585defdd2732322ee37fd79a84dbb728ae Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 21 Feb 2021 10:28:40 -0600 Subject: [PATCH] Fix string_split_v2 return RaggedTensor. --- src/TensorFlowNET.Core/APIs/tf.math.cs | 22 ++++++++++ src/TensorFlowNET.Core/APIs/tf.strings.cs | 2 +- .../Operations/gen_math_ops.cs | 7 ---- src/TensorFlowNET.Core/Operations/math_ops.cs | 40 ++++++++++++++----- .../Operations/string_ops.cs | 10 ++++- .../Tensorflow.Binding.csproj | 1 + .../Tensors/Ragged/RaggedTensor.cs | 39 +++++++++++++++++- .../Tensors/Ragged/RowPartition.cs | 40 +++++++++++++++++-- .../Tensorflow.Keras.csproj | 4 ++ .../ManagedAPI/StringsApiTest.cs | 5 ++- 10 files changed, 144 insertions(+), 26 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index ff43c206..f438f870 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -32,6 +32,28 @@ namespace Tensorflow /// public Tensor erf(Tensor x, string name = null) => math_ops.erf(x, name); + + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor bincount(Tensor arr, Tensor weights = null, + Tensor minlength = null, + Tensor maxlength = null, + TF_DataType dtype = TF_DataType.TF_INT32, + string name = null, + TensorShape axis = null, + bool binary_output = false) + => math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength, + dtype: dtype, name: name, axis: axis, binary_output: binary_output); } public Tensor abs(Tensor x, string name = null) diff --git a/src/TensorFlowNET.Core/APIs/tf.strings.cs b/src/TensorFlowNET.Core/APIs/tf.strings.cs index b2c9c8a9..53a611d6 100644 --- a/src/TensorFlowNET.Core/APIs/tf.strings.cs +++ b/src/TensorFlowNET.Core/APIs/tf.strings.cs @@ -67,7 +67,7 @@ namespace Tensorflow string name = null, string @uint = "BYTE") => ops.substr(input, pos, len, @uint: @uint, name: name); - public SparseTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null) + public RaggedTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null) => ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index d990b13b..f6775ad9 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -249,13 +249,6 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor cumsum(Tensor x, T axis, bool exclusive = false, bool reverse = false, string name = null) - { - var _op = tf.OpDefLib._apply_op_helper("Cumsum", name, args: new { x, axis, exclusive, reverse }); - - return _op.outputs[0]; - } - /// /// Computes the sum along segments of a tensor. /// diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index dfff531e..ef7988fe 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -168,15 +168,12 @@ namespace Tensorflow } public static Tensor cumsum(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null) - { - return tf_with(ops.name_scope(name, "Cumsum", new { x }), scope => - { - name = scope; - x = ops.convert_to_tensor(x, name: "x"); - - return gen_math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name); - }); - } + => tf_with(ops.name_scope(name, "Cumsum", new { x }), scope => + { + name = scope; + return tf.Context.ExecuteOp("Cumsum", name, new ExecuteOpArgs(x, axis) + .SetAttributes(new { exclusive, reverse })); + }); /// /// Computes Psi, the derivative of Lgamma (the log of the absolute value of @@ -807,6 +804,31 @@ namespace Tensorflow .SetAttributes(new { adj_x, adj_y })); }); + public static Tensor bincount(Tensor arr, Tensor weights = null, + Tensor minlength = null, + Tensor maxlength = null, + TF_DataType dtype = TF_DataType.TF_INT32, + string name = null, + TensorShape axis = null, + bool binary_output = false) + => tf_with(ops.name_scope(name, "bincount"), scope => + { + name = scope; + if(!binary_output && axis == null) + { + var array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0; + var output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * (math_ops.reduce_max(arr) + 1); + if (minlength != null) + output_size = math_ops.maximum(minlength, output_size); + if (maxlength != null) + output_size = math_ops.minimum(maxlength, output_size); + var weights = constant_op.constant(new long[0], dtype: dtype); + return tf.Context.ExecuteOp("Bincount", name, new ExecuteOpArgs(arr, output_size, weights)); + } + + throw new NotImplementedException(""); + }); + /// /// Returns the complex conjugate of a complex number. /// diff --git a/src/TensorFlowNET.Core/Operations/string_ops.cs b/src/TensorFlowNET.Core/Operations/string_ops.cs index f7178194..650f9a87 100644 --- a/src/TensorFlowNET.Core/Operations/string_ops.cs +++ b/src/TensorFlowNET.Core/Operations/string_ops.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using NumSharp; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -43,7 +44,7 @@ namespace Tensorflow => tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len) .SetAttributes(new { unit = @uint })); - public SparseTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null) + public RaggedTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null) { return tf_with(ops.name_scope(name, "StringSplit"), scope => { @@ -60,7 +61,12 @@ namespace Tensorflow indices.set_shape(new TensorShape(-1, 2)); values.set_shape(new TensorShape(-1)); shape.set_shape(new TensorShape(2)); - return new SparseTensor(indices, values, shape); + + var sparse_result = new SparseTensor(indices, values, shape); + return RaggedTensor.from_value_rowids(sparse_result.values, + value_rowids: sparse_result.indices[Slice.All, 0], + nrows: sparse_result.dense_shape[0], + validate: false); }); } } diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 26cd5139..7c6e3e00 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -50,6 +50,7 @@ tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.true TRACE;DEBUG x64 + TensorFlow.NET.xml diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs index 0851a12b..9b9a9085 100644 --- a/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using System.Text; +using System.Linq; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -27,9 +28,30 @@ namespace Tensorflow /// public class RaggedTensor : CompositeTensor { - public RaggedTensor(Tensor values, RowPartition row_partition, bool validate = true) + Tensor _values; + RowPartition _row_partition; + public TF_DataType dtype => _values.dtype; + public TensorShape shape { + get + { + var nrows = _row_partition.static_nrows; + var ncols = _row_partition.static_uniform_row_length; + return new TensorShape(nrows, ncols); + } + } + public RaggedTensor(Tensor values, + bool @internal = true, + RowPartition row_partition = null) + { + _values = values; + _row_partition = row_partition; + } + + public static RaggedTensor from_row_partition(Tensor values, RowPartition row_partition, bool validate = true) + { + return new RaggedTensor(values, @internal: true, row_partition: row_partition); } /// @@ -49,8 +71,21 @@ namespace Tensorflow var row_partition = RowPartition.from_value_rowids(value_rowids, nrows: nrows, validate: validate); - return new RaggedTensor(values, row_partition, validate: validate); + return from_row_partition(values, row_partition, validate: validate); }); } + + public override string ToString() + => $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]"; + + public static implicit operator Tensor(RaggedTensor indexedSlices) + { + return indexedSlices._values; + } + + public static implicit operator RaggedTensor(Tensor tensor) + { + return tensor.Tag as RaggedTensor; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs index 587cc154..d58226d8 100644 --- a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs +++ b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs @@ -27,11 +27,35 @@ namespace Tensorflow /// public class RowPartition : CompositeTensor { + Tensor _row_splits; + Tensor _row_lengths; + Tensor _value_rowids; + Tensor _nrows; + + public int static_nrows + { + get + { + return _row_splits.shape[0] - 1; + } + } + + public int static_uniform_row_length + { + get + { + return -1; + } + } + public RowPartition(Tensor row_splits, Tensor row_lengths = null, Tensor value_rowids = null, Tensor nrows = null, Tensor uniform_row_length = null) { - + _row_splits = row_splits; + _row_lengths = row_lengths; + _value_rowids = value_rowids; + _nrows = nrows; } /// @@ -47,8 +71,18 @@ namespace Tensorflow { return tf_with(ops.name_scope(null, "RowPartitionFromValueRowIds"), scope => { - Tensor row_lengths = null; - Tensor row_splits = null; + var value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32); + var nrows_int32 = math_ops.cast(nrows, dtypes.int32); + var row_lengths = tf.math.bincount(value_rowids_int32, + minlength: nrows_int32, + maxlength: nrows_int32, + dtype: value_rowids.dtype); + var row_splits = array_ops.concat(new object[] + { + ops.convert_to_tensor(new long[] { 0 }), + tf.cumsum(row_lengths) + }, axis: 0); + return new RowPartition(row_splits, row_lengths: row_lengths, value_rowids: value_rowids, diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index 6325707d..0c50a5a1 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -49,6 +49,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac false + + Tensorflow.Keras.xml + + diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs index 82cd6eea..d98c5207 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs @@ -62,8 +62,9 @@ namespace TensorFlowNET.UnitTest.ManagedAPI [TestMethod] public void StringSplit() { - var tensor = tf.constant(new[] { "hello world", "tensorflow .net" }); - tf.strings.split(tensor); + var tensor = tf.constant(new[] { "hello world", "tensorflow .net csharp", "fsharp" }); + var ragged_tensor = tf.strings.split(tensor); + Assert.AreEqual((3, -1), ragged_tensor.shape); } } }