using NumSharp; using System; using System.Diagnostics; namespace Tensorflow { public class MnistDataSet : DataSetBase { public int NumOfExamples { get; private set; } public int EpochsCompleted { get; private set; } public int IndexInEpoch { get; private set; } public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape) { EpochsCompleted = 0; IndexInEpoch = 0; NumOfExamples = images.shape[0]; images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); images = images.astype(dataType); // for debug np.multiply performance var sw = new Stopwatch(); sw.Start(); images = np.multiply(images, 1.0f / 255.0f); sw.Stop(); Binding.tf_output_redirect.WriteLine($"{sw.ElapsedMilliseconds}ms"); Data = images; labels = labels.astype(dataType); Labels = labels; } public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true) { if (IndexInEpoch >= NumOfExamples) IndexInEpoch = 0; var start = IndexInEpoch; // Shuffle for the first epoch if (EpochsCompleted == 0 && start == 0 && shuffle) { var perm0 = np.arange(NumOfExamples); np.random.shuffle(perm0); Data = Data[perm0]; Labels = Labels[perm0]; } // Go to the next epoch if (start + batch_size > NumOfExamples) { // Finished epoch EpochsCompleted += 1; // Get the rest examples in this epoch var rest_num_examples = NumOfExamples - start; var images_rest_part = Data[np.arange(start, NumOfExamples)]; var labels_rest_part = Labels[np.arange(start, NumOfExamples)]; // Shuffle the data if (shuffle) { var perm = np.arange(NumOfExamples); np.random.shuffle(perm); Data = Data[perm]; Labels = Labels[perm]; } start = 0; IndexInEpoch = batch_size - rest_num_examples; var end = IndexInEpoch; var images_new_part = Data[np.arange(start, end)]; var labels_new_part = Labels[np.arange(start, end)]; return (np.concatenate(new[] { images_rest_part, images_new_part }, axis: 0), np.concatenate(new[] { labels_rest_part, labels_new_part }, axis: 0)); } else { IndexInEpoch += batch_size; var end = IndexInEpoch; return (Data[np.arange(start, end)], Labels[np.arange(start, end)]); } } } }