You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Imdb.cs 11 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Text;
  5. using Tensorflow.Keras.Utils;
  6. namespace Tensorflow.Keras.Datasets
  7. {
  8. /// <summary>
  9. /// This is a dataset of 25,000 movies reviews from IMDB, labeled by sentiment
  10. /// (positive/negative). Reviews have been preprocessed, and each review is
  11. /// encoded as a list of word indexes(integers).
  12. /// For convenience, words are indexed by overall frequency in the dataset,
  13. /// so that for instance the integer "3" encodes the 3rd most frequent word in
  14. /// the data.This allows for quick filtering operations such as:
  15. /// "only consider the top 10,000 most
  16. /// common words, but eliminate the top 20 most common words".
  17. /// As a convention, "0" does not stand for a specific word, but instead is used
  18. /// to encode the pad token.
  19. /// Args:
  20. /// path: where to cache the data (relative to %TEMP%/imdb/imdb.npz).
  21. /// num_words: integer or None.Words are
  22. /// ranked by how often they occur(in the training set) and only
  23. /// the `num_words` most frequent words are kept.Any less frequent word
  24. /// will appear as `oov_char` value in the sequence data.If None,
  25. /// all words are kept.Defaults to `None`.
  26. /// skip_top: skip the top N most frequently occurring words
  27. /// (which may not be informative). These words will appear as
  28. /// `oov_char` value in the dataset.When 0, no words are
  29. /// skipped. Defaults to `0`.
  30. /// maxlen: int or None.Maximum sequence length.
  31. /// Any longer sequence will be truncated. None, means no truncation.
  32. /// Defaults to `None`.
  33. /// seed: int. Seed for reproducible data shuffling.
  34. /// start_char: int. The start of a sequence will be marked with this
  35. /// character. 0 is usually the padding character. Defaults to `1`.
  36. /// oov_char: int. The out-of-vocabulary character.
  37. /// Words that were cut out because of the `num_words` or
  38. /// `skip_top` limits will be replaced with this character.
  39. /// index_from: int. Index actual words with this index and higher.
  40. /// Returns:
  41. /// Tuple of Numpy arrays: `(x_train, labels_train), (x_test, labels_test)`.
  42. ///
  43. /// ** x_train, x_test**: lists of sequences, which are lists of indexes
  44. /// (integers). If the num_words argument was specific, the maximum
  45. /// possible index value is `num_words - 1`. If the `maxlen` argument was
  46. /// specified, the largest possible sequence length is `maxlen`.
  47. ///
  48. /// ** labels_train, labels_test**: lists of integer labels(1 or 0).
  49. ///
  50. /// Raises:
  51. /// ValueError: in case `maxlen` is so low
  52. /// that no input sequence could be kept.
  53. /// Note that the 'out of vocabulary' character is only used for
  54. /// words that were present in the training set but are not included
  55. /// because they're not making the `num_words` cut here.
  56. /// Words that were not seen in the training set but are in the test set
  57. /// have simply been skipped.
  58. /// </summary>
  59. /// """Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
  60. public class Imdb
  61. {
  62. string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/";
  63. string dest_folder = "imdb";
  64. /// <summary>
  65. /// Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
  66. /// </summary>
  67. /// <param name="path"></param>
  68. /// <param name="num_words"></param>
  69. /// <param name="skip_top"></param>
  70. /// <param name="maxlen"></param>
  71. /// <param name="seed"></param>
  72. /// <param name="start_char"></param>
  73. /// <param name="oov_char"></param>
  74. /// <param name="index_from"></param>
  75. /// <returns></returns>
  76. public DatasetPass load_data(
  77. string path = "imdb.npz",
  78. int? num_words = null,
  79. int skip_top = 0,
  80. int? maxlen = null,
  81. int seed = 113,
  82. int? start_char = 1,
  83. int? oov_char = 2,
  84. int index_from = 3)
  85. {
  86. path = data_utils.get_file(
  87. path,
  88. origin: Path.Combine(origin_folder, "imdb.npz"),
  89. file_hash: "69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
  90. );
  91. path = Path.Combine(path, "imdb.npz");
  92. var fileBytes = File.ReadAllBytes(path);
  93. var (x_train, x_test) = LoadX(fileBytes);
  94. var (labels_train, labels_test) = LoadY(fileBytes);
  95. var indices = np.arange<int>(len(x_train));
  96. np.random.shuffle(indices, seed);
  97. x_train = x_train[indices];
  98. labels_train = labels_train[indices];
  99. indices = np.arange<int>(len(x_test));
  100. np.random.shuffle(indices, seed);
  101. x_test = x_test[indices];
  102. labels_test = labels_test[indices];
  103. var x_train_array = (int[,])x_train.ToMultiDimArray<int>();
  104. var x_test_array = (int[,])x_test.ToMultiDimArray<int>();
  105. var labels_train_array = (long[])labels_train.ToArray<long>();
  106. var labels_test_array = (long[])labels_test.ToArray<long>();
  107. if (start_char != null)
  108. {
  109. int[,] new_x_train_array = new int[x_train_array.GetLength(0), x_train_array.GetLength(1) + 1];
  110. for (var i = 0; i < x_train_array.GetLength(0); i++)
  111. {
  112. new_x_train_array[i, 0] = (int)start_char;
  113. Array.Copy(x_train_array, i * x_train_array.GetLength(1), new_x_train_array, i * new_x_train_array.GetLength(1) + 1, x_train_array.GetLength(1));
  114. }
  115. int[,] new_x_test_array = new int[x_test_array.GetLength(0), x_test_array.GetLength(1) + 1];
  116. for (var i = 0; i < x_test_array.GetLength(0); i++)
  117. {
  118. new_x_test_array[i, 0] = (int)start_char;
  119. Array.Copy(x_test_array, i * x_test_array.GetLength(1), new_x_test_array, i * new_x_test_array.GetLength(1) + 1, x_test_array.GetLength(1));
  120. }
  121. x_train_array = new_x_train_array;
  122. x_test_array = new_x_test_array;
  123. }
  124. else if (index_from != 0)
  125. {
  126. for (var i = 0; i < x_train_array.GetLength(0); i++)
  127. {
  128. for (var j = 0; j < x_train_array.GetLength(1); j++)
  129. {
  130. if (x_train_array[i, j] == 0)
  131. break;
  132. x_train_array[i, j] += index_from;
  133. }
  134. }
  135. for (var i = 0; i < x_test_array.GetLength(0); i++)
  136. {
  137. for (var j = 0; j < x_test_array.GetLength(1); j++)
  138. {
  139. if (x_test_array[i, j] == 0)
  140. break;
  141. x_test[i, j] += index_from;
  142. }
  143. }
  144. }
  145. if (maxlen == null)
  146. {
  147. maxlen = max(x_train_array.GetLength(1), x_test_array.GetLength(1));
  148. }
  149. (x_train_array, labels_train_array) = data_utils._remove_long_seq((int)maxlen, x_train_array, labels_train_array);
  150. (x_test_array, labels_test_array) = data_utils._remove_long_seq((int)maxlen, x_test_array, labels_test_array);
  151. if (x_train_array.Length == 0 || x_test_array.Length == 0)
  152. throw new ValueError("After filtering for sequences shorter than maxlen=" +
  153. $"{maxlen}, no sequence was kept. Increase maxlen.");
  154. int[,] xs_array = new int[x_train_array.GetLength(0) + x_test_array.GetLength(0), (int)maxlen];
  155. Array.Copy(x_train_array, xs_array, x_train_array.Length);
  156. Array.Copy(x_test_array, 0, xs_array, x_train_array.Length, x_train_array.Length);
  157. long[] labels_array = new long[labels_train_array.Length + labels_test_array.Length];
  158. Array.Copy(labels_train_array, labels_array, labels_train_array.Length);
  159. Array.Copy(labels_test_array, 0, labels_array, labels_train_array.Length, labels_test_array.Length);
  160. if (num_words == null)
  161. {
  162. num_words = 0;
  163. for (var i = 0; i < xs_array.GetLength(0); i++)
  164. for (var j = 0; j < xs_array.GetLength(1); j++)
  165. num_words = max((int)num_words, (int)xs_array[i, j]);
  166. }
  167. // by convention, use 2 as OOV word
  168. // reserve 'index_from' (=3 by default) characters:
  169. // 0 (padding), 1 (start), 2 (OOV)
  170. if (oov_char != null)
  171. {
  172. int[,] new_xs_array = new int[xs_array.GetLength(0), xs_array.GetLength(1)];
  173. for (var i = 0; i < xs_array.GetLength(0); i++)
  174. {
  175. for (var j = 0; j < xs_array.GetLength(1); j++)
  176. {
  177. if (xs_array[i, j] == 0 || skip_top <= xs_array[i, j] && xs_array[i, j] < num_words)
  178. new_xs_array[i, j] = xs_array[i, j];
  179. else
  180. new_xs_array[i, j] = (int)oov_char;
  181. }
  182. }
  183. xs_array = new_xs_array;
  184. }
  185. else
  186. {
  187. int[,] new_xs_array = new int[xs_array.GetLength(0), xs_array.GetLength(1)];
  188. for (var i = 0; i < xs_array.GetLength(0); i++)
  189. {
  190. int k = 0;
  191. for (var j = 0; j < xs_array.GetLength(1); j++)
  192. {
  193. if (xs_array[i, j] == 0 || skip_top <= xs_array[i, j] && xs_array[i, j] < num_words)
  194. new_xs_array[i, k++] = xs_array[i, j];
  195. }
  196. }
  197. xs_array = new_xs_array;
  198. }
  199. Array.Copy(xs_array, x_train_array, x_train_array.Length);
  200. Array.Copy(xs_array, x_train_array.Length, x_test_array, 0, x_train_array.Length);
  201. Array.Copy(labels_array, labels_train_array, labels_train_array.Length);
  202. Array.Copy(labels_array, labels_train_array.Length, labels_test_array, 0, labels_test_array.Length);
  203. return new DatasetPass
  204. {
  205. Train = (x_train_array, labels_train_array),
  206. Test = (x_test_array, labels_test_array)
  207. };
  208. }
  209. (NDArray, NDArray) LoadX(byte[] bytes)
  210. {
  211. var x = np.Load_Npz<int[,]>(bytes);
  212. return (x["x_train.npy"], x["x_test.npy"]);
  213. }
  214. (NDArray, NDArray) LoadY(byte[] bytes)
  215. {
  216. var y = np.Load_Npz<long[]>(bytes);
  217. return (y["y_train.npy"], y["y_test.npy"]);
  218. }
  219. }
  220. }