Browse Source

Merge pull request #1175 from lingbai-kong/ndarrayload

optimize: temporal complexity of Imdb dataset loader
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
c814fe121a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 35 deletions
  1. +21
    -27
      src/TensorFlowNET.Keras/Datasets/Imdb.cs
  2. +6
    -8
      src/TensorFlowNET.Keras/Utils/data_utils.cs

+ 21
- 27
src/TensorFlowNET.Keras/Datasets/Imdb.cs View File

@@ -116,23 +116,13 @@ namespace Tensorflow.Keras.Datasets
for (var i = 0; i < x_train_array.GetLength(0); i++) for (var i = 0; i < x_train_array.GetLength(0); i++)
{ {
new_x_train_array[i, 0] = (int)start_char; new_x_train_array[i, 0] = (int)start_char;
for (var j = 0; j < x_train_array.GetLength(1); j++)
{
if (x_train_array[i, j] == 0)
break;
new_x_train_array[i, j + 1] = x_train_array[i, j];
}
Array.Copy(x_train_array, i * x_train_array.GetLength(1), new_x_train_array, i * new_x_train_array.GetLength(1) + 1, x_train_array.GetLength(1));
} }
int[,] new_x_test_array = new int[x_test_array.GetLength(0), x_test_array.GetLength(1) + 1]; int[,] new_x_test_array = new int[x_test_array.GetLength(0), x_test_array.GetLength(1) + 1];
for (var i = 0; i < x_test_array.GetLength(0); i++) for (var i = 0; i < x_test_array.GetLength(0); i++)
{ {
new_x_test_array[i, 0] = (int)start_char; new_x_test_array[i, 0] = (int)start_char;
for (var j = 0; j < x_test_array.GetLength(1); j++)
{
if (x_test_array[i, j] == 0)
break;
new_x_test_array[i, j + 1] = x_test_array[i, j];
}
Array.Copy(x_test_array, i * x_test_array.GetLength(1), new_x_test_array, i * new_x_test_array.GetLength(1) + 1, x_test_array.GetLength(1));
} }
x_train_array = new_x_train_array; x_train_array = new_x_train_array;
x_test_array = new_x_test_array; x_test_array = new_x_test_array;
@@ -163,15 +153,19 @@ namespace Tensorflow.Keras.Datasets
{ {
maxlen = max(x_train_array.GetLength(1), x_test_array.GetLength(1)); maxlen = max(x_train_array.GetLength(1), x_test_array.GetLength(1));
} }
(x_train, labels_train) = data_utils._remove_long_seq((int)maxlen, x_train_array, labels_train_array);
(x_test, labels_test) = data_utils._remove_long_seq((int)maxlen, x_test_array, labels_test_array);
if (x_train.size == 0 || x_test.size == 0)
(x_train_array, labels_train_array) = data_utils._remove_long_seq((int)maxlen, x_train_array, labels_train_array);
(x_test_array, labels_test_array) = data_utils._remove_long_seq((int)maxlen, x_test_array, labels_test_array);
if (x_train_array.Length == 0 || x_test_array.Length == 0)
throw new ValueError("After filtering for sequences shorter than maxlen=" + throw new ValueError("After filtering for sequences shorter than maxlen=" +
$"{maxlen}, no sequence was kept. Increase maxlen."); $"{maxlen}, no sequence was kept. Increase maxlen.");


var xs = np.concatenate(new[] { x_train, x_test });
var labels = np.concatenate(new[] { labels_train, labels_test });
var xs_array = (int[,])xs.ToMultiDimArray<int>();
int[,] xs_array = new int[x_train_array.GetLength(0) + x_test_array.GetLength(0), (int)maxlen];
Array.Copy(x_train_array, xs_array, x_train_array.Length);
Array.Copy(x_test_array, 0, xs_array, x_train_array.Length, x_train_array.Length);

long[] labels_array = new long[labels_train_array.Length + labels_test_array.Length];
Array.Copy(labels_train_array, labels_array, labels_train_array.Length);
Array.Copy(labels_test_array, 0, labels_array, labels_train_array.Length, labels_test_array.Length);


if (num_words == null) if (num_words == null)
{ {
@@ -197,7 +191,7 @@ namespace Tensorflow.Keras.Datasets
new_xs_array[i, j] = (int)oov_char; new_xs_array[i, j] = (int)oov_char;
} }
} }
xs = new NDArray(new_xs_array);
xs_array = new_xs_array;
} }
else else
{ {
@@ -211,19 +205,19 @@ namespace Tensorflow.Keras.Datasets
new_xs_array[i, k++] = xs_array[i, j]; new_xs_array[i, k++] = xs_array[i, j];
} }
} }
xs = new NDArray(new_xs_array);
xs_array = new_xs_array;
} }


var idx = len(x_train);
x_train = xs[$"0:{idx}"];
x_test = xs[$"{idx}:"];
var y_train = labels[$"0:{idx}"];
var y_test = labels[$"{idx}:"];
Array.Copy(xs_array, x_train_array, x_train_array.Length);
Array.Copy(xs_array, x_train_array.Length, x_test_array, 0, x_train_array.Length);
Array.Copy(labels_array, labels_train_array, labels_train_array.Length);
Array.Copy(labels_array, labels_train_array.Length, labels_test_array, 0, labels_test_array.Length);


return new DatasetPass return new DatasetPass
{ {
Train = (x_train, y_train),
Test = (x_test, y_test)
Train = (x_train_array, labels_train_array),
Test = (x_test_array, labels_test_array)
}; };
} }




+ 6
- 8
src/TensorFlowNET.Keras/Utils/data_utils.cs View File

@@ -40,7 +40,7 @@ namespace Tensorflow.Keras.Utils
return datadir; return datadir;
} }


public static (NDArray, NDArray) _remove_long_seq(int maxlen, NDArray seq, NDArray label)
public static (int[,], long[]) _remove_long_seq(int maxlen, int[,] seq, long[] label)
{ {
/*Removes sequences that exceed the maximum length. /*Removes sequences that exceed the maximum length.


@@ -56,19 +56,17 @@ namespace Tensorflow.Keras.Utils
List<int[]> new_seq = new List<int[]>(); List<int[]> new_seq = new List<int[]>();
List<long> new_label = new List<long>(); List<long> new_label = new List<long>();


var seq_array = (int[,])seq.ToMultiDimArray<int>();
var label_array = (long[])label.ToArray<long>();
for (var i = 0; i < seq_array.GetLength(0); i++)
for (var i = 0; i < seq.GetLength(0); i++)
{ {
if (maxlen < seq_array.GetLength(1) && seq_array[i,maxlen] != 0)
if (maxlen < seq.GetLength(1) && seq[i, maxlen] != 0)
continue; continue;
int[] sentence = new int[maxlen]; int[] sentence = new int[maxlen];
for (var j = 0; j < maxlen && j < seq_array.GetLength(1); j++)
for (var j = 0; j < maxlen && j < seq.GetLength(1); j++)
{ {
sentence[j] = seq_array[i, j];
sentence[j] = seq[i, j];
} }
new_seq.Add(sentence); new_seq.Add(sentence);
new_label.Add(label_array[i]);
new_label.Add(label[i]);
} }


int[,] new_seq_array = new int[new_seq.Count, maxlen]; int[,] new_seq_array = new int[new_seq.Count, maxlen];


Loading…
Cancel
Save