You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MnistDataSet.cs 3.0 kB

4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. using NumSharp;
  2. using System;
  3. using System.Diagnostics;
  4. namespace Tensorflow
  5. {
  6. public class MnistDataSet : DataSetBase
  7. {
  8. public int NumOfExamples { get; private set; }
  9. public int EpochsCompleted { get; private set; }
  10. public int IndexInEpoch { get; private set; }
  11. public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape)
  12. {
  13. EpochsCompleted = 0;
  14. IndexInEpoch = 0;
  15. NumOfExamples = images.shape[0];
  16. images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
  17. images = images.astype(dataType);
  18. // for debug np.multiply performance
  19. var sw = new Stopwatch();
  20. sw.Start();
  21. images = np.multiply(images, 1.0f / 255.0f);
  22. sw.Stop();
  23. Binding.tf_output_redirect.WriteLine($"{sw.ElapsedMilliseconds}ms");
  24. Data = images;
  25. labels = labels.astype(dataType);
  26. Labels = labels;
  27. }
  28. public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true)
  29. {
  30. if (IndexInEpoch >= NumOfExamples)
  31. IndexInEpoch = 0;
  32. var start = IndexInEpoch;
  33. // Shuffle for the first epoch
  34. if (EpochsCompleted == 0 && start == 0 && shuffle)
  35. {
  36. var perm0 = np.arange(NumOfExamples);
  37. np.random.shuffle(perm0);
  38. Data = Data[perm0];
  39. Labels = Labels[perm0];
  40. }
  41. // Go to the next epoch
  42. if (start + batch_size > NumOfExamples)
  43. {
  44. // Finished epoch
  45. EpochsCompleted += 1;
  46. // Get the rest examples in this epoch
  47. var rest_num_examples = NumOfExamples - start;
  48. var images_rest_part = Data[np.arange(start, NumOfExamples)];
  49. var labels_rest_part = Labels[np.arange(start, NumOfExamples)];
  50. // Shuffle the data
  51. if (shuffle)
  52. {
  53. var perm = np.arange(NumOfExamples);
  54. np.random.shuffle(perm);
  55. Data = Data[perm];
  56. Labels = Labels[perm];
  57. }
  58. start = 0;
  59. IndexInEpoch = batch_size - rest_num_examples;
  60. var end = IndexInEpoch;
  61. var images_new_part = Data[np.arange(start, end)];
  62. var labels_new_part = Labels[np.arange(start, end)];
  63. return (np.concatenate(new[] { images_rest_part, images_new_part }, axis: 0),
  64. np.concatenate(new[] { labels_rest_part, labels_new_part }, axis: 0));
  65. }
  66. else
  67. {
  68. IndexInEpoch += batch_size;
  69. var end = IndexInEpoch;
  70. return (Data[np.arange(start, end)], Labels[np.arange(start, end)]);
  71. }
  72. }
  73. }
  74. }