Browse Source

Merge pull request #1126 from lingbai-kong/parse_imdb

Add pad preprocessing for `imdb` dataset
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
dfd9dd0d20
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 22 additions and 2 deletions
  1. +22
    -2
      src/TensorFlowNET.Keras/Datasets/Imdb.cs

+ 22
- 2
src/TensorFlowNET.Keras/Datasets/Imdb.cs View File

@@ -40,6 +40,8 @@ namespace Tensorflow.Keras.Datasets
int oov_char= 2,
int index_from = 3)
{
if (maxlen == -1) throw new InvalidArgumentError("maxlen must be assigned.");
var dst = Download();

var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
@@ -51,7 +53,7 @@ namespace Tensorflow.Keras.Datasets
x_train_string[i] = lines[i].Substring(2);
}

var x_train = np.array(x_train_string);
var x_train = keras.preprocessing.sequence.pad_sequences(PraseData(x_train_string), maxlen: maxlen);

File.ReadAllLines(Path.Combine(dst, "imdb_test.txt"));
var x_test_string = new string[lines.Length];
@@ -62,7 +64,7 @@ namespace Tensorflow.Keras.Datasets
x_test_string[i] = lines[i].Substring(2);
}

var x_test = np.array(x_test_string);
var x_test = keras.preprocessing.sequence.pad_sequences(PraseData(x_test_string), maxlen: maxlen);

return new DatasetPass
{
@@ -93,5 +95,23 @@ namespace Tensorflow.Keras.Datasets
return dst;
// return Path.Combine(dst, file_name);
}

protected IEnumerable<int[]> PraseData(string[] x)
{
var data_list = new List<int[]>();
for (int i = 0; i < len(x); i++)
{
var list_string = x[i];
var cleaned_list_string = list_string.Replace("[", "").Replace("]", "").Replace(" ", "");
string[] number_strings = cleaned_list_string.Split(',');
int[] numbers = new int[number_strings.Length];
for (int j = 0; j < number_strings.Length; j++)
{
numbers[j] = int.Parse(number_strings[j]);
}
data_list.Add(numbers);
}
return data_list;
}
}
}

Loading…
Cancel
Save