Browse Source

TextVectorization

tags/v0.40-tf2.4-tstring
Oceania2018 4 years ago
parent
commit
90638a8352
4 changed files with 39 additions and 5 deletions
  1. +5
    -3
      src/TensorFlowNET.Core/Data/MapDataset.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  3. +32
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
  4. +1
    -1
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs

+ 5
- 3
src/TensorFlowNET.Core/Data/MapDataset.cs View File

@@ -17,9 +17,11 @@ namespace Tensorflow
{
var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}");
func.Enter();
var input = tf.placeholder(input_dataset.element_spec[0].dtype);
var output = map_func(input);
func.ToGraph(input, output);
var inputs = new Tensors();
foreach (var input in input_dataset.element_spec)
inputs.Add(tf.placeholder(input.dtype, shape: input.shape));
var outputs = map_func(inputs);
func.ToGraph(inputs, outputs);
func.Exit();

structure = func.OutputStructure;


+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -86,7 +86,7 @@ tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.</PackageReleaseNotes
<ItemGroup>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="5.0.1" />
<PackageReference Include="NumSharp.Lite" Version="0.1.12" />
<PackageReference Include="NumSharp" Version="0.30.0" />
<PackageReference Include="Protobuf.Text" Version="0.5.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" />
</ItemGroup>


+ 32
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs View File

@@ -13,8 +13,40 @@ namespace Tensorflow.Keras.Layers
public TextVectorization(TextVectorizationArgs args)
: base(args)
{
this.args = args;
args.DType = TF_DataType.TF_STRING;
// string standardize = "lower_and_strip_punctuation",
}

/// <summary>
/// Fits the state of the preprocessing layer to the dataset.
/// </summary>
/// <param name="data"></param>
/// <param name="reset_state"></param>
public void adapt(IDatasetV2 data, bool reset_state = true)
{
var shape = data.output_shapes[0];
if (shape.rank == 1)
data = data.map(tensor => array_ops.expand_dims(tensor, -1));
build(data.variant_tensor);
var preprocessed_inputs = data.map(_preprocess);
}

protected override void build(Tensors inputs)
{
base.build(inputs);
}

Tensors _preprocess(Tensors inputs)
{
if (args.Standardize != null)
inputs = args.Standardize(inputs);
if (!string.IsNullOrEmpty(args.Split))
{
if (inputs.shape.ndim > 1)
inputs = array_ops.squeeze(inputs, axis: new[] { -1 });
}
return inputs;
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow.Keras
public DatasetUtils dataset_utils => new DatasetUtils();

public TextVectorization TextVectorization(Func<Tensor, Tensor> standardize = null,
string split = "standardize",
string split = "whitespace",
int max_tokens = -1,
string output_mode = "int",
int output_sequence_length = -1) => new TextVectorization(new TextVectorizationArgs


Loading…
Cancel
Save