From 6fe5a6c5723c8d7ac6adbeaee1dc3247eb913d3e Mon Sep 17 00:00:00 2001 From: Kerry Jiang Date: Sun, 21 Jul 2019 19:25:41 -0700 Subject: [PATCH] tried to implement MnistModelLoader --- src/TensorFlowHub/MnistModelLoader.cs | 143 +++++++++++++++++++++++++- src/TensorFlowHub/ModelLoadSetting.cs | 1 + src/TensorFlowHub/Utils.cs | 10 ++ 3 files changed, 152 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowHub/MnistModelLoader.cs b/src/TensorFlowHub/MnistModelLoader.cs index 9e612a11..72a4dcbb 100644 --- a/src/TensorFlowHub/MnistModelLoader.cs +++ b/src/TensorFlowHub/MnistModelLoader.cs @@ -2,15 +2,154 @@ using System.Threading.Tasks; using System.Collections.Generic; using System.Text; +using System.IO; using NumSharp; namespace Tensorflow.Hub { public class MnistModelLoader : IModelLoader { - public Task> LoadAsync(ModelLoadSetting setting) + private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; + private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; + private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; + private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; + private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; + + public async Task> LoadAsync(ModelLoadSetting setting) + { + if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value) + throw new ArgumentException("Validation set should be smaller than training set"); + + var sourceUrl = setting.SourceUrl; + + if (string.IsNullOrEmpty(sourceUrl)) + sourceUrl = DEFAULT_SOURCE_URL; + + // load train images + await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); + + // load train labels + await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); + + // load test images + await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); + + // load test labels + await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); + + var end = trainImages.shape[0]; + + var validationSize = setting.ValidationSize; + + var validationImages = trainImages[np.arange(validationSize)]; + var validationLabels = trainLabels[np.arange(validationSize)]; + + trainImages = trainImages[np.arange(validationSize, end)]; + trainLabels = trainLabels[np.arange(validationSize, end)]; + + var dtype = setting.DtType; + var reshape = setting.ReShape; + + var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape); + var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape); + var test = new MnistDataSet(trainImages, trainLabels, dtype, reshape); + + return new Datasets(train, validation, test); + } + + private NDArray ExtractImages(string file, int? limit = null) + { + using (var bytestream = new FileStream(file, FileMode.Open)) + { + var magic = Read32(bytestream); + if (magic != 2051) + throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); + + var num_images = Read32(bytestream); + num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit); + + var rows = Read32(bytestream); + var cols = Read32(bytestream); + + var buf = new byte[rows * cols * num_images]; + + bytestream.Read(buf, 0, buf.Length); + + var data = np.frombuffer(buf, np.uint8); + data = data.reshape((int)num_images, (int)rows, (int)cols, 1); + + return data; + } + } + + private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes = 10, int? limit = null) + { + using (var bytestream = new FileStream(file, FileMode.Open)) + { + var magic = Read32(bytestream); + if (magic != 2049) + throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); + + var num_items = Read32(bytestream); + num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit); + + var buf = new byte[num_items]; + + bytestream.Read(buf, 0, buf.Length); + + var labels = np.frombuffer(buf, np.uint8); + + if (one_hot) + return DenseToOneHot(labels, num_classes); + + return labels; + } + } + + private NDArray DenseToOneHot(NDArray labels_dense, int num_classes) + { + var num_labels = labels_dense.shape[0]; + var index_offset = np.arange(num_labels) * num_classes; + var labels_one_hot = np.zeros(num_labels, num_classes); + + for(int row = 0; row < num_labels; row++) + { + var col = labels_dense.Data(row); + labels_one_hot.SetData(1.0, row, col); + } + + return labels_one_hot; + } + + private uint Read32(FileStream bytestream) { - throw new NotImplementedException(); + var buffer = new byte[sizeof(uint)]; + var count = bytestream.Read(buffer, 0, 4); + return np.frombuffer(buffer, ">u4").Data(0); } } } diff --git a/src/TensorFlowHub/ModelLoadSetting.cs b/src/TensorFlowHub/ModelLoadSetting.cs index 95bbaa48..91b4059c 100644 --- a/src/TensorFlowHub/ModelLoadSetting.cs +++ b/src/TensorFlowHub/ModelLoadSetting.cs @@ -15,5 +15,6 @@ namespace Tensorflow.Hub public int? TrainSize { get; set; } public int? TestSize { get; set; } public string SourceUrl { get; set; } + public bool ShowProgressInConsole { get; set; } } } diff --git a/src/TensorFlowHub/Utils.cs b/src/TensorFlowHub/Utils.cs index 56251035..9f71c61d 100644 --- a/src/TensorFlowHub/Utils.cs +++ b/src/TensorFlowHub/Utils.cs @@ -71,6 +71,16 @@ namespace Tensorflow.Hub public static async Task ShowProgressInConsole(this Task task) { + await ShowProgressInConsole(task, true); + } + + public static async Task ShowProgressInConsole(this Task task, bool enable) + { + if (!enable) + { + await task; + } + var cts = new CancellationTokenSource(); var showProgressTask = ShowProgressInConsole(cts);