diff --git a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs
index d3bfeaac..aaca1cb9 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs
@@ -9,6 +9,12 @@ using System.Text;
namespace Tensorflow.Keras.Text
{
+ ///
+ /// Text tokenization API.
+ /// This class allows to vectorize a text corpus, by turning each text into either a sequence of integers
+ /// (each integer being the index of a token in a dictionary) or into a vector where the coefficient for
+ /// each token could be binary, based on word count, based on tf-idf...
+ ///
public class Tokenizer
{
private readonly int num_words;
@@ -47,6 +53,11 @@ namespace Tensorflow.Keras.Text
this.analyzer = analyzer;
}
+ ///
+ /// Updates internal vocabulary based on a list of texts.
+ ///
+ ///
+ /// Required before using texts_to_sequences or texts_to_matrix.
public void fit_on_texts(IEnumerable texts)
{
foreach (var text in texts)
@@ -81,16 +92,16 @@ namespace Tensorflow.Keras.Text
var wcounts = word_counts.AsEnumerable().ToList();
wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value));
- var sorted_voc = (oov_token == null) ? new List() : new List(){oov_token};
+ var sorted_voc = (oov_token == null) ? new List() : new List() { oov_token };
sorted_voc.AddRange(word_counts.Select(kv => kv.Key));
- if (num_words > 0 -1)
+ if (num_words > 0 - 1)
{
sorted_voc = sorted_voc.Take((oov_token == null) ? num_words : num_words + 1).ToList();
}
word_index = new Dictionary(sorted_voc.Count);
- index_word = new Dictionary(sorted_voc.Count);
+ index_word = new Dictionary(sorted_voc.Count);
index_docs = new Dictionary(word_docs.Count);
for (int i = 0; i < sorted_voc.Count; i++)
@@ -109,25 +120,98 @@ namespace Tensorflow.Keras.Text
}
}
+ public void fit_on_texts(IEnumerable> texts)
+ {
+ foreach (var seq in texts)
+ {
+ foreach (var w in seq.Select(s => lower ? s.ToLower() : s))
+ {
+ var count = 0;
+ word_counts.TryGetValue(w, out count);
+ word_counts[w] = count + 1;
+ }
+
+ foreach (var w in new HashSet(word_counts.Keys))
+ {
+ var count = 0;
+ word_docs.TryGetValue(w, out count);
+ word_docs[w] = count + 1;
+ }
+ }
+
+ var wcounts = word_counts.AsEnumerable().ToList();
+ wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value));
+
+ var sorted_voc = (oov_token == null) ? new List() : new List() { oov_token };
+ sorted_voc.AddRange(word_counts.Select(kv => kv.Key));
+
+ if (num_words > 0 - 1)
+ {
+ sorted_voc = sorted_voc.Take((oov_token == null) ? num_words : num_words + 1).ToList();
+ }
+
+ word_index = new Dictionary(sorted_voc.Count);
+ index_word = new Dictionary(sorted_voc.Count);
+ index_docs = new Dictionary(word_docs.Count);
+
+ for (int i = 0; i < sorted_voc.Count; i++)
+ {
+ word_index.Add(sorted_voc[i], i + 1);
+ index_word.Add(i + 1, sorted_voc[i]);
+ }
+
+ foreach (var kv in word_docs)
+ {
+ var idx = -1;
+ if (word_index.TryGetValue(kv.Key, out idx))
+ {
+ index_docs.Add(idx, kv.Value);
+ }
+ }
+ }
+
+ ///
+ /// Updates internal vocabulary based on a list of sequences.
+ ///
+ ///
+ /// Required before using sequences_to_matrix (if fit_on_texts was never called).
public void fit_on_sequences(IEnumerable sequences)
{
throw new NotImplementedException("fit_on_sequences");
}
+ ///
+ /// Transforms each string in texts to a sequence of integers.
+ ///
+ ///
+ ///
+ /// Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.
public IList texts_to_sequences(IEnumerable texts)
{
return texts_to_sequences_generator(texts).ToArray();
}
+
+ ///
+ /// Transforms each token in texts to a sequence of integers.
+ ///
+ ///
+ ///
+ /// Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.
+ public IList texts_to_sequences(IEnumerable> texts)
+ {
+ return texts_to_sequences_generator(texts).ToArray();
+ }
+
public IEnumerable texts_to_sequences_generator(IEnumerable texts)
{
int oov_index = -1;
var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
- return texts.Select(text => {
-
+ return texts.Select(text =>
+ {
IEnumerable seq = null;
- if (char_level)
+ if (char_level)
{
throw new NotImplementedException("char_level == true");
}
@@ -135,39 +219,101 @@ namespace Tensorflow.Keras.Text
{
seq = analyzer(lower ? text.ToLower() : text);
}
- var vect = new List();
- foreach (var w in seq)
+
+ return ConvertToSequence(oov_index, seq).ToArray();
+ });
+ }
+
+ private List ConvertToSequence(int oov_index, IEnumerable seq)
+ {
+ var vect = new List();
+ foreach (var w in seq.Select(s => lower ? s.ToLower() : s))
+ {
+ var i = -1;
+ if (word_index.TryGetValue(w, out i))
+ {
+ if (num_words != -1 && i >= num_words)
+ {
+ if (oov_index != -1)
+ {
+ vect.Add(oov_index);
+ }
+ }
+ else
+ {
+ vect.Add(i);
+ }
+ }
+ else if(oov_index != -1)
+ {
+ vect.Add(oov_index);
+ }
+ }
+
+ return vect;
+ }
+
+ public IEnumerable texts_to_sequences_generator(IEnumerable> texts)
+ {
+ int oov_index = -1;
+ var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
+ return texts.Select(seq => ConvertToSequence(oov_index, seq).ToArray());
+ }
+
+ ///
+ /// Transforms each sequence into a list of text.
+ ///
+ ///
+ /// A list of texts(strings)
+ /// Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.
+ public IList sequences_to_texts(IEnumerable sequences)
+ {
+ return sequences_to_texts_generator(sequences).ToArray();
+ }
+
+ public IEnumerable sequences_to_texts_generator(IEnumerable sequences)
+ {
+ int oov_index = -1;
+ var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
+
+ return sequences.Select(seq =>
+ {
+
+ var bldr = new StringBuilder();
+ for (var i = 0; i < seq.Length; i++)
{
- var i = -1;
- if (word_index.TryGetValue(w, out i))
+ if (i > 0) bldr.Append(' ');
+
+ string word = null;
+ if (index_word.TryGetValue(seq[i], out word))
{
if (num_words != -1 && i >= num_words)
{
if (oov_index != -1)
{
- vect.Add(oov_index);
+ bldr.Append(oov_token);
}
}
else
{
- vect.Add(i);
+ bldr.Append(word);
}
}
- else
+ else if (oov_index != -1)
{
- vect.Add(oov_index);
+ bldr.Append(oov_token);
}
}
- return vect.ToArray();
+ return bldr.ToString();
});
}
- public IEnumerable sequences_to_texts(IEnumerable sequences)
- {
- throw new NotImplementedException("sequences_to_texts");
- }
-
+ ///
+ /// Converts a list of sequences into a Numpy matrix.
+ ///
+ ///
+ ///
public NDArray sequences_to_matrix(IEnumerable sequences)
{
throw new NotImplementedException("sequences_to_matrix");
diff --git a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
index 0bbb9a8b..ebde87fa 100644
--- a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
@@ -19,12 +19,27 @@ namespace TensorFlowNET.Keras.UnitTest
"It was the best of times, it was the worst of times.",
"this is a new dawn, an era to follow the previous era.",
};
+
+ private readonly string[][] tokenized_texts = new string[][] {
+ new string[] {"It","was","the","best","of","times","it","was","the","worst","of","times"},
+ new string[] {"this","is","a","new","dawn","an","era","to","follow","the","previous","era","It","can","not","be","said","to","start","anew" },
+ new string[] {"It","was","the","best","of","times","it","was","the","worst","of","times"},
+ new string[] {"this","is","a","new","dawn","an","era","to","follow","the","previous","era" },
+ };
+
+ private readonly string[] processed_texts = new string[] {
+ "it was the best of times it was the worst of times",
+ "this is a new dawn an era to follow the previous era it can not be said to start anew",
+ "it was the best of times it was the worst of times",
+ "this is a new dawn an era to follow the previous era",
+ };
+
private const string OOV = "";
[TestMethod]
public void TokenizeWithNoOOV()
{
- var tokenizer = keras.preprocessing.text.Tokenizer(lower: true);
+ var tokenizer = keras.preprocessing.text.Tokenizer();
tokenizer.fit_on_texts(texts);
Assert.AreEqual(23, tokenizer.word_index.Count);
@@ -34,10 +49,24 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(16, tokenizer.word_index["follow"]);
}
+ [TestMethod]
+ public void TokenizeWithNoOOV_Tkn()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ // Use the list version, where the tokenization has already been done.
+ tokenizer.fit_on_texts(tokenized_texts);
+
+ Assert.AreEqual(23, tokenizer.word_index.Count);
+
+ Assert.AreEqual(7, tokenizer.word_index["worst"]);
+ Assert.AreEqual(12, tokenizer.word_index["dawn"]);
+ Assert.AreEqual(16, tokenizer.word_index["follow"]);
+ }
+
[TestMethod]
public void TokenizeWithOOV()
{
- var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV);
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
tokenizer.fit_on_texts(texts);
Assert.AreEqual(24, tokenizer.word_index.Count);
@@ -48,10 +77,161 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(17, tokenizer.word_index["follow"]);
}
+ [TestMethod]
+ public void TokenizeWithOOV_Tkn()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
+ // Use the list version, where the tokenization has already been done.
+ tokenizer.fit_on_texts(tokenized_texts);
+
+ Assert.AreEqual(24, tokenizer.word_index.Count);
+
+ Assert.AreEqual(1, tokenizer.word_index[OOV]);
+ Assert.AreEqual(8, tokenizer.word_index["worst"]);
+ Assert.AreEqual(13, tokenizer.word_index["dawn"]);
+ Assert.AreEqual(17, tokenizer.word_index["follow"]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequences()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]);
+ Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequences_Tkn()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ // Use the list version, where the tokenization has already been done.
+ tokenizer.fit_on_texts(tokenized_texts);
+
+ var sequences = tokenizer.texts_to_sequences(tokenized_texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]);
+ Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequencesAndBack()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ var processed = tokenizer.sequences_to_texts(sequences);
+
+ Assert.AreEqual(4, processed.Count);
+
+ for (var i = 0; i < processed.Count; i++)
+ Assert.AreEqual(processed_texts[i], processed[i]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequencesAndBack_Tkn1()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ // Use the list version, where the tokenization has already been done.
+ tokenizer.fit_on_texts(tokenized_texts);
+
+ // Use the list version, where the tokenization has already been done.
+ var sequences = tokenizer.texts_to_sequences(tokenized_texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ var processed = tokenizer.sequences_to_texts(sequences);
+
+ Assert.AreEqual(4, processed.Count);
+
+ for (var i = 0; i < processed.Count; i++)
+ Assert.AreEqual(processed_texts[i], processed[i]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequencesAndBack_Tkn2()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ // Use the list version, where the tokenization has already been done.
+ tokenizer.fit_on_texts(tokenized_texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ var processed = tokenizer.sequences_to_texts(sequences);
+
+ Assert.AreEqual(4, processed.Count);
+
+ for (var i = 0; i < processed.Count; i++)
+ Assert.AreEqual(processed_texts[i], processed[i]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequencesAndBack_Tkn3()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ tokenizer.fit_on_texts(texts);
+
+ // Use the list version, where the tokenization has already been done.
+ var sequences = tokenizer.texts_to_sequences(tokenized_texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ var processed = tokenizer.sequences_to_texts(sequences);
+
+ Assert.AreEqual(4, processed.Count);
+
+ for (var i = 0; i < processed.Count; i++)
+ Assert.AreEqual(processed_texts[i], processed[i]);
+ }
+ [TestMethod]
+ public void TokenizeTextsToSequencesWithOOV()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]);
+ Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]);
+
+ for (var i = 0; i < sequences.Count; i++)
+ for (var j = 0; j < sequences[i].Length; j++)
+ Assert.AreNotEqual(tokenizer.word_index[OOV], sequences[i][j]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequencesWithOOVPresent()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV, num_words:20);
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]);
+ Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]);
+
+ var oov_count = 0;
+ for (var i = 0; i < sequences.Count; i++)
+ for (var j = 0; j < sequences[i].Length; j++)
+ if (tokenizer.word_index[OOV] == sequences[i][j])
+ oov_count += 1;
+
+ Assert.AreEqual(5, oov_count);
+ }
+
[TestMethod]
public void PadSequencesWithDefaults()
{
- var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV);
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
tokenizer.fit_on_texts(texts);
var sequences = tokenizer.texts_to_sequences(texts);
@@ -74,7 +254,7 @@ namespace TensorFlowNET.Keras.UnitTest
[TestMethod]
public void PadSequencesPrePaddingTrunc()
{
- var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV);
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
tokenizer.fit_on_texts(texts);
var sequences = tokenizer.texts_to_sequences(texts);
@@ -97,7 +277,7 @@ namespace TensorFlowNET.Keras.UnitTest
[TestMethod]
public void PadSequencesPostPaddingTrunc()
{
- var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV);
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
tokenizer.fit_on_texts(texts);
var sequences = tokenizer.texts_to_sequences(texts);