diff --git a/src/TensorFlowNET.Keras/Datasets/Imdb.cs b/src/TensorFlowNET.Keras/Datasets/Imdb.cs index 98769a21..f4f9de5f 100644 --- a/src/TensorFlowNET.Keras/Datasets/Imdb.cs +++ b/src/TensorFlowNET.Keras/Datasets/Imdb.cs @@ -5,6 +5,8 @@ using System.Text; using Tensorflow.Keras.Utils; using NumSharp; using System.Linq; +using NumSharp.Utilities; +using Tensorflow.Queues; namespace Tensorflow.Keras.Datasets { @@ -15,8 +17,10 @@ namespace Tensorflow.Keras.Datasets /// public class Imdb { - string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"; - string file_name = "imdb.npz"; + //string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"; + string origin_folder = "http://ai.stanford.edu/~amaas/data/sentiment/"; + //string file_name = "imdb.npz"; + string file_name = "aclImdb_v1.tar.gz"; string dest_folder = "imdb"; /// @@ -37,38 +41,66 @@ namespace Tensorflow.Keras.Datasets int maxlen = -1, int seed = 113, int start_char = 1, - int oov_char= 2, + int oov_char = 2, int index_from = 3) { var dst = Download(); - var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt")); - var x_train_string = new string[lines.Length]; - var y_train = np.zeros(new int[] { lines.Length }, NPTypeCode.Int64); - for (int i = 0; i < lines.Length; i++) + var vocab = BuildVocabulary(Path.Combine(dst, "imdb.vocab"), start_char, oov_char, index_from); + + var (x_train,y_train) = GetDataSet(Path.Combine(dst, "train")); + var (x_test, y_test) = GetDataSet(Path.Combine(dst, "test")); + + return new DatasetPass { - y_train[i] = long.Parse(lines[i].Substring(0, 1)); - x_train_string[i] = lines[i].Substring(2); - } + Train = (x_train, y_train), + Test = (x_test, y_test) + }; + } - var x_train = np.array(x_train_string); + private static Dictionary BuildVocabulary(string path, + int start_char, + int oov_char, + int index_from) + { + var words = File.ReadAllLines(path); + var result = new Dictionary(); + var idx = index_from; - File.ReadAllLines(Path.Combine(dst, "imdb_test.txt")); - var x_test_string = new string[lines.Length]; - var y_test = np.zeros(new int[] { lines.Length }, NPTypeCode.Int64); - for (int i = 0; i < lines.Length; i++) + foreach (var word in words) { - y_test[i] = long.Parse(lines[i].Substring(0, 1)); - x_test_string[i] = lines[i].Substring(2); + result[word] = idx; + idx += 1; } - var x_test = np.array(x_test_string); + return result; + } - return new DatasetPass + private static (NDArray, NDArray) GetDataSet(string path) + { + var posFiles = Directory.GetFiles(Path.Combine(path, "pos")).Slice(0,10); + var negFiles = Directory.GetFiles(Path.Combine(path, "neg")).Slice(0,10); + + var x_string = new string[posFiles.Length + negFiles.Length]; + var y = new int[posFiles.Length + negFiles.Length]; + var trg = 0; + var longest = 0; + + for (int i = 0; i < posFiles.Length; i++, trg++) { - Train = (x_train, y_train), - Test = (x_test, y_test) - }; + y[trg] = 1; + x_string[trg] = File.ReadAllText(posFiles[i]); + longest = Math.Max(longest, x_string[trg].Length); + } + for (int i = 0; i < posFiles.Length; i++, trg++) + { + y[trg] = 0; + x_string[trg] = File.ReadAllText(negFiles[i]); + longest = Math.Max(longest, x_string[trg].Length); + } + var x = np.array(x_string); + + return (x, y); } (NDArray, NDArray) LoadX(byte[] bytes) @@ -90,8 +122,9 @@ namespace Tensorflow.Keras.Datasets Web.Download(origin_folder + file_name, dst, file_name); - return dst; - // return Path.Combine(dst, file_name); + Tensorflow.Keras.Utils.Compress.ExtractTGZ(Path.Combine(dst, file_name), dst); + + return Path.Combine(dst, "aclImdb"); } } } diff --git a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs index aaca1cb9..8bf7cf38 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs @@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Text /// /// Updates internal vocabulary based on a list of texts. /// - /// + /// A list of strings, each containing one or more tokens. /// Required before using texts_to_sequences or texts_to_matrix. public void fit_on_texts(IEnumerable texts) { @@ -90,7 +90,7 @@ namespace Tensorflow.Keras.Text } var wcounts = word_counts.AsEnumerable().ToList(); - wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); + wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); // Note: '-' gives us descending order. var sorted_voc = (oov_token == null) ? new List() : new List() { oov_token }; sorted_voc.AddRange(word_counts.Select(kv => kv.Key)); @@ -120,7 +120,12 @@ namespace Tensorflow.Keras.Text } } - public void fit_on_texts(IEnumerable> texts) + /// + /// Updates internal vocabulary based on a list of texts. + /// + /// A list of list of strings, each containing one token. + /// Required before using texts_to_sequences or texts_to_matrix. + public void fit_on_texts(IEnumerable> texts) { foreach (var seq in texts) { @@ -197,7 +202,7 @@ namespace Tensorflow.Keras.Text /// /// /// Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account. - public IList texts_to_sequences(IEnumerable> texts) + public IList texts_to_sequences(IEnumerable> texts) { return texts_to_sequences_generator(texts).ToArray(); } @@ -224,6 +229,13 @@ namespace Tensorflow.Keras.Text }); } + public IEnumerable texts_to_sequences_generator(IEnumerable> texts) + { + int oov_index = -1; + var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index); + return texts.Select(seq => ConvertToSequence(oov_index, seq).ToArray()); + } + private List ConvertToSequence(int oov_index, IEnumerable seq) { var vect = new List(); @@ -244,7 +256,7 @@ namespace Tensorflow.Keras.Text vect.Add(i); } } - else if(oov_index != -1) + else if (oov_index != -1) { vect.Add(oov_index); } @@ -253,13 +265,6 @@ namespace Tensorflow.Keras.Text return vect; } - public IEnumerable texts_to_sequences_generator(IEnumerable> texts) - { - int oov_index = -1; - var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index); - return texts.Select(seq => ConvertToSequence(oov_index, seq).ToArray()); - } - /// /// Transforms each sequence into a list of text. /// @@ -271,7 +276,7 @@ namespace Tensorflow.Keras.Text return sequences_to_texts_generator(sequences).ToArray(); } - public IEnumerable sequences_to_texts_generator(IEnumerable sequences) + public IEnumerable sequences_to_texts_generator(IEnumerable> sequences) { int oov_index = -1; var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index); @@ -280,7 +285,7 @@ namespace Tensorflow.Keras.Text { var bldr = new StringBuilder(); - for (var i = 0; i < seq.Length; i++) + for (var i = 0; i < seq.Count; i++) { if (i > 0) bldr.Append(' '); @@ -314,7 +319,7 @@ namespace Tensorflow.Keras.Text /// /// /// - public NDArray sequences_to_matrix(IEnumerable sequences) + public NDArray sequences_to_matrix(IEnumerable> sequences) { throw new NotImplementedException("sequences_to_matrix"); } diff --git a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs index ebde87fa..ad4f91bf 100644 --- a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs +++ b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs @@ -15,23 +15,23 @@ namespace TensorFlowNET.Keras.UnitTest { private readonly string[] texts = new string[] { "It was the best of times, it was the worst of times.", - "this is a new dawn, an era to follow the previous era. It can not be said to start anew.", + "Mr and Mrs Dursley of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.", "It was the best of times, it was the worst of times.", - "this is a new dawn, an era to follow the previous era.", + "Mr and Mrs Dursley of number four, Privet Drive.", }; private readonly string[][] tokenized_texts = new string[][] { new string[] {"It","was","the","best","of","times","it","was","the","worst","of","times"}, - new string[] {"this","is","a","new","dawn","an","era","to","follow","the","previous","era","It","can","not","be","said","to","start","anew" }, + new string[] {"mr","and","mrs","dursley","of","number","four","privet","drive","were","proud","to","say","that","they","were","perfectly","normal","thank","you","very","much"}, new string[] {"It","was","the","best","of","times","it","was","the","worst","of","times"}, - new string[] {"this","is","a","new","dawn","an","era","to","follow","the","previous","era" }, + new string[] {"mr","and","mrs","dursley","of","number","four","privet","drive"}, }; private readonly string[] processed_texts = new string[] { "it was the best of times it was the worst of times", - "this is a new dawn an era to follow the previous era it can not be said to start anew", + "mr and mrs dursley of number four privet drive were proud to say that they were perfectly normal thank you very much", "it was the best of times it was the worst of times", - "this is a new dawn an era to follow the previous era", + "mr and mrs dursley of number four privet drive", }; private const string OOV = ""; @@ -42,11 +42,11 @@ namespace TensorFlowNET.Keras.UnitTest var tokenizer = keras.preprocessing.text.Tokenizer(); tokenizer.fit_on_texts(texts); - Assert.AreEqual(23, tokenizer.word_index.Count); + Assert.AreEqual(27, tokenizer.word_index.Count); Assert.AreEqual(7, tokenizer.word_index["worst"]); - Assert.AreEqual(12, tokenizer.word_index["dawn"]); - Assert.AreEqual(16, tokenizer.word_index["follow"]); + Assert.AreEqual(12, tokenizer.word_index["number"]); + Assert.AreEqual(16, tokenizer.word_index["were"]); } [TestMethod] @@ -56,11 +56,11 @@ namespace TensorFlowNET.Keras.UnitTest // Use the list version, where the tokenization has already been done. tokenizer.fit_on_texts(tokenized_texts); - Assert.AreEqual(23, tokenizer.word_index.Count); + Assert.AreEqual(27, tokenizer.word_index.Count); Assert.AreEqual(7, tokenizer.word_index["worst"]); - Assert.AreEqual(12, tokenizer.word_index["dawn"]); - Assert.AreEqual(16, tokenizer.word_index["follow"]); + Assert.AreEqual(12, tokenizer.word_index["number"]); + Assert.AreEqual(16, tokenizer.word_index["were"]); } [TestMethod] @@ -69,12 +69,12 @@ namespace TensorFlowNET.Keras.UnitTest var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV); tokenizer.fit_on_texts(texts); - Assert.AreEqual(24, tokenizer.word_index.Count); + Assert.AreEqual(28, tokenizer.word_index.Count); Assert.AreEqual(1, tokenizer.word_index[OOV]); Assert.AreEqual(8, tokenizer.word_index["worst"]); - Assert.AreEqual(13, tokenizer.word_index["dawn"]); - Assert.AreEqual(17, tokenizer.word_index["follow"]); + Assert.AreEqual(13, tokenizer.word_index["number"]); + Assert.AreEqual(17, tokenizer.word_index["were"]); } [TestMethod] @@ -84,12 +84,12 @@ namespace TensorFlowNET.Keras.UnitTest // Use the list version, where the tokenization has already been done. tokenizer.fit_on_texts(tokenized_texts); - Assert.AreEqual(24, tokenizer.word_index.Count); + Assert.AreEqual(28, tokenizer.word_index.Count); Assert.AreEqual(1, tokenizer.word_index[OOV]); Assert.AreEqual(8, tokenizer.word_index["worst"]); - Assert.AreEqual(13, tokenizer.word_index["dawn"]); - Assert.AreEqual(17, tokenizer.word_index["follow"]); + Assert.AreEqual(13, tokenizer.word_index["number"]); + Assert.AreEqual(17, tokenizer.word_index["were"]); } [TestMethod] @@ -102,7 +102,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(4, sequences.Count); Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]); - Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]); + Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]); } [TestMethod] @@ -116,7 +116,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(4, sequences.Count); Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]); - Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]); + Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]); } [TestMethod] @@ -200,7 +200,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(4, sequences.Count); Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]); - Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]); + Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]); for (var i = 0; i < sequences.Count; i++) for (var j = 0; j < sequences[i].Length; j++) @@ -217,7 +217,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(4, sequences.Count); Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]); - Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]); + Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]); var oov_count = 0; for (var i = 0; i < sequences.Count; i++) @@ -225,7 +225,7 @@ namespace TensorFlowNET.Keras.UnitTest if (tokenizer.word_index[OOV] == sequences[i][j]) oov_count += 1; - Assert.AreEqual(5, oov_count); + Assert.AreEqual(9, oov_count); } [TestMethod] @@ -238,15 +238,15 @@ namespace TensorFlowNET.Keras.UnitTest var padded = keras.preprocessing.sequence.pad_sequences(sequences); Assert.AreEqual(4, padded.shape[0]); - Assert.AreEqual(20, padded.shape[1]); + Assert.AreEqual(22, padded.shape[1]); var firstRow = padded[0]; var secondRow = padded[1]; - Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 17].GetInt32()); + Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 19].GetInt32()); for (var i = 0; i < 8; i++) Assert.AreEqual(0, padded[0, i].GetInt32()); - Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32()); + Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 10].GetInt32()); for (var i = 0; i < 20; i++) Assert.AreNotEqual(0, padded[1, i].GetInt32()); } @@ -269,7 +269,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 12].GetInt32()); for (var i = 0; i < 3; i++) Assert.AreEqual(0, padded[0, i].GetInt32()); - Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 5].GetInt32()); + Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 3].GetInt32()); for (var i = 0; i < 15; i++) Assert.AreNotEqual(0, padded[1, i].GetInt32()); } @@ -292,7 +292,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 9].GetInt32()); for (var i = 12; i < 15; i++) Assert.AreEqual(0, padded[0, i].GetInt32()); - Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32()); + Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 10].GetInt32()); for (var i = 0; i < 15; i++) Assert.AreNotEqual(0, padded[1, i].GetInt32()); }