Browse Source

IndexLookup, Accumulator.

tags/v0.40-tf2.4-tstring
Oceania2018 4 years ago
parent
commit
bbc2e98a51
16 changed files with 202 additions and 6 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.strings.cs
  2. +11
    -0
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  3. +2
    -0
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  4. +1
    -0
      src/TensorFlowNET.Core/Data/OwnedIterator.cs
  5. +1
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs
  6. +5
    -0
      src/TensorFlowNET.Core/Operations/string_ops.cs
  7. +13
    -1
      src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs
  8. +10
    -0
      src/TensorFlowNET.Keras/Engine/Interfaces/IAccumulator.cs
  9. +19
    -0
      src/TensorFlowNET.Keras/Engine/Interfaces/ICombiner.cs
  10. +30
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookup.cs
  11. +16
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupAccumulator.cs
  12. +55
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupCombiner.cs
  13. +23
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/StringLookup.cs
  14. +11
    -1
      src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
  15. +2
    -0
      src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs
  16. +0
    -4
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.strings.cs View File

@@ -64,6 +64,9 @@ namespace Tensorflow
public Tensor substr(string input, int pos, int len,
string name = null, string @uint = "BYTE")
=> ops.substr(input, pos, len, @uint: @uint, name: name);

public Tensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
=> ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name);
}
}
}

+ 11
- 0
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -68,6 +68,17 @@ namespace Tensorflow
public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls)
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);

public OwnedIterator make_one_shot_iterator()
{
if (tf.Context.executing_eagerly())
{
// with ops.colocate_with(self._variant_tensor)
return new OwnedIterator(this);
}

throw new NotImplementedException("");
}

public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)
=> new FlatMapDataset(this, map_func);



+ 2
- 0
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -72,6 +72,8 @@ namespace Tensorflow
IDatasetV2 map(Func<Tensors, Tensors> map_func,
int num_parallel_calls);

OwnedIterator make_one_shot_iterator();

IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);

IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget);


+ 1
- 0
src/TensorFlowNET.Core/Data/OwnedIterator.cs View File

@@ -26,6 +26,7 @@ namespace Tensorflow
dataset = dataset.apply_options();
_dataset = dataset;
_element_spec = dataset.element_spec;
// _flat_output_types =
(_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes);
ops.make_iterator(dataset.variant_tensor, _iterator_resource);
}


+ 1
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs View File

@@ -11,5 +11,6 @@ namespace Tensorflow.Keras.ArgsDefinition
public int MaxTokens { get; set; } = -1;
public string OutputMode { get; set; } = "int";
public int OutputSequenceLength { get; set; } = -1;
public string[] Vocabulary { get; set; }
}
}

+ 5
- 0
src/TensorFlowNET.Core/Operations/string_ops.cs View File

@@ -41,5 +41,10 @@ namespace Tensorflow
string @uint = "BYTE", string name = null)
=> tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len)
.SetAttributes(new { unit = @uint }));

public Tensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
{
return null;
}
}
}

+ 13
- 1
src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs View File

@@ -8,11 +8,23 @@ namespace Tensorflow.Keras.Engine
public class CombinerPreprocessingLayer : Layer
{
PreprocessingLayerArgs args;
protected ICombiner combiner;
protected bool _previously_updated;

public CombinerPreprocessingLayer(PreprocessingLayerArgs args)
: base(args)
{
_previously_updated = false;
}

public virtual void adapt(IDatasetV2 data, bool reset_state = true)
{
IAccumulator accumulator;
if (!reset_state)
accumulator = combiner.Restore();

var next_data = data.make_one_shot_iterator();
var data_element = next_data.next();
}
}
}

+ 10
- 0
src/TensorFlowNET.Keras/Engine/Interfaces/IAccumulator.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
public interface IAccumulator
{
}
}

+ 19
- 0
src/TensorFlowNET.Keras/Engine/Interfaces/ICombiner.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
/// <summary>
/// Functional object that defines a shardable computation.
/// </summary>
public interface ICombiner
{
void Compute(Tensor values, IAccumulator accumulator = null);
void Merge();
void Extract();
IAccumulator Restore();
void Serialize();
void Deserialize();
}
}

+ 30
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookup.cs View File

@@ -0,0 +1,30 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
{
public class IndexLookup : CombinerPreprocessingLayer
{
public IndexLookup(int max_tokens = -1,
int num_oov_indices = 1,
string mask_token = "",
string oov_token = "[UNK]",
string encoding = "utf-8",
bool invert = false) : base(new PreprocessingLayerArgs())
{
var num_mask_tokens = mask_token == null ? 0 : 1;
var vocab_size = max_tokens - (num_oov_indices + num_mask_tokens);
combiner = new IndexLookupCombiner(vocab_size, mask_token);
}

public override void adapt(IDatasetV2 data, bool reset_state = true)
{
if (!reset_state)
throw new ValueError("IndexLookup does not support streaming adapts.");
base.adapt(data, reset_state);
}
}
}

+ 16
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupAccumulator.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
{
public class IndexLookupAccumulator : IAccumulator
{
public Dictionary<string, int> CountDict { get; set; }
public IndexLookupAccumulator()
{
CountDict = new Dictionary<string, int>();
}
}
}

+ 55
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupCombiner.cs View File

@@ -0,0 +1,55 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Combiner for the IndexLookup preprocessing layer.
/// </summary>
public class IndexLookupCombiner : ICombiner
{
int _vocab_size;
string _mask_value;

public IndexLookupCombiner(int vocab_size = -1, string mask_value = null)
{
_vocab_size = vocab_size;
_mask_value = mask_value;
}

public void Compute(Tensor values, IAccumulator accumulator = null)
{
if(accumulator == null)
{
accumulator = new IndexLookupAccumulator();
}
}

public void Deserialize()
{
throw new NotImplementedException();
}

public void Extract()
{
throw new NotImplementedException();
}

public void Merge()
{
throw new NotImplementedException();
}

public IAccumulator Restore()
{
throw new NotImplementedException();
}

public void Serialize()
{
throw new NotImplementedException();
}
}
}

+ 23
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/StringLookup.cs View File

@@ -0,0 +1,23 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Maps strings from a vocabulary to integer indices.
/// </summary>
class StringLookup : IndexLookup
{
public StringLookup(int max_tokens = -1,
int num_oov_indices = 1,
string mask_token = "",
string[] vocabulary = null,
string oov_token = "[UNK]",
string encoding = "utf-8",
bool invert = false)
{

}
}
}

+ 11
- 1
src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs View File

@@ -3,12 +3,14 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
{
public class TextVectorization : CombinerPreprocessingLayer
{
TextVectorizationArgs args;
IndexLookup _index_lookup_layer;

public TextVectorization(TextVectorizationArgs args)
: base(args)
@@ -16,6 +18,11 @@ namespace Tensorflow.Keras.Layers
this.args = args;
args.DType = TF_DataType.TF_STRING;
// string standardize = "lower_and_strip_punctuation",

var mask_token = args.OutputMode == "int" ? "" : null;
_index_lookup_layer = new StringLookup(max_tokens: args.MaxTokens,
mask_token: mask_token,
vocabulary: args.Vocabulary);
}

/// <summary>
@@ -23,13 +30,14 @@ namespace Tensorflow.Keras.Layers
/// </summary>
/// <param name="data"></param>
/// <param name="reset_state"></param>
public void adapt(IDatasetV2 data, bool reset_state = true)
public override void adapt(IDatasetV2 data, bool reset_state = true)
{
var shape = data.output_shapes[0];
if (shape.rank == 1)
data = data.map(tensor => array_ops.expand_dims(tensor, -1));
build(data.variant_tensor);
var preprocessed_inputs = data.map(_preprocess);
_index_lookup_layer.adapt(preprocessed_inputs);
}

protected override void build(Tensors inputs)
@@ -45,6 +53,8 @@ namespace Tensorflow.Keras.Layers
{
if (inputs.shape.ndim > 1)
inputs = array_ops.squeeze(inputs, axis: new[] { -1 });
if (args.Split == "whitespace")
inputs = tf.strings.split(inputs);
}
return inputs;
}


+ 2
- 0
src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs View File

@@ -1,4 +1,5 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
@@ -60,6 +61,7 @@ namespace Tensorflow.Keras.Preprocessings
}
}

Console.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes.");
return (return_file_paths, return_labels, class_names);
}
}


+ 0
- 4
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -63,10 +63,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
</None>
</ItemGroup>

<ItemGroup>
<Folder Include="Engine\Interfaces\" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
</ItemGroup>


Loading…
Cancel
Save