From 90638a8352072d0948ef89e70f0d782116436e4a Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 14 Feb 2021 11:51:55 -0600 Subject: [PATCH] TextVectorization --- src/TensorFlowNET.Core/Data/MapDataset.cs | 8 +++-- .../Tensorflow.Binding.csproj | 2 +- .../Layers/Preprocessing/TextVectorization.cs | 32 +++++++++++++++++++ .../Preprocessings/Preprocessing.cs | 2 +- 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/TensorFlowNET.Core/Data/MapDataset.cs b/src/TensorFlowNET.Core/Data/MapDataset.cs index 5786a340..1f843e4a 100644 --- a/src/TensorFlowNET.Core/Data/MapDataset.cs +++ b/src/TensorFlowNET.Core/Data/MapDataset.cs @@ -17,9 +17,11 @@ namespace Tensorflow { var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}"); func.Enter(); - var input = tf.placeholder(input_dataset.element_spec[0].dtype); - var output = map_func(input); - func.ToGraph(input, output); + var inputs = new Tensors(); + foreach (var input in input_dataset.element_spec) + inputs.Add(tf.placeholder(input.dtype, shape: input.shape)); + var outputs = map_func(inputs); + func.ToGraph(inputs, outputs); func.Exit(); structure = func.OutputStructure; diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 139f98dc..26cd5139 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -86,7 +86,7 @@ tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library. - + diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs index a66be94b..c72860a6 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs @@ -13,8 +13,40 @@ namespace Tensorflow.Keras.Layers public TextVectorization(TextVectorizationArgs args) : base(args) { + this.args = args; args.DType = TF_DataType.TF_STRING; // string standardize = "lower_and_strip_punctuation", } + + /// + /// Fits the state of the preprocessing layer to the dataset. + /// + /// + /// + public void adapt(IDatasetV2 data, bool reset_state = true) + { + var shape = data.output_shapes[0]; + if (shape.rank == 1) + data = data.map(tensor => array_ops.expand_dims(tensor, -1)); + build(data.variant_tensor); + var preprocessed_inputs = data.map(_preprocess); + } + + protected override void build(Tensors inputs) + { + base.build(inputs); + } + + Tensors _preprocess(Tensors inputs) + { + if (args.Standardize != null) + inputs = args.Standardize(inputs); + if (!string.IsNullOrEmpty(args.Split)) + { + if (inputs.shape.ndim > 1) + inputs = array_ops.squeeze(inputs, axis: new[] { -1 }); + } + return inputs; + } } } diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs index 6c33e9f5..34aeb211 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Keras public DatasetUtils dataset_utils => new DatasetUtils(); public TextVectorization TextVectorization(Func standardize = null, - string split = "standardize", + string split = "whitespace", int max_tokens = -1, string output_mode = "int", int output_sequence_length = -1) => new TextVectorization(new TextVectorizationArgs