Browse Source

started to use MnistModelLoader in Tensorflow.Hub (#330)

tags/v0.12
Kerry Jiang Haiping 6 years ago
parent
commit
85319e0feb
12 changed files with 69 additions and 356 deletions
  1. +4
    -0
      .gitignore
  2. +20
    -8
      test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs
  3. +13
    -13
      test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs
  4. +5
    -5
      test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs
  5. +8
    -8
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs
  6. +9
    -31
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs
  7. +9
    -9
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs
  8. +1
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  9. +0
    -95
      test/TensorFlowNET.Examples/Utility/DataSetMnist.cs
  10. +0
    -46
      test/TensorFlowNET.Examples/Utility/Datasets.cs
  11. +0
    -10
      test/TensorFlowNET.Examples/Utility/IDataSet.cs
  12. +0
    -131
      test/TensorFlowNET.Examples/Utility/MNIST.cs

+ 4
- 0
.gitignore View File

@@ -332,3 +332,7 @@ src/TensorFlowNET.Native/bazel-*
src/TensorFlowNET.Native/c_api.h
/.vscode
test/TensorFlowNET.Examples/mnist


# training model resources
.resources

+ 20
- 8
test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs View File

@@ -18,7 +18,7 @@ using NumSharp;
using System;
using System.Diagnostics;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using Tensorflow.Hub;
using static Tensorflow.Python;

namespace TensorFlowNET.Examples
@@ -39,7 +39,7 @@ namespace TensorFlowNET.Examples
public int? test_size = null;
public int batch_size = 1024; // The number of samples per batch

Datasets<DataSetMnist> mnist;
Datasets<MnistDataSet> mnist;
NDArray full_data_x;
int num_steps = 20; // Total steps to train
int k = 25; // The number of clusters
@@ -62,19 +62,31 @@ namespace TensorFlowNET.Examples

public void PrepareData()
{
mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size:validation_size, test_size:test_size);
full_data_x = mnist.train.data;
var loader = new MnistModelLoader();

var setting = new ModelLoadSetting
{
TrainDir = ".resources/mnist",
OneHot = true,
TrainSize = train_size,
ValidationSize = validation_size,
TestSize = test_size
};

mnist = loader.LoadAsync(setting).Result;

full_data_x = mnist.Train.Data;

// download graph meta data
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta";
Web.Download(url, "graph", "kmeans.meta");
loader.DownloadAsync(url, ".resources/graph", "kmeans.meta").Wait();
}

public Graph ImportGraph()
{
var graph = tf.Graph().as_default();

tf.train.import_meta_graph("graph/kmeans.meta");
tf.train.import_meta_graph(".resources/graph/kmeans.meta");

return graph;
}
@@ -132,7 +144,7 @@ namespace TensorFlowNET.Examples
sw.Start();
foreach (var i in range(idx.Length))
{
var x = mnist.train.labels[i];
var x = mnist.Train.Labels[i];
counts[idx[i]] += x;
}

@@ -153,7 +165,7 @@ namespace TensorFlowNET.Examples
var accuracy_op = tf.reduce_mean(cast);

// Test Model
var (test_x, test_y) = (mnist.test.data, mnist.test.labels);
var (test_x, test_y) = (mnist.Test.Data, mnist.Test.Labels);
result = sess.run(accuracy_op, new FeedItem(X, test_x), new FeedItem(Y, test_y));
accuray_test = result;
print($"Test Accuracy: {accuray_test}");


+ 13
- 13
test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs View File

@@ -19,7 +19,7 @@ using System;
using System.Diagnostics;
using System.IO;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using Tensorflow.Hub;
using static Tensorflow.Python;

namespace TensorFlowNET.Examples
@@ -45,7 +45,7 @@ namespace TensorFlowNET.Examples
private float learning_rate = 0.01f;
private int display_step = 1;

Datasets<DataSetMnist> mnist;
Datasets<MnistDataSet> mnist;

public bool Run()
{
@@ -84,11 +84,11 @@ namespace TensorFlowNET.Examples
sw.Start();

var avg_cost = 0.0f;
var total_batch = mnist.train.num_examples / batch_size;
var total_batch = mnist.Train.NumOfExamples / batch_size;
// Loop over all batches
foreach (var i in range(total_batch))
{
var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size);
var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(batch_size);
// Run optimization op (backprop) and cost op (to get loss value)
var result = sess.run(new object[] { optimizer, cost },
new FeedItem(x, batch_xs),
@@ -115,7 +115,7 @@ namespace TensorFlowNET.Examples
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
// Calculate accuracy
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
float acc = accuracy.eval(new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels));
float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels));
print($"Accuracy: {acc.ToString("F4")}");

return acc > 0.9;
@@ -124,23 +124,23 @@ namespace TensorFlowNET.Examples

public void PrepareData()
{
mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size);
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size).Result;
}

public void SaveModel(Session sess)
{
var saver = tf.train.Saver();
var save_path = saver.save(sess, "logistic_regression/model.ckpt");
tf.train.write_graph(sess.graph, "logistic_regression", "model.pbtxt", as_text: true);
var save_path = saver.save(sess, ".resources/logistic_regression/model.ckpt");
tf.train.write_graph(sess.graph, ".resources/logistic_regression", "model.pbtxt", as_text: true);

FreezeGraph.freeze_graph(input_graph: "logistic_regression/model.pbtxt",
FreezeGraph.freeze_graph(input_graph: ".resources/logistic_regression/model.pbtxt",
input_saver: "",
input_binary: false,
input_checkpoint: "logistic_regression/model.ckpt",
input_checkpoint: ".resources/logistic_regression/model.ckpt",
output_node_names: "Softmax",
restore_op_name: "save/restore_all",
filename_tensor_name: "save/Const:0",
output_graph: "logistic_regression/model.pb",
output_graph: ".resources/logistic_regression/model.pb",
clear_devices: true,
initializer_nodes: "");
}
@@ -148,7 +148,7 @@ namespace TensorFlowNET.Examples
public void Predict(Session sess)
{
var graph = new Graph().as_default();
graph.Import(Path.Join("logistic_regression", "model.pb"));
graph.Import(Path.Join(".resources/logistic_regression", "model.pb"));

// restoring the model
// var saver = tf.train.import_meta_graph("logistic_regression/tensorflowModel.ckpt.meta");
@@ -159,7 +159,7 @@ namespace TensorFlowNET.Examples
var input = x.outputs[0];

// predict
var (batch_xs, batch_ys) = mnist.train.next_batch(10);
var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(10);
var results = sess.run(output, new FeedItem(input, batch_xs[np.arange(1)]));

if (results.argmax() == (batch_ys[0] as NDArray).argmax())


+ 5
- 5
test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs View File

@@ -17,7 +17,7 @@
using NumSharp;
using System;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using Tensorflow.Hub;
using static Tensorflow.Python;

namespace TensorFlowNET.Examples
@@ -31,7 +31,7 @@ namespace TensorFlowNET.Examples
{
public bool Enabled { get; set; } = true;
public string Name => "Nearest Neighbor";
Datasets<DataSetMnist> mnist;
Datasets<MnistDataSet> mnist;
NDArray Xtr, Ytr, Xte, Yte;
public int? TrainSize = null;
public int ValidationSize = 5000;
@@ -84,10 +84,10 @@ namespace TensorFlowNET.Examples

public void PrepareData()
{
mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize);
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize).Result;
// In this example, we limit mnist data
(Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing
(Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing
}

public Graph ImportGraph()


+ 8
- 8
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs View File

@@ -18,7 +18,7 @@ using NumSharp;
using System;
using System.Diagnostics;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using Tensorflow.Hub;
using static Tensorflow.Python;

namespace TensorFlowNET.Examples.ImageProcess
@@ -46,7 +46,7 @@ namespace TensorFlowNET.Examples.ImageProcess
int epochs = 5; // accuracy > 98%
int batch_size = 100;
float learning_rate = 0.001f;
Datasets<DataSetMnist> mnist;
Datasets<MnistDataSet> mnist;

// Network configuration
// 1st Convolutional Layer
@@ -310,14 +310,14 @@ namespace TensorFlowNET.Examples.ImageProcess
public void PrepareData()
{
mnist = MNIST.read_data_sets("mnist", one_hot: true);
(x_train, y_train) = Reformat(mnist.train.data, mnist.train.labels);
(x_valid, y_valid) = Reformat(mnist.validation.data, mnist.validation.labels);
(x_test, y_test) = Reformat(mnist.test.data, mnist.test.labels);
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result;
(x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels);
(x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels);
(x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels);

print("Size of:");
print($"- Training-set:\t\t{len(mnist.train.data)}");
print($"- Validation-set:\t{len(mnist.validation.data)}");
print($"- Training-set:\t\t{len(mnist.Train.Data)}");
print($"- Validation-set:\t{len(mnist.Validation.Data)}");
}

/// <summary>


+ 9
- 31
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs View File

@@ -17,7 +17,7 @@
using NumSharp;
using System;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using Tensorflow.Hub;
using static Tensorflow.Python;

namespace TensorFlowNET.Examples.ImageProcess
@@ -44,7 +44,7 @@ namespace TensorFlowNET.Examples.ImageProcess
int batch_size = 100;
float learning_rate = 0.001f;
int h1 = 200; // number of nodes in the 1st hidden layer
Datasets<DataSetMnist> mnist;
Datasets<MnistDataSet> mnist;

Tensor x, y;
Tensor loss, accuracy;
@@ -121,13 +121,13 @@ namespace TensorFlowNET.Examples.ImageProcess
public void PrepareData()
{
mnist = MNIST.read_data_sets("mnist", one_hot: true);
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result;
}

public void Train(Session sess)
{
// Number of training iterations in each epoch
var num_tr_iter = mnist.train.labels.len / batch_size;
var num_tr_iter = mnist.Train.Labels.len / batch_size;

var init = tf.global_variables_initializer();
sess.run(init);
@@ -139,13 +139,13 @@ namespace TensorFlowNET.Examples.ImageProcess
{
print($"Training epoch: {epoch + 1}");
// Randomly shuffle the training data at the beginning of each epoch
var (x_train, y_train) = randomize(mnist.train.data, mnist.train.labels);
var (x_train, y_train) = mnist.Randomize(mnist.Train.Data, mnist.Train.Labels);

foreach (var iteration in range(num_tr_iter))
{
var start = iteration * batch_size;
var end = (iteration + 1) * batch_size;
var (x_batch, y_batch) = get_next_batch(x_train, y_train, start, end);
var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);

// Run optimization op (backprop)
sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
@@ -161,7 +161,8 @@ namespace TensorFlowNET.Examples.ImageProcess
}

// Run validation after every epoch
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels));
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Validation.Data), new FeedItem(y, mnist.Validation.Labels));
loss_val = results1[0];
accuracy_val = results1[1];
print("---------------------------------------------------------");
@@ -172,35 +173,12 @@ namespace TensorFlowNET.Examples.ImageProcess

public void Test(Session sess)
{
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels));
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels));
loss_test = result[0];
accuracy_test = result[1];
print("---------------------------------------------------------");
print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
print("---------------------------------------------------------");
}

private (NDArray, NDArray) randomize(NDArray x, NDArray y)
{
var perm = np.random.permutation(y.shape[0]);

np.random.shuffle(perm);
return (mnist.train.data[perm], mnist.train.labels[perm]);
}

/// <summary>
/// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method)
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="start"></param>
/// <param name="end"></param>
/// <returns></returns>
private (NDArray, NDArray) get_next_batch(NDArray x, NDArray y, int start, int end)
{
var x_batch = x[$"{start}:{end}"];
var y_batch = y[$"{start}:{end}"];
return (x_batch, y_batch);
}
}
}

+ 9
- 9
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs View File

@@ -17,7 +17,7 @@
using NumSharp;
using System;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using Tensorflow.Hub;
using static Tensorflow.Python;

namespace TensorFlowNET.Examples.ImageProcess
@@ -45,7 +45,7 @@ namespace TensorFlowNET.Examples.ImageProcess
int n_inputs = 28;
int n_outputs = 10;

Datasets<DataSetMnist> mnist;
Datasets<MnistDataSet> mnist;

Tensor x, y;
Tensor loss, accuracy, cls_prediction;
@@ -143,15 +143,15 @@ namespace TensorFlowNET.Examples.ImageProcess

public void PrepareData()
{
mnist = MNIST.read_data_sets("mnist", one_hot: true);
(x_train, y_train) = (mnist.train.data, mnist.train.labels);
(x_valid, y_valid) = (mnist.validation.data, mnist.validation.labels);
(x_test, y_test) = (mnist.test.data, mnist.test.labels);
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result;
(x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels);
(x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels);
(x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels);

print("Size of:");
print($"- Training-set:\t\t{len(mnist.train.data)}");
print($"- Validation-set:\t{len(mnist.validation.data)}");
print($"- Test-set:\t\t{len(mnist.test.data)}");
print($"- Training-set:\t\t{len(mnist.Train.Data)}");
print($"- Validation-set:\t{len(mnist.Validation.Data)}");
print($"- Test-set:\t\t{len(mnist.Test.Data)}");
}

public Graph ImportGraph() => throw new NotImplementedException();


+ 1
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -18,5 +18,6 @@
<ProjectReference Include="..\..\src\KerasNET.Core\Keras.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowText\TensorFlowText.csproj" />
<ProjectReference Include="..\..\src\TensorFlowHub\TensorFlowHub.csproj" />
</ItemGroup>
</Project>

+ 0
- 95
test/TensorFlowNET.Examples/Utility/DataSetMnist.cs View File

@@ -1,95 +0,0 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using NumSharp;
using Tensorflow;

namespace TensorFlowNET.Examples.Utility
{
public class DataSetMnist : IDataSet
{
public int num_examples { get; }

public int epochs_completed { get; private set; }
public int index_in_epoch { get; private set; }
public NDArray data { get; private set; }
public NDArray labels { get; private set; }

public DataSetMnist(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
{
num_examples = images.shape[0];
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
images.astype(dtype.as_numpy_datatype());
images = np.multiply(images, 1.0f / 255.0f);

labels.astype(dtype.as_numpy_datatype());

data = images;
this.labels = labels;
epochs_completed = 0;
index_in_epoch = 0;
}

public (NDArray, NDArray) next_batch(int batch_size, bool fake_data = false, bool shuffle = true)
{
var start = index_in_epoch;
// Shuffle for the first epoch
if(epochs_completed == 0 && start == 0 && shuffle)
{
var perm0 = np.arange(num_examples);
np.random.shuffle(perm0);
data = data[perm0];
labels = labels[perm0];
}

// Go to the next epoch
if (start + batch_size > num_examples)
{
// Finished epoch
epochs_completed += 1;

// Get the rest examples in this epoch
var rest_num_examples = num_examples - start;
//var images_rest_part = _images[np.arange(start, _num_examples)];
//var labels_rest_part = _labels[np.arange(start, _num_examples)];
// Shuffle the data
if (shuffle)
{
var perm = np.arange(num_examples);
np.random.shuffle(perm);
data = data[perm];
labels = labels[perm];
}

start = 0;
index_in_epoch = batch_size - rest_num_examples;
var end = index_in_epoch;
var images_new_part = data[np.arange(start, end)];
var labels_new_part = labels[np.arange(start, end)];

/*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0),
np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/
return (images_new_part, labels_new_part);
}
else
{
index_in_epoch += batch_size;
var end = index_in_epoch;
return (data[np.arange(start, end)], labels[np.arange(start, end)]);
}
}
}
}

+ 0
- 46
test/TensorFlowNET.Examples/Utility/Datasets.cs View File

@@ -1,46 +0,0 @@
using NumSharp;

namespace TensorFlowNET.Examples.Utility
{
public class Datasets<T> where T : IDataSet
{
private T _train;
public T train => _train;

private T _validation;
public T validation => _validation;

private T _test;
public T test => _test;

public Datasets(T train, T validation, T test)
{
_train = train;
_validation = validation;
_test = test;
}

public (NDArray, NDArray) Randomize(NDArray x, NDArray y)
{
var perm = np.random.permutation(y.shape[0]);

np.random.shuffle(perm);
return (x[perm], y[perm]);
}

/// <summary>
/// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method)
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="start"></param>
/// <param name="end"></param>
/// <returns></returns>
public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end)
{
var x_batch = x[$"{start}:{end}"];
var y_batch = y[$"{start}:{end}"];
return (x_batch, y_batch);
}
}
}

+ 0
- 10
test/TensorFlowNET.Examples/Utility/IDataSet.cs View File

@@ -1,10 +0,0 @@
using NumSharp;

namespace TensorFlowNET.Examples.Utility
{
public interface IDataSet
{
NDArray data { get; }
NDArray labels { get; }
}
}

+ 0
- 131
test/TensorFlowNET.Examples/Utility/MNIST.cs View File

@@ -1,131 +0,0 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using NumSharp;
using System;
using System.IO;
using Tensorflow;

namespace TensorFlowNET.Examples.Utility
{
public class MNIST
{
private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/";
private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz";
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";
public static Datasets<DataSetMnist> read_data_sets(string train_dir,
bool one_hot = false,
TF_DataType dtype = TF_DataType.TF_FLOAT,
bool reshape = true,
int validation_size = 5000,
int? train_size = null,
int? test_size = null,
string source_url = DEFAULT_SOURCE_URL)
{
if (train_size!=null && validation_size >= train_size)
throw new ArgumentException("Validation set should be smaller than training set");

Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES);
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir);
var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]), limit: train_size);

Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS);
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir);
var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot, limit: train_size);

Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES);
Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir);
var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]), limit: test_size);

Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS);
Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir);
var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot, limit:test_size);

int end = train_images.shape[0];
var validation_images = train_images[np.arange(validation_size)];
var validation_labels = train_labels[np.arange(validation_size)];
train_images = train_images[np.arange(validation_size, end)];
train_labels = train_labels[np.arange(validation_size, end)];

var train = new DataSetMnist(train_images, train_labels, dtype, reshape);
var validation = new DataSetMnist(validation_images, validation_labels, dtype, reshape);
var test = new DataSetMnist(test_images, test_labels, dtype, reshape);

return new Datasets<DataSetMnist>(train, validation, test);
}

public static NDArray extract_images(string file, int? limit=null)
{
using (var bytestream = new FileStream(file, FileMode.Open))
{
var magic = _read32(bytestream);
if (magic != 2051)
throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}");
var num_images = _read32(bytestream);
num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit);
var rows = _read32(bytestream);
var cols = _read32(bytestream);
var buf = new byte[rows * cols * num_images];
bytestream.Read(buf, 0, buf.Length);
var data = np.frombuffer(buf, np.uint8);
data = data.reshape((int)num_images, (int)rows, (int)cols, 1);
return data;
}
}

public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10, int? limit = null)
{
using (var bytestream = new FileStream(file, FileMode.Open))
{
var magic = _read32(bytestream);
if (magic != 2049)
throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}");
var num_items = _read32(bytestream);
num_items = limit == null ? num_items : Math.Min(num_items,(uint) limit);
var buf = new byte[num_items];
bytestream.Read(buf, 0, buf.Length);
var labels = np.frombuffer(buf, np.uint8);
if (one_hot)
return dense_to_one_hot(labels, num_classes);
return labels;
}
}

private static NDArray dense_to_one_hot(NDArray labels_dense, int num_classes)
{
var num_labels = labels_dense.shape[0];
var index_offset = np.arange(num_labels) * num_classes;
var labels_one_hot = np.zeros(num_labels, num_classes);

for(int row = 0; row < num_labels; row++)
{
var col = labels_dense.Data<byte>(row);
labels_one_hot.SetData(1.0, row, col);
}

return labels_one_hot;
}

private static uint _read32(FileStream bytestream)
{
var buffer = new byte[sizeof(uint)];
var count = bytestream.Read(buffer, 0, 4);
return np.frombuffer(buffer, ">u4").Data<uint>(0);
}
}
}

Loading…
Cancel
Save