|
|
@@ -3,12 +3,14 @@ using System.Collections.Generic; |
|
|
|
using System.Text; |
|
|
|
using Tensorflow.Keras.ArgsDefinition; |
|
|
|
using Tensorflow.Keras.Engine; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
namespace Tensorflow.Keras.Layers |
|
|
|
{ |
|
|
|
public class TextVectorization : CombinerPreprocessingLayer |
|
|
|
{ |
|
|
|
TextVectorizationArgs args; |
|
|
|
IndexLookup _index_lookup_layer; |
|
|
|
|
|
|
|
public TextVectorization(TextVectorizationArgs args) |
|
|
|
: base(args) |
|
|
@@ -16,6 +18,11 @@ namespace Tensorflow.Keras.Layers |
|
|
|
this.args = args; |
|
|
|
args.DType = TF_DataType.TF_STRING; |
|
|
|
// string standardize = "lower_and_strip_punctuation", |
|
|
|
|
|
|
|
var mask_token = args.OutputMode == "int" ? "" : null; |
|
|
|
_index_lookup_layer = new StringLookup(max_tokens: args.MaxTokens, |
|
|
|
mask_token: mask_token, |
|
|
|
vocabulary: args.Vocabulary); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
@@ -23,13 +30,14 @@ namespace Tensorflow.Keras.Layers |
|
|
|
/// </summary> |
|
|
|
/// <param name="data"></param> |
|
|
|
/// <param name="reset_state"></param> |
|
|
|
public void adapt(IDatasetV2 data, bool reset_state = true) |
|
|
|
public override void adapt(IDatasetV2 data, bool reset_state = true) |
|
|
|
{ |
|
|
|
var shape = data.output_shapes[0]; |
|
|
|
if (shape.rank == 1) |
|
|
|
data = data.map(tensor => array_ops.expand_dims(tensor, -1)); |
|
|
|
build(data.variant_tensor); |
|
|
|
var preprocessed_inputs = data.map(_preprocess); |
|
|
|
_index_lookup_layer.adapt(preprocessed_inputs); |
|
|
|
} |
|
|
|
|
|
|
|
protected override void build(Tensors inputs) |
|
|
@@ -45,6 +53,8 @@ namespace Tensorflow.Keras.Layers |
|
|
|
{ |
|
|
|
if (inputs.shape.ndim > 1) |
|
|
|
inputs = array_ops.squeeze(inputs, axis: new[] { -1 }); |
|
|
|
if (args.Split == "whitespace") |
|
|
|
inputs = tf.strings.split(inputs); |
|
|
|
} |
|
|
|
return inputs; |
|
|
|
} |
|
|
|