You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Imdb.cs 4.2 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Text;
  5. using Tensorflow.Keras.Utils;
  6. using Tensorflow.NumPy;
  7. using System.Linq;
  8. namespace Tensorflow.Keras.Datasets
  9. {
  10. /// <summary>
  11. /// This is a dataset of 25,000 movies reviews from IMDB, labeled by sentiment
  12. /// (positive/negative). Reviews have been preprocessed, and each review is
  13. /// encoded as a list of word indexes(integers).
  14. /// </summary>
  15. public class Imdb
  16. {
  17. string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/";
  18. string file_name = "imdb.npz";
  19. string dest_folder = "imdb";
  20. /// <summary>
  21. /// Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
  22. /// </summary>
  23. /// <param name="path"></param>
  24. /// <param name="num_words"></param>
  25. /// <param name="skip_top"></param>
  26. /// <param name="maxlen"></param>
  27. /// <param name="seed"></param>
  28. /// <param name="start_char"></param>
  29. /// <param name="oov_char"></param>
  30. /// <param name="index_from"></param>
  31. /// <returns></returns>
  32. public DatasetPass load_data(string path = "imdb.npz",
  33. int num_words = -1,
  34. int skip_top = 0,
  35. int maxlen = -1,
  36. int seed = 113,
  37. int start_char = 1,
  38. int oov_char= 2,
  39. int index_from = 3)
  40. {
  41. if (maxlen == -1) throw new InvalidArgumentError("maxlen must be assigned.");
  42. var dst = Download();
  43. var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
  44. var x_train_string = new string[lines.Length];
  45. var y_train = np.zeros(new int[] { lines.Length }, np.int64);
  46. for (int i = 0; i < lines.Length; i++)
  47. {
  48. y_train[i] = long.Parse(lines[i].Substring(0, 1));
  49. x_train_string[i] = lines[i].Substring(2);
  50. }
  51. var x_train = keras.preprocessing.sequence.pad_sequences(PraseData(x_train_string), maxlen: maxlen);
  52. File.ReadAllLines(Path.Combine(dst, "imdb_test.txt"));
  53. var x_test_string = new string[lines.Length];
  54. var y_test = np.zeros(new int[] { lines.Length }, np.int64);
  55. for (int i = 0; i < lines.Length; i++)
  56. {
  57. y_test[i] = long.Parse(lines[i].Substring(0, 1));
  58. x_test_string[i] = lines[i].Substring(2);
  59. }
  60. var x_test = keras.preprocessing.sequence.pad_sequences(PraseData(x_test_string), maxlen: maxlen);
  61. return new DatasetPass
  62. {
  63. Train = (x_train, y_train),
  64. Test = (x_test, y_test)
  65. };
  66. }
  67. (NDArray, NDArray) LoadX(byte[] bytes)
  68. {
  69. var y = np.Load_Npz<byte[]>(bytes);
  70. return (y["x_train.npy"], y["x_test.npy"]);
  71. }
  72. (NDArray, NDArray) LoadY(byte[] bytes)
  73. {
  74. var y = np.Load_Npz<long[]>(bytes);
  75. return (y["y_train.npy"], y["y_test.npy"]);
  76. }
  77. string Download()
  78. {
  79. var dst = Path.Combine(Path.GetTempPath(), dest_folder);
  80. Directory.CreateDirectory(dst);
  81. Web.Download(origin_folder + file_name, dst, file_name);
  82. return dst;
  83. // return Path.Combine(dst, file_name);
  84. }
  85. protected IEnumerable<int[]> PraseData(string[] x)
  86. {
  87. var data_list = new List<int[]>();
  88. for (int i = 0; i < len(x); i++)
  89. {
  90. var list_string = x[i];
  91. var cleaned_list_string = list_string.Replace("[", "").Replace("]", "").Replace(" ", "");
  92. string[] number_strings = cleaned_list_string.Split(',');
  93. int[] numbers = new int[number_strings.Length];
  94. for (int j = 0; j < number_strings.Length; j++)
  95. {
  96. numbers[j] = int.Parse(number_strings[j]);
  97. }
  98. data_list.Add(numbers);
  99. }
  100. return data_list;
  101. }
  102. }
  103. }