Browse Source

Merge pull request #1126 from lingbai-kong/parse_imdb

Add pad preprocessing for `imdb` dataset
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
dfd9dd0d20
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 22 additions and 2 deletions
  1. +22
    -2
      src/TensorFlowNET.Keras/Datasets/Imdb.cs

+ 22
- 2
src/TensorFlowNET.Keras/Datasets/Imdb.cs View File

@@ -40,6 +40,8 @@ namespace Tensorflow.Keras.Datasets
int oov_char= 2, int oov_char= 2,
int index_from = 3) int index_from = 3)
{ {
if (maxlen == -1) throw new InvalidArgumentError("maxlen must be assigned.");
var dst = Download(); var dst = Download();


var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt")); var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
@@ -51,7 +53,7 @@ namespace Tensorflow.Keras.Datasets
x_train_string[i] = lines[i].Substring(2); x_train_string[i] = lines[i].Substring(2);
} }


var x_train = np.array(x_train_string);
var x_train = keras.preprocessing.sequence.pad_sequences(PraseData(x_train_string), maxlen: maxlen);


File.ReadAllLines(Path.Combine(dst, "imdb_test.txt")); File.ReadAllLines(Path.Combine(dst, "imdb_test.txt"));
var x_test_string = new string[lines.Length]; var x_test_string = new string[lines.Length];
@@ -62,7 +64,7 @@ namespace Tensorflow.Keras.Datasets
x_test_string[i] = lines[i].Substring(2); x_test_string[i] = lines[i].Substring(2);
} }


var x_test = np.array(x_test_string);
var x_test = keras.preprocessing.sequence.pad_sequences(PraseData(x_test_string), maxlen: maxlen);


return new DatasetPass return new DatasetPass
{ {
@@ -93,5 +95,23 @@ namespace Tensorflow.Keras.Datasets
return dst; return dst;
// return Path.Combine(dst, file_name); // return Path.Combine(dst, file_name);
} }

protected IEnumerable<int[]> PraseData(string[] x)
{
var data_list = new List<int[]>();
for (int i = 0; i < len(x); i++)
{
var list_string = x[i];
var cleaned_list_string = list_string.Replace("[", "").Replace("]", "").Replace(" ", "");
string[] number_strings = cleaned_list_string.Split(',');
int[] numbers = new int[number_strings.Length];
for (int j = 0; j < number_strings.Length; j++)
{
numbers[j] = int.Parse(number_strings[j]);
}
data_list.Add(numbers);
}
return data_list;
}
} }
} }

Loading…
Cancel
Save