diff --git a/src/TensorFlowNET.Core/APIs/tf.strings.cs b/src/TensorFlowNET.Core/APIs/tf.strings.cs index 53a611d6..38a40eb4 100644 --- a/src/TensorFlowNET.Core/APIs/tf.strings.cs +++ b/src/TensorFlowNET.Core/APIs/tf.strings.cs @@ -67,8 +67,26 @@ namespace Tensorflow string name = null, string @uint = "BYTE") => ops.substr(input, pos, len, @uint: @uint, name: name); + /// + /// String lengths of `input`. + /// + /// + /// + /// + /// + public Tensor string_length(Tensor input, string name = null, string unit = "BYTE") + => ops.string_length(input, name: name, unit: unit); + public RaggedTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null) => ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name); + + public (RaggedTensor, RaggedTensor) unicode_decode_with_offsets(Tensor input, string input_encoding, + string errors = "replace", int replacement_char = 0xFFFD, + bool replace_control_characters = false, string name = null) + => ops.unicode_decode_with_offsets(input, input_encoding, errors, + replacement_char: replacement_char, + replace_control_characters: replace_control_characters, + name: name); } } } diff --git a/src/TensorFlowNET.Core/Operations/string_ops.cs b/src/TensorFlowNET.Core/Operations/string_ops.cs index 650f9a87..2d7c54c7 100644 --- a/src/TensorFlowNET.Core/Operations/string_ops.cs +++ b/src/TensorFlowNET.Core/Operations/string_ops.cs @@ -44,6 +44,22 @@ namespace Tensorflow => tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len) .SetAttributes(new { unit = @uint })); + /// + /// Computes the length of each string given in the input tensor. + /// + /// + /// + /// + /// + public Tensor string_length(Tensor input, string name = null, string unit = "BYTE") + => tf.Context.ExecuteOp("StringLength", name, new ExecuteOpArgs(input) + { + GetGradientAttrs = op => new + { + unit = op.get_attr("unit") + } + }.SetAttributes(new { unit })); + public RaggedTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null) { return tf_with(ops.name_scope(name, "StringSplit"), scope => @@ -69,5 +85,49 @@ namespace Tensorflow validate: false); }); } + + public (RaggedTensor, RaggedTensor) unicode_decode_with_offsets(Tensor input, string input_encoding, string errors, + int replacement_char = 0xFFFD, bool replace_control_characters = false, string name = null) + { + return tf_with(ops.name_scope(name, "UnicodeDecodeWithOffsets"), scope => + { + var (codepoints, byte_start_offsets) = _unicode_decode(input, input_encoding, errors, + replacement_char, replace_control_characters, + with_offsets: true, name: name); + return (codepoints, byte_start_offsets); + }); + } + + (RaggedTensor, RaggedTensor) _unicode_decode(Tensor input, string input_encoding, string errors, int replacement_char, + bool replace_control_characters, bool with_offsets, string name = null) + { + if (with_offsets) + { + var flat_result = tf.Context.ExecuteOp("UnicodeDecodeWithOffsets", name, new ExecuteOpArgs(input) + { + GetGradientAttrs = op => new + { + input_encoding = op.get_attr("input_encoding"), + errors = op.get_attr("errors"), + replacement_char = op.get_attr("replacement_char"), + replace_control_characters = op.get_attr("replace_control_characters"), + Tsplits = op.get_attr("Tsplits") + } + }.SetAttributes(new + { + input_encoding, + errors, + replacement_char, + replace_control_characters + })); + + var codepoints = RaggedTensor.from_row_splits(flat_result[1], flat_result[0], validate: false); + + var offsets = RaggedTensor.from_row_splits(flat_result[2], flat_result[0], validate: false); + return (codepoints, offsets); + } + + return (null, null); + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs index 9b9a9085..567014ab 100644 --- a/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs @@ -20,6 +20,7 @@ using System.Text; using System.Linq; using Tensorflow.Framework; using static Tensorflow.Binding; +using NumSharp; namespace Tensorflow { @@ -30,6 +31,8 @@ namespace Tensorflow { Tensor _values; RowPartition _row_partition; + Tensor _row_splits => _row_partition.row_splits; + public TF_DataType dtype => _values.dtype; public TensorShape shape { @@ -41,6 +44,28 @@ namespace Tensorflow } } + public RaggedTensor this[params Slice[] slices] + { + get + { + var row_key = slices[0]; + var inner_keys = slices.Skip(1).ToArray(); + + var args = tensor_util.ParseSlices(slices); + + return tf_with(ops.name_scope(null, "RaggedGetItem", args), scope => + { + string name = scope; + return _ragged_getitem_inner_dimensions(this, inner_keys); + }); + } + } + + RaggedTensor _ragged_getitem_inner_dimensions(RaggedTensor input, Slice[] slices) + { + return input; + } + public RaggedTensor(Tensor values, bool @internal = true, RowPartition row_partition = null) @@ -75,13 +100,44 @@ namespace Tensorflow }); } + public static RaggedTensor from_row_splits(Tensor values, Tensor row_splits, + string name = null, bool validate = true) + { + return tf_with(ops.name_scope(name, "RaggedFromRowSplits"), scope => + { + var row_partition = RowPartition.from_row_splits(row_splits, + validate: validate); + return from_row_partition(values, row_partition, validate: validate); + }); + } + + Tensor _to_variant(bool batched_input = false, string name = null) + => tf_with(ops.name_scope(name, "RaggedToVariant"), scope => + { + return tf.Context.ExecuteOp("RaggedTensorToVariant", name, + new ExecuteOpArgs(nested_row_splits, flat_values) + { + GetGradientAttrs = op => new + { + RAGGED_RANK = op.get_attr("RAGGED_RANK"), + Tvalues = op.get_attr("Tvalues"), + Tsplits = op.get_attr("Tsplits"), + batched_input = op.get_attr("batched_input") + } + }.SetAttributes(new { batched_input })); + }); + + Tensor flat_values + => _values; + + Tensor[] nested_row_splits + => new[] { _row_splits }; + public override string ToString() => $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]"; public static implicit operator Tensor(RaggedTensor indexedSlices) - { - return indexedSlices._values; - } + => indexedSlices._to_variant(); public static implicit operator RaggedTensor(Tensor tensor) { diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs index d58226d8..6a52397a 100644 --- a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs +++ b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs @@ -28,6 +28,7 @@ namespace Tensorflow public class RowPartition : CompositeTensor { Tensor _row_splits; + public Tensor row_splits => _row_splits; Tensor _row_lengths; Tensor _value_rowids; Tensor _nrows; @@ -89,5 +90,14 @@ namespace Tensorflow nrows: nrows); }); } + + public static RowPartition from_row_splits(Tensor row_splits, + bool validate = true, TF_DataType preferred_dtype = TF_DataType.DtInvalid) + { + return tf_with(ops.name_scope(null, "RowPartitionFromRowSplits"), scope => + { + return new RowPartition(row_splits); + }); + } } } diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs index dfc99fea..038f419b 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs @@ -55,10 +55,9 @@ namespace Tensorflow.Keras.Layers if (inputs.shape.ndim > 1) input_tensor = array_ops.squeeze(inputs, axis: new[] { -1 }); if (args.Split == "whitespace") - input_tensor = tf.strings.split(inputs); - + input_tensor = tf.strings.split(input_tensor); } - return inputs; + return input_tensor; } } } diff --git a/src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs b/src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs index a0bbe473..bade6f4a 100644 --- a/src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs +++ b/src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs @@ -1,6 +1,8 @@ -using System; +using NumSharp; +using System; using System.Collections.Generic; using System.Text; +using static Tensorflow.Binding; namespace Tensorflow.Text.Tokenizers { @@ -13,7 +15,31 @@ namespace Tensorflow.Text.Tokenizers /// public Tensor tokenize(Tensor input) { + tokenize_with_offsets(input); throw new NotImplementedException(""); } + + Tensor[] tokenize_with_offsets(Tensor input) + { + tf_with(ops.name_scope(null, "WhitespaceTokenize"), scope => + { + _whitespace_tokenize_with_offsets_encode_decode_wrapper(input); + }); + throw new NotImplementedException(""); + } + + Tensor _whitespace_tokenize_with_offsets_encode_decode_wrapper(Tensor input_tensor) + { + // Decode the strings and get byte offsets + var (codepoints, byte_start_offsets) = tf.strings.unicode_decode_with_offsets(input_tensor, "UTF-8"); + var byte_end_offsets = array_ops.concat(new Tensor[] + { + byte_start_offsets[Slice.All, new Slice(1)], + math_ops.cast( + array_ops.expand_dims(tf.strings.string_length(input_tensor), 1), + dtypes.int64) + }, 1); + return input_tensor; + } } } diff --git a/test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs b/test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs index 3b8237b9..65c69a3f 100644 --- a/test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs +++ b/test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs @@ -10,10 +10,12 @@ namespace TensorFlowNET.UnitTest.Text [TestClass] public class TokenizerTest { - [TestMethod] + [TestMethod, Ignore] public void Tokenize() { var docs = tf.constant(new[] { "Everything not saved will be lost." }); + var tokenizer = text.WhitespaceTokenizer(); + var tokens = tokenizer.tokenize(docs); } } }