You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MnistModelLoader.cs 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. using System;
  2. using System.Threading.Tasks;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using System.IO;
  6. using NumSharp;
  7. namespace Tensorflow.Hub
  8. {
  9. public class MnistModelLoader : IModelLoader<MnistDataSet>
  10. {
  11. private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/";
  12. private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz";
  13. private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
  14. private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
  15. private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";
  16. public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false)
  17. {
  18. var loader = new MnistModelLoader();
  19. var setting = new ModelLoadSetting
  20. {
  21. TrainDir = trainDir,
  22. OneHot = oneHot,
  23. ShowProgressInConsole = showProgressInConsole
  24. };
  25. if (trainSize.HasValue)
  26. setting.TrainSize = trainSize.Value;
  27. if (validationSize.HasValue)
  28. setting.ValidationSize = validationSize.Value;
  29. if (testSize.HasValue)
  30. setting.TestSize = testSize.Value;
  31. return await loader.LoadAsync(setting);
  32. }
  33. public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting)
  34. {
  35. if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value)
  36. throw new ArgumentException("Validation set should be smaller than training set");
  37. var sourceUrl = setting.SourceUrl;
  38. if (string.IsNullOrEmpty(sourceUrl))
  39. sourceUrl = DEFAULT_SOURCE_URL;
  40. // load train images
  41. await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole)
  42. .ShowProgressInConsole(setting.ShowProgressInConsole);
  43. await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
  44. .ShowProgressInConsole(setting.ShowProgressInConsole);
  45. var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize);
  46. // load train labels
  47. await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole)
  48. .ShowProgressInConsole(setting.ShowProgressInConsole);
  49. await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
  50. .ShowProgressInConsole(setting.ShowProgressInConsole);
  51. var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize);
  52. // load test images
  53. await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole)
  54. .ShowProgressInConsole(setting.ShowProgressInConsole);
  55. await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
  56. .ShowProgressInConsole(setting.ShowProgressInConsole);
  57. var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize);
  58. // load test labels
  59. await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole)
  60. .ShowProgressInConsole(setting.ShowProgressInConsole);
  61. await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
  62. .ShowProgressInConsole(setting.ShowProgressInConsole);
  63. var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize);
  64. var end = trainImages.shape[0];
  65. var validationSize = setting.ValidationSize;
  66. var validationImages = trainImages[np.arange(validationSize)];
  67. var validationLabels = trainLabels[np.arange(validationSize)];
  68. trainImages = trainImages[np.arange(validationSize, end)];
  69. trainLabels = trainLabels[np.arange(validationSize, end)];
  70. var dtype = setting.DataType;
  71. var reshape = setting.ReShape;
  72. var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape);
  73. var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape);
  74. var test = new MnistDataSet(testImages, testLabels, dtype, reshape);
  75. return new Datasets<MnistDataSet>(train, validation, test);
  76. }
  77. private NDArray ExtractImages(string file, int? limit = null)
  78. {
  79. if (!Path.IsPathRooted(file))
  80. file = Path.Combine(AppContext.BaseDirectory, file);
  81. using (var bytestream = new FileStream(file, FileMode.Open))
  82. {
  83. var magic = Read32(bytestream);
  84. if (magic != 2051)
  85. throw new Exception($"Invalid magic number {magic} in MNIST image file: {file}");
  86. var num_images = Read32(bytestream);
  87. num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit);
  88. var rows = Read32(bytestream);
  89. var cols = Read32(bytestream);
  90. var buf = new byte[rows * cols * num_images];
  91. bytestream.Read(buf, 0, buf.Length);
  92. var data = np.frombuffer(buf, np.uint8);
  93. data = data.reshape((int)num_images, (int)rows, (int)cols, 1);
  94. return data;
  95. }
  96. }
  97. private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes = 10, int? limit = null)
  98. {
  99. if (!Path.IsPathRooted(file))
  100. file = Path.Combine(AppContext.BaseDirectory, file);
  101. using (var bytestream = new FileStream(file, FileMode.Open))
  102. {
  103. var magic = Read32(bytestream);
  104. if (magic != 2049)
  105. throw new Exception($"Invalid magic number {magic} in MNIST label file: {file}");
  106. var num_items = Read32(bytestream);
  107. num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit);
  108. var buf = new byte[num_items];
  109. bytestream.Read(buf, 0, buf.Length);
  110. var labels = np.frombuffer(buf, np.uint8);
  111. if (one_hot)
  112. return DenseToOneHot(labels, num_classes);
  113. return labels;
  114. }
  115. }
  116. private NDArray DenseToOneHot(NDArray labels_dense, int num_classes)
  117. {
  118. var num_labels = labels_dense.shape[0];
  119. var index_offset = np.arange(num_labels) * num_classes;
  120. var labels_one_hot = np.zeros(num_labels, num_classes);
  121. var labels = labels_dense.Data<byte>();
  122. for (int row = 0; row < num_labels; row++)
  123. {
  124. var col = labels[row];
  125. labels_one_hot.SetData(1.0, row, col);
  126. }
  127. return labels_one_hot;
  128. }
  129. private uint Read32(FileStream bytestream)
  130. {
  131. var buffer = new byte[sizeof(uint)];
  132. var count = bytestream.Read(buffer, 0, 4);
  133. return np.frombuffer(buffer, ">u4").Data<uint>()[0];
  134. }
  135. }
  136. }