Browse Source

Merge pull request #1182 from lingbai-kong/imdbfix

fix: adjust imdb dataset loader for faster loading speed
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
8e02682637
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 15 deletions
  1. +17
    -12
      src/TensorFlowNET.Keras/Datasets/Imdb.cs
  2. +5
    -3
      src/TensorFlowNET.Keras/Utils/data_utils.cs

+ 17
- 12
src/TensorFlowNET.Keras/Datasets/Imdb.cs View File

@@ -112,35 +112,39 @@ namespace Tensorflow.Keras.Datasets

if (start_char != null)
{
int[,] new_x_train_array = new int[x_train_array.GetLength(0), x_train_array.GetLength(1) + 1];
for (var i = 0; i < x_train_array.GetLength(0); i++)
var (d1, d2) = (x_train_array.GetLength(0), x_train_array.GetLength(1));
int[,] new_x_train_array = new int[d1, d2 + 1];
for (var i = 0; i < d1; i++)
{
new_x_train_array[i, 0] = (int)start_char;
Array.Copy(x_train_array, i * x_train_array.GetLength(1), new_x_train_array, i * new_x_train_array.GetLength(1) + 1, x_train_array.GetLength(1));
Array.Copy(x_train_array, i * d2, new_x_train_array, i * (d2 + 1) + 1, d2);
}
int[,] new_x_test_array = new int[x_test_array.GetLength(0), x_test_array.GetLength(1) + 1];
for (var i = 0; i < x_test_array.GetLength(0); i++)
(d1, d2) = (x_test_array.GetLength(0), x_test_array.GetLength(1));
int[,] new_x_test_array = new int[d1, d2 + 1];
for (var i = 0; i < d1; i++)
{
new_x_test_array[i, 0] = (int)start_char;
Array.Copy(x_test_array, i * x_test_array.GetLength(1), new_x_test_array, i * new_x_test_array.GetLength(1) + 1, x_test_array.GetLength(1));
Array.Copy(x_test_array, i * d2, new_x_test_array, i * (d2 + 1) + 1, d2);
}
x_train_array = new_x_train_array;
x_test_array = new_x_test_array;
}
else if (index_from != 0)
{
for (var i = 0; i < x_train_array.GetLength(0); i++)
var (d1, d2) = (x_train_array.GetLength(0), x_train_array.GetLength(1));
for (var i = 0; i < d1; i++)
{
for (var j = 0; j < x_train_array.GetLength(1); j++)
for (var j = 0; j < d2; j++)
{
if (x_train_array[i, j] == 0)
break;
x_train_array[i, j] += index_from;
}
}
for (var i = 0; i < x_test_array.GetLength(0); i++)
(d1, d2) = (x_test_array.GetLength(0), x_test_array.GetLength(1));
for (var i = 0; i < d1; i++)
{
for (var j = 0; j < x_test_array.GetLength(1); j++)
for (var j = 0; j < d2; j++)
{
if (x_test_array[i, j] == 0)
break;
@@ -169,9 +173,10 @@ namespace Tensorflow.Keras.Datasets

if (num_words == null)
{
var (d1, d2) = (xs_array.GetLength(0), xs_array.GetLength(1));
num_words = 0;
for (var i = 0; i < xs_array.GetLength(0); i++)
for (var j = 0; j < xs_array.GetLength(1); j++)
for (var i = 0; i < d1; i++)
for (var j = 0; j < d2; j++)
num_words = max((int)num_words, (int)xs_array[i, j]);
}



+ 5
- 3
src/TensorFlowNET.Keras/Utils/data_utils.cs View File

@@ -53,15 +53,17 @@ namespace Tensorflow.Keras.Utils
new_seq, new_label: shortened lists for `seq` and `label`.

*/
var nRow = seq.GetLength(0);
var nCol = seq.GetLength(1);
List<int[]> new_seq = new List<int[]>();
List<long> new_label = new List<long>();

for (var i = 0; i < seq.GetLength(0); i++)
for (var i = 0; i < nRow; i++)
{
if (maxlen < seq.GetLength(1) && seq[i, maxlen] != 0)
if (maxlen < nCol && seq[i, maxlen] != 0)
continue;
int[] sentence = new int[maxlen];
for (var j = 0; j < maxlen && j < seq.GetLength(1); j++)
for (var j = 0; j < maxlen && j < nCol; j++)
{
sentence[j] = seq[i, j];
}


Loading…
Cancel
Save