From 12f4f230f87df4f848694f06c6d34c40d01f7cc4 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Fri, 29 Jan 2021 13:58:18 -0800 Subject: [PATCH] Implemented text_to_matrix() and unit tests. --- .../Preprocessings/Tokenizer.cs | 119 +++++++++++++++++- .../PreprocessingTests.cs | 95 ++++++++++++-- 2 files changed, 202 insertions(+), 12 deletions(-) diff --git a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs index 8bf7cf38..3bf14ce5 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs @@ -3,6 +3,7 @@ using Serilog.Debugging; using System; using System.Collections.Generic; using System.Collections.Specialized; +using System.Data.SqlTypes; using System.Linq; using System.Net.Sockets; using System.Text; @@ -314,14 +315,126 @@ namespace Tensorflow.Keras.Text }); } + /// + /// Convert a list of texts to a Numpy matrix. + /// + /// A sequence of strings containing one or more tokens. + /// One of "binary", "count", "tfidf", "freq". + /// + public NDArray texts_to_matrix(IEnumerable texts, string mode = "binary") + { + return sequences_to_matrix(texts_to_sequences(texts), mode); + } + + /// + /// Convert a list of texts to a Numpy matrix. + /// + /// A sequence of lists of strings, each containing one token. + /// One of "binary", "count", "tfidf", "freq". + /// + public NDArray texts_to_matrix(IEnumerable> texts, string mode = "binary") + { + return sequences_to_matrix(texts_to_sequences(texts), mode); + } + /// /// Converts a list of sequences into a Numpy matrix. /// - /// + /// A sequence of lists of integers, encoding tokens. + /// One of "binary", "count", "tfidf", "freq". /// - public NDArray sequences_to_matrix(IEnumerable> sequences) + public NDArray sequences_to_matrix(IEnumerable> sequences, string mode = "binary") { - throw new NotImplementedException("sequences_to_matrix"); + if (!modes.Contains(mode)) throw new InvalidArgumentError($"Unknown vectorization mode: {mode}"); + var word_count = 0; + + if (num_words == -1) + { + if (word_index != null) + { + word_count = word_index.Count + 1; + } + else + { + throw new InvalidOperationException("Specifya dimension ('num_words' arugment), or fit on some text data first."); + } + } + else + { + word_count = num_words; + } + + if (mode == "tfidf" && this.document_count == 0) + { + throw new InvalidOperationException("Fit the Tokenizer on some text data before using the 'tfidf' mode."); + } + + var x = np.zeros(sequences.Count(), word_count); + + for (int i = 0; i < sequences.Count(); i++) + { + var seq = sequences.ElementAt(i); + if (seq == null || seq.Count == 0) + continue; + + var counts = new Dictionary(); + + var seq_length = seq.Count; + + foreach (var j in seq) + { + if (j >= word_count) + continue; + var count = 0; + counts.TryGetValue(j, out count); + counts[j] = count + 1; + } + + if (mode == "count") + { + foreach (var kv in counts) + { + var j = kv.Key; + var c = kv.Value; + x[i, j] = c; + } + } + else if (mode == "freq") + { + foreach (var kv in counts) + { + var j = kv.Key; + var c = kv.Value; + x[i, j] = ((double)c) / seq_length; + } + } + else if (mode == "binary") + { + foreach (var kv in counts) + { + var j = kv.Key; + var c = kv.Value; + x[i, j] = 1; + } + } + else if (mode == "tfidf") + { + foreach (var kv in counts) + { + var j = kv.Key; + var c = kv.Value; + var id = 0; + var _ = index_docs.TryGetValue(j, out id); + var tf = 1 + np.log(c); + var idf = np.log(1 + document_count / (1 + id)); + x[i, j] = tf * idf; + } + } + } + + return x; } + + private string[] modes = new string[] { "binary", "count", "tfidf", "freq" }; } } diff --git a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs index ad4f91bf..7792beae 100644 --- a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs +++ b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs @@ -7,6 +7,7 @@ using NumSharp; using static Tensorflow.KerasApi; using Tensorflow; using Tensorflow.Keras.Datasets; +using Microsoft.Extensions.DependencyInjection; namespace TensorFlowNET.Keras.UnitTest { @@ -240,9 +241,6 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(4, padded.shape[0]); Assert.AreEqual(22, padded.shape[1]); - var firstRow = padded[0]; - var secondRow = padded[1]; - Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 19].GetInt32()); for (var i = 0; i < 8; i++) Assert.AreEqual(0, padded[0, i].GetInt32()); @@ -263,9 +261,6 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(4, padded.shape[0]); Assert.AreEqual(15, padded.shape[1]); - var firstRow = padded[0]; - var secondRow = padded[1]; - Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 12].GetInt32()); for (var i = 0; i < 3; i++) Assert.AreEqual(0, padded[0, i].GetInt32()); @@ -286,9 +281,6 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(4, padded.shape[0]); Assert.AreEqual(15, padded.shape[1]); - var firstRow = padded[0]; - var secondRow = padded[1]; - Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 9].GetInt32()); for (var i = 12; i < 15; i++) Assert.AreEqual(0, padded[0, i].GetInt32()); @@ -296,5 +288,90 @@ namespace TensorFlowNET.Keras.UnitTest for (var i = 0; i < 15; i++) Assert.AreNotEqual(0, padded[1, i].GetInt32()); } + + [TestMethod] + public void TextToMatrixBinary() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + tokenizer.fit_on_texts(texts); + + Assert.AreEqual(27, tokenizer.word_index.Count); + + var matrix = tokenizer.texts_to_matrix(texts); + + Assert.AreEqual(texts.Length, matrix.shape[0]); + + CompareLists(new double[] { 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, matrix[0].ToArray()); + CompareLists(new double[] { 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }, matrix[1].ToArray()); + } + + [TestMethod] + public void TextToMatrixCount() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + tokenizer.fit_on_texts(texts); + + Assert.AreEqual(27, tokenizer.word_index.Count); + + var matrix = tokenizer.texts_to_matrix(texts, mode:"count"); + + Assert.AreEqual(texts.Length, matrix.shape[0]); + + CompareLists(new double[] { 0, 2, 2, 2, 1, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, matrix[0].ToArray()); + CompareLists(new double[] { 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }, matrix[1].ToArray()); + } + + [TestMethod] + public void TextToMatrixFrequency() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + tokenizer.fit_on_texts(texts); + + Assert.AreEqual(27, tokenizer.word_index.Count); + + var matrix = tokenizer.texts_to_matrix(texts, mode: "freq"); + + Assert.AreEqual(texts.Length, matrix.shape[0]); + + double t12 = 2.0 / 12.0; + double o12 = 1.0 / 12.0; + double t22 = 2.0 / 22.0; + double o22 = 1.0 / 22.0; + + CompareLists(new double[] { 0, t12, t12, t12, o12, t12, t12, o12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, matrix[0].ToArray()); + CompareLists(new double[] { 0, 0, 0, 0, 0, o22, 0, 0, o22, o22, o22, o22, o22, o22, o22, o22, t22, o22, o22, o22, o22, o22, o22, o22, o22, o22, o22, o22 }, matrix[1].ToArray()); + } + + [TestMethod] + public void TextToMatrixTDIDF() + { + var tokenizer = keras.preprocessing.text.Tokenizer(); + tokenizer.fit_on_texts(texts); + + Assert.AreEqual(27, tokenizer.word_index.Count); + + var matrix = tokenizer.texts_to_matrix(texts, mode: "tfidf"); + + Assert.AreEqual(texts.Length, matrix.shape[0]); + + double t1 = 1.1736001944781467; + double t2 = 0.69314718055994529; + double t3 = 1.860112299086919; + double t4 = 1.0986122886681098; + double t5 = 0.69314718055994529; + + CompareLists(new double[] { 0, t1, t1, t1, t2, 0, t1, t2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, matrix[0].ToArray()); + CompareLists(new double[] { 0, 0, 0, 0, 0, 0, 0, 0, t5, t5, t5, t5, t5, t5, t5, t5, t3, t4, t4, t4, t4, t4, t4, t4, t4, t4, t4, t4 }, matrix[1].ToArray()); + } + + private void CompareLists(IList expected, IList actual) + { + Assert.AreEqual(expected.Count, actual.Count); + for (var i = 0; i < expected.Count; i++) + { + Assert.AreEqual(expected[i], actual[i]); + } + } + } }