You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MnistDataSet.cs 5.5 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using Tensorflow;
  8. namespace TensorFlowNET.Examples.Utility
  9. {
  10. public class MnistDataSet
  11. {
  12. private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/";
  13. private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz";
  14. private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
  15. private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
  16. private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";
  17. public static Datasets read_data_sets(string train_dir,
  18. bool one_hot = false,
  19. TF_DataType dtype = TF_DataType.TF_FLOAT,
  20. bool reshape = true,
  21. int validation_size = 5000,
  22. int? train_size = null,
  23. int? test_size = null,
  24. string source_url = DEFAULT_SOURCE_URL)
  25. {
  26. if (train_size!=null && validation_size >= train_size)
  27. throw new ArgumentException("Validation set should be smaller than training set");
  28. Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES);
  29. Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir);
  30. var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]), limit: train_size);
  31. Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS);
  32. Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir);
  33. var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot, limit: train_size);
  34. Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES);
  35. Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir);
  36. var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]), limit: test_size);
  37. Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS);
  38. Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir);
  39. var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot, limit:test_size);
  40. int end = train_images.shape[0];
  41. var validation_images = train_images[np.arange(validation_size)];
  42. var validation_labels = train_labels[np.arange(validation_size)];
  43. train_images = train_images[np.arange(validation_size, end)];
  44. train_labels = train_labels[np.arange(validation_size, end)];
  45. var train = new DataSet(train_images, train_labels, dtype, reshape);
  46. var validation = new DataSet(validation_images, validation_labels, dtype, reshape);
  47. var test = new DataSet(test_images, test_labels, dtype, reshape);
  48. return new Datasets(train, validation, test);
  49. }
  50. public static NDArray extract_images(string file, int? limit=null)
  51. {
  52. using (var bytestream = new FileStream(file, FileMode.Open))
  53. {
  54. var magic = _read32(bytestream);
  55. if (magic != 2051)
  56. throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}");
  57. var num_images = _read32(bytestream);
  58. num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit);
  59. var rows = _read32(bytestream);
  60. var cols = _read32(bytestream);
  61. var buf = new byte[rows * cols * num_images];
  62. bytestream.Read(buf, 0, buf.Length);
  63. var data = np.frombuffer(buf, np.uint8);
  64. data = data.reshape((int)num_images, (int)rows, (int)cols, 1);
  65. return data;
  66. }
  67. }
  68. public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10, int? limit = null)
  69. {
  70. using (var bytestream = new FileStream(file, FileMode.Open))
  71. {
  72. var magic = _read32(bytestream);
  73. if (magic != 2049)
  74. throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}");
  75. var num_items = _read32(bytestream);
  76. num_items = limit == null ? num_items : Math.Min(num_items,(uint) limit);
  77. var buf = new byte[num_items];
  78. bytestream.Read(buf, 0, buf.Length);
  79. var labels = np.frombuffer(buf, np.uint8);
  80. if (one_hot)
  81. return dense_to_one_hot(labels, num_classes);
  82. return labels;
  83. }
  84. }
  85. private static NDArray dense_to_one_hot(NDArray labels_dense, int num_classes)
  86. {
  87. var num_labels = labels_dense.shape[0];
  88. var index_offset = np.arange(num_labels) * num_classes;
  89. var labels_one_hot = np.zeros(num_labels, num_classes);
  90. for(int row = 0; row < num_labels; row++)
  91. {
  92. var col = labels_dense.Data<byte>(row);
  93. labels_one_hot.SetData(1.0, row, col);
  94. }
  95. return labels_one_hot;
  96. }
  97. private static uint _read32(FileStream bytestream)
  98. {
  99. var buffer = new byte[sizeof(uint)];
  100. var count = bytestream.Read(buffer, 0, 4);
  101. return np.frombuffer(buffer, ">u4").Data<uint>(0);
  102. }
  103. }
  104. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。