Browse Source

Started implementation of Tokenizer(), as well as pad_sequences().

pull/756/head
Niklas Gustafsson 4 years ago
parent
commit
606dba816e
6 changed files with 364 additions and 6 deletions
  1. +3
    -0
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
  2. +176
    -0
      src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs
  3. +21
    -6
      src/TensorFlowNET.Keras/Sequence.cs
  4. +42
    -0
      src/TensorFlowNET.Keras/TextApi.cs
  5. +120
    -0
      test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
  6. +2
    -0
      test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj

+ 3
- 0
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs View File

@@ -6,5 +6,8 @@ namespace Tensorflow.Keras
{
public Sequence sequence => new Sequence();
public DatasetUtils dataset_utils => new DatasetUtils();
public TextApi text => _text;

private static TextApi _text = new TextApi();
}
}

+ 176
- 0
src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs View File

@@ -0,0 +1,176 @@
using NumSharp;
using Serilog.Debugging;
using System;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Linq;
using System.Net.Sockets;
using System.Text;

namespace Tensorflow.Keras.Text
{
public class Tokenizer
{
private readonly int num_words;
private readonly string filters;
private readonly bool lower;
private readonly char split;
private readonly bool char_level;
private readonly string oov_token;
private readonly Func<string, IEnumerable<string>> analyzer;

private int document_count = 0;

private Dictionary<string, int> word_docs = new Dictionary<string, int>();
private Dictionary<string, int> word_counts = new Dictionary<string, int>();

public Dictionary<string, int> word_index = null;
public Dictionary<int, string> index_word = null;

private Dictionary<int, int> index_docs = null;

public Tokenizer(
int num_words = -1,
string filters = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n",
bool lower = true,
char split = ' ',
bool char_level = false,
string oov_token = null,
Func<string, IEnumerable<string>> analyzer = null)
{
this.num_words = num_words;
this.filters = filters;
this.lower = lower;
this.split = split;
this.char_level = char_level;
this.oov_token = oov_token;
this.analyzer = analyzer;
}

public void fit_on_texts(IEnumerable<string> texts)
{
foreach (var text in texts)
{
IEnumerable<string> seq = null;

document_count += 1;
if (char_level)
{
throw new NotImplementedException("char_level == true");
}
else
{
seq = analyzer(lower ? text.ToLower() : text);
}

foreach (var w in seq)
{
var count = 0;
word_counts.TryGetValue(w, out count);
word_counts[w] = count + 1;
}

foreach (var w in new HashSet<string>(seq))
{
var count = 0;
word_docs.TryGetValue(w, out count);
word_docs[w] = count + 1;
}
}

var wcounts = word_counts.AsEnumerable().ToList();
wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value));

var sorted_voc = (oov_token == null) ? new List<string>() : new List<string>(){oov_token};
sorted_voc.AddRange(word_counts.Select(kv => kv.Key));

if (num_words > 0 -1)
{
sorted_voc = sorted_voc.Take<string>((oov_token == null) ? num_words : num_words + 1).ToList();
}

word_index = new Dictionary<string, int>(sorted_voc.Count);
index_word = new Dictionary<int,string>(sorted_voc.Count);
index_docs = new Dictionary<int, int>(word_docs.Count);

for (int i = 0; i < sorted_voc.Count; i++)
{
word_index.Add(sorted_voc[i], i + 1);
index_word.Add(i + 1, sorted_voc[i]);
}

foreach (var kv in word_docs)
{
var idx = -1;
if (word_index.TryGetValue(kv.Key, out idx))
{
index_docs.Add(idx, kv.Value);
}
}
}

public void fit_on_sequences(IEnumerable<int[]> sequences)
{
throw new NotImplementedException("fit_on_sequences");
}

public IList<int[]> texts_to_sequences(IEnumerable<string> texts)
{
return texts_to_sequences_generator(texts).ToArray();
}
public IEnumerable<int[]> texts_to_sequences_generator(IEnumerable<string> texts)
{
int oov_index = -1;
var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);

return texts.Select(text => {

IEnumerable<string> seq = null;

if (char_level)
{
throw new NotImplementedException("char_level == true");
}
else
{
seq = analyzer(lower ? text.ToLower() : text);
}
var vect = new List<int>();
foreach (var w in seq)
{
var i = -1;
if (word_index.TryGetValue(w, out i))
{
if (num_words != -1 && i >= num_words)
{
if (oov_index != -1)
{
vect.Add(oov_index);
}
}
else
{
vect.Add(i);
}
}
else
{
vect.Add(oov_index);
}
}

return vect.ToArray();
});
}

public IEnumerable<string> sequences_to_texts(IEnumerable<int[]> sequences)
{
throw new NotImplementedException("sequences_to_texts");
}

public NDArray sequences_to_matrix(IEnumerable<int[]> sequences)
{
throw new NotImplementedException("sequences_to_matrix");
}
}
}

+ 21
- 6
src/TensorFlowNET.Keras/Sequence.cs View File

@@ -15,7 +15,9 @@
******************************************************************************/

using NumSharp;
using NumSharp.Utilities;
using System;
using System.Collections.Generic;
using System.Linq;

namespace Tensorflow.Keras
@@ -34,14 +36,18 @@ namespace Tensorflow.Keras
/// <param name="truncating">String, 'pre' or 'post'</param>
/// <param name="value">Float or String, padding value.</param>
/// <returns></returns>
public NDArray pad_sequences(NDArray sequences,
public NDArray pad_sequences(IEnumerable<int[]> sequences,
int? maxlen = null,
string dtype = "int32",
string padding = "pre",
string truncating = "pre",
object value = null)
{
int[] length = new int[sequences.size];
if (value != null) throw new NotImplementedException("padding with a specific value.");
if (padding != "pre" && padding != "post") throw new InvalidArgumentError("padding must be 'pre' or 'post'.");
if (truncating != "pre" && truncating != "post") throw new InvalidArgumentError("truncating must be 'pre' or 'post'.");

var length = sequences.Select(s => s.Length);

if (maxlen == null)
maxlen = length.Max();
@@ -49,19 +55,28 @@ namespace Tensorflow.Keras
if (value == null)
value = 0f;

var nd = new NDArray(np.int32, new Shape(sequences.size, maxlen.Value));
var type = getNPType(dtype);
var nd = new NDArray(type, new Shape(length.Count(), maxlen.Value),true);

#pragma warning disable CS0162 // Unreachable code detected
for (int i = 0; i < nd.shape[0]; i++)
#pragma warning restore CS0162 // Unreachable code detected
{
switch (sequences[i])
var s = sequences.ElementAt(i);
if (s.Length > maxlen.Value)
{
default:
throw new NotImplementedException("pad_sequences");
s = (truncating == "pre") ? s.Slice(s.Length - maxlen.Value, s.Length) : s.Slice(0, maxlen.Value);
}
var sliceString = (padding == "pre") ? (i.ToString() + "," + (maxlen-s.Length).ToString() + ":") : (i.ToString() + ",:" + s.Length);
nd[sliceString] = np.array(s);
}

return nd;
}

private Type getNPType(string typeName)
{
return System.Type.GetType("NumSharp.np,NumSharp.Lite").GetField(typeName).GetValue(null) as Type;
}
}
}

+ 42
- 0
src/TensorFlowNET.Keras/TextApi.cs View File

@@ -0,0 +1,42 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Text;

namespace Tensorflow.Keras
{
public class TextApi
{
public Tensorflow.Keras.Text.Tokenizer Tokenizer(
int num_words = -1,
string filters = DefaultFilter,
bool lower = true,
char split = ' ',
bool char_level = false,
string oov_token = null,
Func<string, IEnumerable<string>> analyzer = null)
{
if (analyzer != null)
{
return new Keras.Text.Tokenizer(num_words, filters, lower, split, char_level, oov_token, analyzer);
}
else
{
return new Keras.Text.Tokenizer(num_words, filters, lower, split, char_level, oov_token, (text) => text_to_word_sequence(text, filters, lower, split));
}
}

public static IEnumerable<string> text_to_word_sequence(string text, string filters = DefaultFilter, bool lower = true, char split = ' ')
{
if (lower)
{
text = text.ToLower();
}
var newText = new String(text.Where(c => !filters.Contains(c)).ToArray());
return newText.Split(split);
}

private const string DefaultFilter = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n";
}
}

+ 120
- 0
test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs View File

@@ -0,0 +1,120 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq;
using System.Collections.Generic;
using System.Text;
using NumSharp;
using static Tensorflow.KerasApi;
using Tensorflow;
using Tensorflow.Keras.Datasets;

namespace TensorFlowNET.Keras.UnitTest
{
[TestClass]
public class PreprocessingTests : EagerModeTestBase
{
private readonly string[] texts = new string[] {
"It was the best of times, it was the worst of times.",
"this is a new dawn, an era to follow the previous era. It can not be said to start anew.",
"It was the best of times, it was the worst of times.",
"this is a new dawn, an era to follow the previous era.",
};
private const string OOV = "<OOV>";

[TestMethod]
public void TokenizeWithNoOOV()
{
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true);
tokenizer.fit_on_texts(texts);

Assert.AreEqual(23, tokenizer.word_index.Count);

Assert.AreEqual(7, tokenizer.word_index["worst"]);
Assert.AreEqual(12, tokenizer.word_index["dawn"]);
Assert.AreEqual(16, tokenizer.word_index["follow"]);
}

[TestMethod]
public void TokenizeWithOOV()
{
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV);
tokenizer.fit_on_texts(texts);

Assert.AreEqual(24, tokenizer.word_index.Count);

Assert.AreEqual(1, tokenizer.word_index[OOV]);
Assert.AreEqual(8, tokenizer.word_index["worst"]);
Assert.AreEqual(13, tokenizer.word_index["dawn"]);
Assert.AreEqual(17, tokenizer.word_index["follow"]);
}

[TestMethod]
public void PadSequencesWithDefaults()
{
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV);
tokenizer.fit_on_texts(texts);

var sequences = tokenizer.texts_to_sequences(texts);
var padded = keras.preprocessing.sequence.pad_sequences(sequences);

Assert.AreEqual(4, padded.shape[0]);
Assert.AreEqual(20, padded.shape[1]);

var firstRow = padded[0];
var secondRow = padded[1];

Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 17].GetInt32());
for (var i = 0; i < 8; i++)
Assert.AreEqual(0, padded[0, i].GetInt32());
Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32());
for (var i = 0; i < 20; i++)
Assert.AreNotEqual(0, padded[1, i].GetInt32());
}

[TestMethod]
public void PadSequencesPrePaddingTrunc()
{
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV);
tokenizer.fit_on_texts(texts);

var sequences = tokenizer.texts_to_sequences(texts);
var padded = keras.preprocessing.sequence.pad_sequences(sequences,maxlen:15);

Assert.AreEqual(4, padded.shape[0]);
Assert.AreEqual(15, padded.shape[1]);

var firstRow = padded[0];
var secondRow = padded[1];

Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 12].GetInt32());
for (var i = 0; i < 3; i++)
Assert.AreEqual(0, padded[0, i].GetInt32());
Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 5].GetInt32());
for (var i = 0; i < 15; i++)
Assert.AreNotEqual(0, padded[1, i].GetInt32());
}

[TestMethod]
public void PadSequencesPostPaddingTrunc()
{
var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV);
tokenizer.fit_on_texts(texts);

var sequences = tokenizer.texts_to_sequences(texts);
var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 15, padding: "post", truncating: "post");

Assert.AreEqual(4, padded.shape[0]);
Assert.AreEqual(15, padded.shape[1]);

var firstRow = padded[0];
var secondRow = padded[1];

Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 9].GetInt32());
for (var i = 12; i < 15; i++)
Assert.AreEqual(0, padded[0, i].GetInt32());
Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32());
for (var i = 0; i < 15; i++)
Assert.AreNotEqual(0, padded[1, i].GetInt32());
}
}
}

+ 2
- 0
test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj View File

@@ -6,6 +6,8 @@
<IsPackable>false</IsPackable>

<Platforms>AnyCPU;x64</Platforms>

<RootNamespace>TensorflowNET.Keras.UnitTest</RootNamespace>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">


Loading…
Cancel
Save