From 606dba816e3cb71e979ab1fb8e84b20c8b2ef308 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Fri, 29 Jan 2021 09:19:08 -0800 Subject: [PATCH] Started implementation of Tokenizer(), as well as pad_sequences(). --- .../Preprocessings/Preprocessing.cs | 3 + .../Preprocessings/Tokenizer.cs | 176 ++++++++++++++++++ src/TensorFlowNET.Keras/Sequence.cs | 27 ++- src/TensorFlowNET.Keras/TextApi.cs | 42 +++++ .../PreprocessingTests.cs | 120 ++++++++++++ .../Tensorflow.Keras.UnitTest.csproj | 2 + 6 files changed, 364 insertions(+), 6 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs create mode 100644 src/TensorFlowNET.Keras/TextApi.cs create mode 100644 test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs index 2d418509..b49d49f2 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs @@ -6,5 +6,8 @@ namespace Tensorflow.Keras { public Sequence sequence => new Sequence(); public DatasetUtils dataset_utils => new DatasetUtils(); + public TextApi text => _text; + + private static TextApi _text = new TextApi(); } } diff --git a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs new file mode 100644 index 00000000..d3bfeaac --- /dev/null +++ b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs @@ -0,0 +1,176 @@ +using NumSharp; +using Serilog.Debugging; +using System; +using System.Collections.Generic; +using System.Collections.Specialized; +using System.Linq; +using System.Net.Sockets; +using System.Text; + +namespace Tensorflow.Keras.Text +{ + public class Tokenizer + { + private readonly int num_words; + private readonly string filters; + private readonly bool lower; + private readonly char split; + private readonly bool char_level; + private readonly string oov_token; + private readonly Func> analyzer; + + private int document_count = 0; + + private Dictionary word_docs = new Dictionary(); + private Dictionary word_counts = new Dictionary(); + + public Dictionary word_index = null; + public Dictionary index_word = null; + + private Dictionary index_docs = null; + + public Tokenizer( + int num_words = -1, + string filters = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n", + bool lower = true, + char split = ' ', + bool char_level = false, + string oov_token = null, + Func> analyzer = null) + { + this.num_words = num_words; + this.filters = filters; + this.lower = lower; + this.split = split; + this.char_level = char_level; + this.oov_token = oov_token; + this.analyzer = analyzer; + } + + public void fit_on_texts(IEnumerable texts) + { + foreach (var text in texts) + { + IEnumerable seq = null; + + document_count += 1; + if (char_level) + { + throw new NotImplementedException("char_level == true"); + } + else + { + seq = analyzer(lower ? text.ToLower() : text); + } + + foreach (var w in seq) + { + var count = 0; + word_counts.TryGetValue(w, out count); + word_counts[w] = count + 1; + } + + foreach (var w in new HashSet(seq)) + { + var count = 0; + word_docs.TryGetValue(w, out count); + word_docs[w] = count + 1; + } + } + + var wcounts = word_counts.AsEnumerable().ToList(); + wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); + + var sorted_voc = (oov_token == null) ? new List() : new List(){oov_token}; + sorted_voc.AddRange(word_counts.Select(kv => kv.Key)); + + if (num_words > 0 -1) + { + sorted_voc = sorted_voc.Take((oov_token == null) ? num_words : num_words + 1).ToList(); + } + + word_index = new Dictionary(sorted_voc.Count); + index_word = new Dictionary(sorted_voc.Count); + index_docs = new Dictionary(word_docs.Count); + + for (int i = 0; i < sorted_voc.Count; i++) + { + word_index.Add(sorted_voc[i], i + 1); + index_word.Add(i + 1, sorted_voc[i]); + } + + foreach (var kv in word_docs) + { + var idx = -1; + if (word_index.TryGetValue(kv.Key, out idx)) + { + index_docs.Add(idx, kv.Value); + } + } + } + + public void fit_on_sequences(IEnumerable sequences) + { + throw new NotImplementedException("fit_on_sequences"); + } + + public IList texts_to_sequences(IEnumerable texts) + { + return texts_to_sequences_generator(texts).ToArray(); + } + public IEnumerable texts_to_sequences_generator(IEnumerable texts) + { + int oov_index = -1; + var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index); + + return texts.Select(text => { + + IEnumerable seq = null; + + if (char_level) + { + throw new NotImplementedException("char_level == true"); + } + else + { + seq = analyzer(lower ? text.ToLower() : text); + } + var vect = new List(); + foreach (var w in seq) + { + var i = -1; + if (word_index.TryGetValue(w, out i)) + { + if (num_words != -1 && i >= num_words) + { + if (oov_index != -1) + { + vect.Add(oov_index); + } + } + else + { + vect.Add(i); + } + } + else + { + vect.Add(oov_index); + } + } + + return vect.ToArray(); + }); + } + + public IEnumerable sequences_to_texts(IEnumerable sequences) + { + throw new NotImplementedException("sequences_to_texts"); + } + + public NDArray sequences_to_matrix(IEnumerable sequences) + { + throw new NotImplementedException("sequences_to_matrix"); + } + } +} diff --git a/src/TensorFlowNET.Keras/Sequence.cs b/src/TensorFlowNET.Keras/Sequence.cs index a428a568..9567325e 100644 --- a/src/TensorFlowNET.Keras/Sequence.cs +++ b/src/TensorFlowNET.Keras/Sequence.cs @@ -15,7 +15,9 @@ ******************************************************************************/ using NumSharp; +using NumSharp.Utilities; using System; +using System.Collections.Generic; using System.Linq; namespace Tensorflow.Keras @@ -34,14 +36,18 @@ namespace Tensorflow.Keras /// String, 'pre' or 'post' /// Float or String, padding value. /// - public NDArray pad_sequences(NDArray sequences, + public NDArray pad_sequences(IEnumerable sequences, int? maxlen = null, string dtype = "int32", string padding = "pre", string truncating = "pre", object value = null) { - int[] length = new int[sequences.size]; + if (value != null) throw new NotImplementedException("padding with a specific value."); + if (padding != "pre" && padding != "post") throw new InvalidArgumentError("padding must be 'pre' or 'post'."); + if (truncating != "pre" && truncating != "post") throw new InvalidArgumentError("truncating must be 'pre' or 'post'."); + + var length = sequences.Select(s => s.Length); if (maxlen == null) maxlen = length.Max(); @@ -49,19 +55,28 @@ namespace Tensorflow.Keras if (value == null) value = 0f; - var nd = new NDArray(np.int32, new Shape(sequences.size, maxlen.Value)); + var type = getNPType(dtype); + var nd = new NDArray(type, new Shape(length.Count(), maxlen.Value),true); + #pragma warning disable CS0162 // Unreachable code detected for (int i = 0; i < nd.shape[0]; i++) #pragma warning restore CS0162 // Unreachable code detected { - switch (sequences[i]) + var s = sequences.ElementAt(i); + if (s.Length > maxlen.Value) { - default: - throw new NotImplementedException("pad_sequences"); + s = (truncating == "pre") ? s.Slice(s.Length - maxlen.Value, s.Length) : s.Slice(0, maxlen.Value); } + var sliceString = (padding == "pre") ? (i.ToString() + "," + (maxlen-s.Length).ToString() + ":") : (i.ToString() + ",:" + s.Length); + nd[sliceString] = np.array(s); } return nd; } + + private Type getNPType(string typeName) + { + return System.Type.GetType("NumSharp.np,NumSharp.Lite").GetField(typeName).GetValue(null) as Type; + } } } diff --git a/src/TensorFlowNET.Keras/TextApi.cs b/src/TensorFlowNET.Keras/TextApi.cs new file mode 100644 index 00000000..2e62e25b --- /dev/null +++ b/src/TensorFlowNET.Keras/TextApi.cs @@ -0,0 +1,42 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Text; + +namespace Tensorflow.Keras +{ + public class TextApi + { + public Tensorflow.Keras.Text.Tokenizer Tokenizer( + int num_words = -1, + string filters = DefaultFilter, + bool lower = true, + char split = ' ', + bool char_level = false, + string oov_token = null, + Func> analyzer = null) + { + if (analyzer != null) + { + return new Keras.Text.Tokenizer(num_words, filters, lower, split, char_level, oov_token, analyzer); + } + else + { + return new Keras.Text.Tokenizer(num_words, filters, lower, split, char_level, oov_token, (text) => text_to_word_sequence(text, filters, lower, split)); + } + } + + public static IEnumerable text_to_word_sequence(string text, string filters = DefaultFilter, bool lower = true, char split = ' ') + { + if (lower) + { + text = text.ToLower(); + } + var newText = new String(text.Where(c => !filters.Contains(c)).ToArray()); + return newText.Split(split); + } + + private const string DefaultFilter = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n"; + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs new file mode 100644 index 00000000..0bbb9a8b --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs @@ -0,0 +1,120 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Linq; +using System.Collections.Generic; +using System.Text; +using NumSharp; +using static Tensorflow.KerasApi; +using Tensorflow; +using Tensorflow.Keras.Datasets; + +namespace TensorFlowNET.Keras.UnitTest +{ + [TestClass] + public class PreprocessingTests : EagerModeTestBase + { + private readonly string[] texts = new string[] { + "It was the best of times, it was the worst of times.", + "this is a new dawn, an era to follow the previous era. It can not be said to start anew.", + "It was the best of times, it was the worst of times.", + "this is a new dawn, an era to follow the previous era.", + }; + private const string OOV = ""; + + [TestMethod] + public void TokenizeWithNoOOV() + { + var tokenizer = keras.preprocessing.text.Tokenizer(lower: true); + tokenizer.fit_on_texts(texts); + + Assert.AreEqual(23, tokenizer.word_index.Count); + + Assert.AreEqual(7, tokenizer.word_index["worst"]); + Assert.AreEqual(12, tokenizer.word_index["dawn"]); + Assert.AreEqual(16, tokenizer.word_index["follow"]); + } + + [TestMethod] + public void TokenizeWithOOV() + { + var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV); + tokenizer.fit_on_texts(texts); + + Assert.AreEqual(24, tokenizer.word_index.Count); + + Assert.AreEqual(1, tokenizer.word_index[OOV]); + Assert.AreEqual(8, tokenizer.word_index["worst"]); + Assert.AreEqual(13, tokenizer.word_index["dawn"]); + Assert.AreEqual(17, tokenizer.word_index["follow"]); + } + + [TestMethod] + public void PadSequencesWithDefaults() + { + var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + var padded = keras.preprocessing.sequence.pad_sequences(sequences); + + Assert.AreEqual(4, padded.shape[0]); + Assert.AreEqual(20, padded.shape[1]); + + var firstRow = padded[0]; + var secondRow = padded[1]; + + Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 17].GetInt32()); + for (var i = 0; i < 8; i++) + Assert.AreEqual(0, padded[0, i].GetInt32()); + Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32()); + for (var i = 0; i < 20; i++) + Assert.AreNotEqual(0, padded[1, i].GetInt32()); + } + + [TestMethod] + public void PadSequencesPrePaddingTrunc() + { + var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + var padded = keras.preprocessing.sequence.pad_sequences(sequences,maxlen:15); + + Assert.AreEqual(4, padded.shape[0]); + Assert.AreEqual(15, padded.shape[1]); + + var firstRow = padded[0]; + var secondRow = padded[1]; + + Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 12].GetInt32()); + for (var i = 0; i < 3; i++) + Assert.AreEqual(0, padded[0, i].GetInt32()); + Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 5].GetInt32()); + for (var i = 0; i < 15; i++) + Assert.AreNotEqual(0, padded[1, i].GetInt32()); + } + + [TestMethod] + public void PadSequencesPostPaddingTrunc() + { + var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 15, padding: "post", truncating: "post"); + + Assert.AreEqual(4, padded.shape[0]); + Assert.AreEqual(15, padded.shape[1]); + + var firstRow = padded[0]; + var secondRow = padded[1]; + + Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 9].GetInt32()); + for (var i = 12; i < 15; i++) + Assert.AreEqual(0, padded[0, i].GetInt32()); + Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32()); + for (var i = 0; i < 15; i++) + Assert.AreNotEqual(0, padded[1, i].GetInt32()); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj index 95ead0c5..38cfbd67 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj +++ b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj @@ -6,6 +6,8 @@ false AnyCPU;x64 + + TensorflowNET.Keras.UnitTest