From 8e3ba22c832e6d34598644686e00182924b08c3a Mon Sep 17 00:00:00 2001 From: lingbai-kong Date: Sat, 26 Aug 2023 16:29:28 +0800 Subject: [PATCH] fix: validate dataset of `Imdb` do not load bug & add: custom `Imdb` path --- src/TensorFlowNET.Keras/Datasets/Imdb.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Keras/Datasets/Imdb.cs b/src/TensorFlowNET.Keras/Datasets/Imdb.cs index 61ce3947..a62f3f87 100644 --- a/src/TensorFlowNET.Keras/Datasets/Imdb.cs +++ b/src/TensorFlowNET.Keras/Datasets/Imdb.cs @@ -31,7 +31,7 @@ namespace Tensorflow.Keras.Datasets /// /// /// - public DatasetPass load_data(string path = "imdb.npz", + public DatasetPass load_data(string? path = "imdb.npz", int num_words = -1, int skip_top = 0, int maxlen = -1, @@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Datasets { if (maxlen == -1) throw new InvalidArgumentError("maxlen must be assigned."); - var dst = Download(); + var dst = path ?? Download(); var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt")); var x_train_string = new string[lines.Length]; @@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Datasets var x_train = keras.preprocessing.sequence.pad_sequences(PraseData(x_train_string), maxlen: maxlen); - File.ReadAllLines(Path.Combine(dst, "imdb_test.txt")); + lines = File.ReadAllLines(Path.Combine(dst, "imdb_test.txt")); var x_test_string = new string[lines.Length]; var y_test = np.zeros(new int[] { lines.Length }, np.int64); for (int i = 0; i < lines.Length; i++)