Browse Source

tried to implement MnistModelLoader

tags/v0.10
Kerry Jiang 6 years ago
parent
commit
6fe5a6c572
3 changed files with 152 additions and 2 deletions
  1. +141
    -2
      src/TensorFlowHub/MnistModelLoader.cs
  2. +1
    -0
      src/TensorFlowHub/ModelLoadSetting.cs
  3. +10
    -0
      src/TensorFlowHub/Utils.cs

+ 141
- 2
src/TensorFlowHub/MnistModelLoader.cs View File

@@ -2,15 +2,154 @@
using System.Threading.Tasks;
using System.Collections.Generic;
using System.Text;
using System.IO;
using NumSharp;

namespace Tensorflow.Hub
{
public class MnistModelLoader : IModelLoader<MnistDataSet>
{
public Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting)
private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/";
private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz";
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";

public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting)
{
if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value)
throw new ArgumentException("Validation set should be smaller than training set");

var sourceUrl = setting.SourceUrl;

if (string.IsNullOrEmpty(sourceUrl))
sourceUrl = DEFAULT_SOURCE_URL;

// load train images
await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES)
.ShowProgressInConsole(setting.ShowProgressInConsole);

await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir)
.ShowProgressInConsole(setting.ShowProgressInConsole);

var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize);

// load train labels
await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS)
.ShowProgressInConsole(setting.ShowProgressInConsole);

await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir)
.ShowProgressInConsole(setting.ShowProgressInConsole);

var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize);

// load test images
await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES)
.ShowProgressInConsole(setting.ShowProgressInConsole);

await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir)
.ShowProgressInConsole(setting.ShowProgressInConsole);

var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize);

// load test labels
await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS)
.ShowProgressInConsole(setting.ShowProgressInConsole);

await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir)
.ShowProgressInConsole(setting.ShowProgressInConsole);

var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize);

var end = trainImages.shape[0];

var validationSize = setting.ValidationSize;

var validationImages = trainImages[np.arange(validationSize)];
var validationLabels = trainLabels[np.arange(validationSize)];
trainImages = trainImages[np.arange(validationSize, end)];
trainLabels = trainLabels[np.arange(validationSize, end)];

var dtype = setting.DtType;
var reshape = setting.ReShape;

var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape);
var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape);
var test = new MnistDataSet(trainImages, trainLabels, dtype, reshape);

return new Datasets<MnistDataSet>(train, validation, test);
}

private NDArray ExtractImages(string file, int? limit = null)
{
using (var bytestream = new FileStream(file, FileMode.Open))
{
var magic = Read32(bytestream);
if (magic != 2051)
throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}");
var num_images = Read32(bytestream);
num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit);

var rows = Read32(bytestream);
var cols = Read32(bytestream);

var buf = new byte[rows * cols * num_images];

bytestream.Read(buf, 0, buf.Length);

var data = np.frombuffer(buf, np.uint8);
data = data.reshape((int)num_images, (int)rows, (int)cols, 1);

return data;
}
}

private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes = 10, int? limit = null)
{
using (var bytestream = new FileStream(file, FileMode.Open))
{
var magic = Read32(bytestream);
if (magic != 2049)
throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}");
var num_items = Read32(bytestream);
num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit);
var buf = new byte[num_items];

bytestream.Read(buf, 0, buf.Length);
var labels = np.frombuffer(buf, np.uint8);

if (one_hot)
return DenseToOneHot(labels, num_classes);
return labels;
}
}

private NDArray DenseToOneHot(NDArray labels_dense, int num_classes)
{
var num_labels = labels_dense.shape[0];
var index_offset = np.arange(num_labels) * num_classes;
var labels_one_hot = np.zeros(num_labels, num_classes);

for(int row = 0; row < num_labels; row++)
{
var col = labels_dense.Data<byte>(row);
labels_one_hot.SetData(1.0, row, col);
}

return labels_one_hot;
}

private uint Read32(FileStream bytestream)
{
throw new NotImplementedException();
var buffer = new byte[sizeof(uint)];
var count = bytestream.Read(buffer, 0, 4);
return np.frombuffer(buffer, ">u4").Data<uint>(0);
}
}
}

+ 1
- 0
src/TensorFlowHub/ModelLoadSetting.cs View File

@@ -15,5 +15,6 @@ namespace Tensorflow.Hub
public int? TrainSize { get; set; }
public int? TestSize { get; set; }
public string SourceUrl { get; set; }
public bool ShowProgressInConsole { get; set; }
}
}

+ 10
- 0
src/TensorFlowHub/Utils.cs View File

@@ -71,6 +71,16 @@ namespace Tensorflow.Hub

public static async Task ShowProgressInConsole(this Task task)
{
await ShowProgressInConsole(task, true);
}

public static async Task ShowProgressInConsole(this Task task, bool enable)
{
if (!enable)
{
await task;
}

var cts = new CancellationTokenSource();
var showProgressTask = ShowProgressInConsole(cts);


Loading…
Cancel
Save