diff --git a/src/TensorFlowNET.Core/Data/MapDataset.cs b/src/TensorFlowNET.Core/Data/MapDataset.cs
index 5786a340..1f843e4a 100644
--- a/src/TensorFlowNET.Core/Data/MapDataset.cs
+++ b/src/TensorFlowNET.Core/Data/MapDataset.cs
@@ -17,9 +17,11 @@ namespace Tensorflow
{
var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}");
func.Enter();
- var input = tf.placeholder(input_dataset.element_spec[0].dtype);
- var output = map_func(input);
- func.ToGraph(input, output);
+ var inputs = new Tensors();
+ foreach (var input in input_dataset.element_spec)
+ inputs.Add(tf.placeholder(input.dtype, shape: input.shape));
+ var outputs = map_func(inputs);
+ func.ToGraph(inputs, outputs);
func.Exit();
structure = func.OutputStructure;
diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
index 139f98dc..26cd5139 100644
--- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
+++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
@@ -86,7 +86,7 @@ tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.
-
+
diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
index a66be94b..c72860a6 100644
--- a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
+++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
@@ -13,8 +13,40 @@ namespace Tensorflow.Keras.Layers
public TextVectorization(TextVectorizationArgs args)
: base(args)
{
+ this.args = args;
args.DType = TF_DataType.TF_STRING;
// string standardize = "lower_and_strip_punctuation",
}
+
+ ///
+ /// Fits the state of the preprocessing layer to the dataset.
+ ///
+ ///
+ ///
+ public void adapt(IDatasetV2 data, bool reset_state = true)
+ {
+ var shape = data.output_shapes[0];
+ if (shape.rank == 1)
+ data = data.map(tensor => array_ops.expand_dims(tensor, -1));
+ build(data.variant_tensor);
+ var preprocessed_inputs = data.map(_preprocess);
+ }
+
+ protected override void build(Tensors inputs)
+ {
+ base.build(inputs);
+ }
+
+ Tensors _preprocess(Tensors inputs)
+ {
+ if (args.Standardize != null)
+ inputs = args.Standardize(inputs);
+ if (!string.IsNullOrEmpty(args.Split))
+ {
+ if (inputs.shape.ndim > 1)
+ inputs = array_ops.squeeze(inputs, axis: new[] { -1 });
+ }
+ return inputs;
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
index 6c33e9f5..34aeb211 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
@@ -11,7 +11,7 @@ namespace Tensorflow.Keras
public DatasetUtils dataset_utils => new DatasetUtils();
public TextVectorization TextVectorization(Func standardize = null,
- string split = "standardize",
+ string split = "whitespace",
int max_tokens = -1,
string output_mode = "int",
int output_sequence_length = -1) => new TextVectorization(new TextVectorizationArgs