From 85319e0febc78dc71abf8b2cc17c6ce771ce4311 Mon Sep 17 00:00:00 2001 From: Kerry Jiang Date: Wed, 31 Jul 2019 04:58:56 -0700 Subject: [PATCH] started to use MnistModelLoader in Tensorflow.Hub (#330) --- .gitignore | 4 + .../BasicModels/KMeansClustering.cs | 28 ++-- .../BasicModels/LogisticRegression.cs | 26 ++-- .../BasicModels/NearestNeighbor.cs | 10 +- .../ImageProcessing/DigitRecognitionCNN.cs | 16 +-- .../ImageProcessing/DigitRecognitionNN.cs | 40 ++---- .../ImageProcessing/DigitRecognitionRNN.cs | 18 +-- .../TensorFlowNET.Examples.csproj | 1 + .../Utility/DataSetMnist.cs | 95 ------------- .../Utility/Datasets.cs | 46 ------ .../Utility/IDataSet.cs | 10 -- test/TensorFlowNET.Examples/Utility/MNIST.cs | 131 ------------------ 12 files changed, 69 insertions(+), 356 deletions(-) delete mode 100644 test/TensorFlowNET.Examples/Utility/DataSetMnist.cs delete mode 100644 test/TensorFlowNET.Examples/Utility/Datasets.cs delete mode 100644 test/TensorFlowNET.Examples/Utility/IDataSet.cs delete mode 100644 test/TensorFlowNET.Examples/Utility/MNIST.cs diff --git a/.gitignore b/.gitignore index eee1dc7b..ce600fbb 100644 --- a/.gitignore +++ b/.gitignore @@ -332,3 +332,7 @@ src/TensorFlowNET.Native/bazel-* src/TensorFlowNET.Native/c_api.h /.vscode test/TensorFlowNET.Examples/mnist + + +# training model resources +.resources diff --git a/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs b/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs index c0ca95b3..3b52a75e 100644 --- a/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs +++ b/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs @@ -18,7 +18,7 @@ using NumSharp; using System; using System.Diagnostics; using Tensorflow; -using TensorFlowNET.Examples.Utility; +using Tensorflow.Hub; using static Tensorflow.Python; namespace TensorFlowNET.Examples @@ -39,7 +39,7 @@ namespace TensorFlowNET.Examples public int? test_size = null; public int batch_size = 1024; // The number of samples per batch - Datasets mnist; + Datasets mnist; NDArray full_data_x; int num_steps = 20; // Total steps to train int k = 25; // The number of clusters @@ -62,19 +62,31 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size:validation_size, test_size:test_size); - full_data_x = mnist.train.data; + var loader = new MnistModelLoader(); + + var setting = new ModelLoadSetting + { + TrainDir = ".resources/mnist", + OneHot = true, + TrainSize = train_size, + ValidationSize = validation_size, + TestSize = test_size + }; + + mnist = loader.LoadAsync(setting).Result; + + full_data_x = mnist.Train.Data; // download graph meta data string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta"; - Web.Download(url, "graph", "kmeans.meta"); + loader.DownloadAsync(url, ".resources/graph", "kmeans.meta").Wait(); } public Graph ImportGraph() { var graph = tf.Graph().as_default(); - tf.train.import_meta_graph("graph/kmeans.meta"); + tf.train.import_meta_graph(".resources/graph/kmeans.meta"); return graph; } @@ -132,7 +144,7 @@ namespace TensorFlowNET.Examples sw.Start(); foreach (var i in range(idx.Length)) { - var x = mnist.train.labels[i]; + var x = mnist.Train.Labels[i]; counts[idx[i]] += x; } @@ -153,7 +165,7 @@ namespace TensorFlowNET.Examples var accuracy_op = tf.reduce_mean(cast); // Test Model - var (test_x, test_y) = (mnist.test.data, mnist.test.labels); + var (test_x, test_y) = (mnist.Test.Data, mnist.Test.Labels); result = sess.run(accuracy_op, new FeedItem(X, test_x), new FeedItem(Y, test_y)); accuray_test = result; print($"Test Accuracy: {accuray_test}"); diff --git a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs index 185dd1fe..1d7808b7 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs @@ -19,7 +19,7 @@ using System; using System.Diagnostics; using System.IO; using Tensorflow; -using TensorFlowNET.Examples.Utility; +using Tensorflow.Hub; using static Tensorflow.Python; namespace TensorFlowNET.Examples @@ -45,7 +45,7 @@ namespace TensorFlowNET.Examples private float learning_rate = 0.01f; private int display_step = 1; - Datasets mnist; + Datasets mnist; public bool Run() { @@ -84,11 +84,11 @@ namespace TensorFlowNET.Examples sw.Start(); var avg_cost = 0.0f; - var total_batch = mnist.train.num_examples / batch_size; + var total_batch = mnist.Train.NumOfExamples / batch_size; // Loop over all batches foreach (var i in range(total_batch)) { - var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size); + var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(batch_size); // Run optimization op (backprop) and cost op (to get loss value) var result = sess.run(new object[] { optimizer, cost }, new FeedItem(x, batch_xs), @@ -115,7 +115,7 @@ namespace TensorFlowNET.Examples var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); // Calculate accuracy var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); - float acc = accuracy.eval(new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels)); + float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); print($"Accuracy: {acc.ToString("F4")}"); return acc > 0.9; @@ -124,23 +124,23 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size); + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size).Result; } public void SaveModel(Session sess) { var saver = tf.train.Saver(); - var save_path = saver.save(sess, "logistic_regression/model.ckpt"); - tf.train.write_graph(sess.graph, "logistic_regression", "model.pbtxt", as_text: true); + var save_path = saver.save(sess, ".resources/logistic_regression/model.ckpt"); + tf.train.write_graph(sess.graph, ".resources/logistic_regression", "model.pbtxt", as_text: true); - FreezeGraph.freeze_graph(input_graph: "logistic_regression/model.pbtxt", + FreezeGraph.freeze_graph(input_graph: ".resources/logistic_regression/model.pbtxt", input_saver: "", input_binary: false, - input_checkpoint: "logistic_regression/model.ckpt", + input_checkpoint: ".resources/logistic_regression/model.ckpt", output_node_names: "Softmax", restore_op_name: "save/restore_all", filename_tensor_name: "save/Const:0", - output_graph: "logistic_regression/model.pb", + output_graph: ".resources/logistic_regression/model.pb", clear_devices: true, initializer_nodes: ""); } @@ -148,7 +148,7 @@ namespace TensorFlowNET.Examples public void Predict(Session sess) { var graph = new Graph().as_default(); - graph.Import(Path.Join("logistic_regression", "model.pb")); + graph.Import(Path.Join(".resources/logistic_regression", "model.pb")); // restoring the model // var saver = tf.train.import_meta_graph("logistic_regression/tensorflowModel.ckpt.meta"); @@ -159,7 +159,7 @@ namespace TensorFlowNET.Examples var input = x.outputs[0]; // predict - var (batch_xs, batch_ys) = mnist.train.next_batch(10); + var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(10); var results = sess.run(output, new FeedItem(input, batch_xs[np.arange(1)])); if (results.argmax() == (batch_ys[0] as NDArray).argmax()) diff --git a/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs b/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs index 86ecd281..d1d867a2 100644 --- a/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs +++ b/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs @@ -17,7 +17,7 @@ using NumSharp; using System; using Tensorflow; -using TensorFlowNET.Examples.Utility; +using Tensorflow.Hub; using static Tensorflow.Python; namespace TensorFlowNET.Examples @@ -31,7 +31,7 @@ namespace TensorFlowNET.Examples { public bool Enabled { get; set; } = true; public string Name => "Nearest Neighbor"; - Datasets mnist; + Datasets mnist; NDArray Xtr, Ytr, Xte, Yte; public int? TrainSize = null; public int ValidationSize = 5000; @@ -84,10 +84,10 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize); + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize).Result; // In this example, we limit mnist data - (Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) - (Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing + (Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) + (Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing } public Graph ImportGraph() diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs index a5c757b9..d2a1b9f4 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs @@ -18,7 +18,7 @@ using NumSharp; using System; using System.Diagnostics; using Tensorflow; -using TensorFlowNET.Examples.Utility; +using Tensorflow.Hub; using static Tensorflow.Python; namespace TensorFlowNET.Examples.ImageProcess @@ -46,7 +46,7 @@ namespace TensorFlowNET.Examples.ImageProcess int epochs = 5; // accuracy > 98% int batch_size = 100; float learning_rate = 0.001f; - Datasets mnist; + Datasets mnist; // Network configuration // 1st Convolutional Layer @@ -310,14 +310,14 @@ namespace TensorFlowNET.Examples.ImageProcess public void PrepareData() { - mnist = MNIST.read_data_sets("mnist", one_hot: true); - (x_train, y_train) = Reformat(mnist.train.data, mnist.train.labels); - (x_valid, y_valid) = Reformat(mnist.validation.data, mnist.validation.labels); - (x_test, y_test) = Reformat(mnist.test.data, mnist.test.labels); + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; + (x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels); + (x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels); + (x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels); print("Size of:"); - print($"- Training-set:\t\t{len(mnist.train.data)}"); - print($"- Validation-set:\t{len(mnist.validation.data)}"); + print($"- Training-set:\t\t{len(mnist.Train.Data)}"); + print($"- Validation-set:\t{len(mnist.Validation.Data)}"); } /// diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs index 09fdc818..059c5419 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs @@ -17,7 +17,7 @@ using NumSharp; using System; using Tensorflow; -using TensorFlowNET.Examples.Utility; +using Tensorflow.Hub; using static Tensorflow.Python; namespace TensorFlowNET.Examples.ImageProcess @@ -44,7 +44,7 @@ namespace TensorFlowNET.Examples.ImageProcess int batch_size = 100; float learning_rate = 0.001f; int h1 = 200; // number of nodes in the 1st hidden layer - Datasets mnist; + Datasets mnist; Tensor x, y; Tensor loss, accuracy; @@ -121,13 +121,13 @@ namespace TensorFlowNET.Examples.ImageProcess public void PrepareData() { - mnist = MNIST.read_data_sets("mnist", one_hot: true); + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; } public void Train(Session sess) { // Number of training iterations in each epoch - var num_tr_iter = mnist.train.labels.len / batch_size; + var num_tr_iter = mnist.Train.Labels.len / batch_size; var init = tf.global_variables_initializer(); sess.run(init); @@ -139,13 +139,13 @@ namespace TensorFlowNET.Examples.ImageProcess { print($"Training epoch: {epoch + 1}"); // Randomly shuffle the training data at the beginning of each epoch - var (x_train, y_train) = randomize(mnist.train.data, mnist.train.labels); + var (x_train, y_train) = mnist.Randomize(mnist.Train.Data, mnist.Train.Labels); foreach (var iteration in range(num_tr_iter)) { var start = iteration * batch_size; var end = (iteration + 1) * batch_size; - var (x_batch, y_batch) = get_next_batch(x_train, y_train, start, end); + var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); // Run optimization op (backprop) sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); @@ -161,7 +161,8 @@ namespace TensorFlowNET.Examples.ImageProcess } // Run validation after every epoch - var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels)); + var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Validation.Data), new FeedItem(y, mnist.Validation.Labels)); + loss_val = results1[0]; accuracy_val = results1[1]; print("---------------------------------------------------------"); @@ -172,35 +173,12 @@ namespace TensorFlowNET.Examples.ImageProcess public void Test(Session sess) { - var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels)); + var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); loss_test = result[0]; accuracy_test = result[1]; print("---------------------------------------------------------"); print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); print("---------------------------------------------------------"); } - - private (NDArray, NDArray) randomize(NDArray x, NDArray y) - { - var perm = np.random.permutation(y.shape[0]); - - np.random.shuffle(perm); - return (mnist.train.data[perm], mnist.train.labels[perm]); - } - - /// - /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) - /// - /// - /// - /// - /// - /// - private (NDArray, NDArray) get_next_batch(NDArray x, NDArray y, int start, int end) - { - var x_batch = x[$"{start}:{end}"]; - var y_batch = y[$"{start}:{end}"]; - return (x_batch, y_batch); - } } } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs index d51ca9ad..babf62f3 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs @@ -17,7 +17,7 @@ using NumSharp; using System; using Tensorflow; -using TensorFlowNET.Examples.Utility; +using Tensorflow.Hub; using static Tensorflow.Python; namespace TensorFlowNET.Examples.ImageProcess @@ -45,7 +45,7 @@ namespace TensorFlowNET.Examples.ImageProcess int n_inputs = 28; int n_outputs = 10; - Datasets mnist; + Datasets mnist; Tensor x, y; Tensor loss, accuracy, cls_prediction; @@ -143,15 +143,15 @@ namespace TensorFlowNET.Examples.ImageProcess public void PrepareData() { - mnist = MNIST.read_data_sets("mnist", one_hot: true); - (x_train, y_train) = (mnist.train.data, mnist.train.labels); - (x_valid, y_valid) = (mnist.validation.data, mnist.validation.labels); - (x_test, y_test) = (mnist.test.data, mnist.test.labels); + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; + (x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels); + (x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels); + (x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels); print("Size of:"); - print($"- Training-set:\t\t{len(mnist.train.data)}"); - print($"- Validation-set:\t{len(mnist.validation.data)}"); - print($"- Test-set:\t\t{len(mnist.test.data)}"); + print($"- Training-set:\t\t{len(mnist.Train.Data)}"); + print($"- Validation-set:\t{len(mnist.Validation.Data)}"); + print($"- Test-set:\t\t{len(mnist.Test.Data)}"); } public Graph ImportGraph() => throw new NotImplementedException(); diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 149bd549..6184d4ad 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -18,5 +18,6 @@ + diff --git a/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs b/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs deleted file mode 100644 index 0017eba5..00000000 --- a/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs +++ /dev/null @@ -1,95 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using Tensorflow; - -namespace TensorFlowNET.Examples.Utility -{ - public class DataSetMnist : IDataSet - { - public int num_examples { get; } - - public int epochs_completed { get; private set; } - public int index_in_epoch { get; private set; } - public NDArray data { get; private set; } - public NDArray labels { get; private set; } - - public DataSetMnist(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) - { - num_examples = images.shape[0]; - images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); - images.astype(dtype.as_numpy_datatype()); - images = np.multiply(images, 1.0f / 255.0f); - - labels.astype(dtype.as_numpy_datatype()); - - data = images; - this.labels = labels; - epochs_completed = 0; - index_in_epoch = 0; - } - - public (NDArray, NDArray) next_batch(int batch_size, bool fake_data = false, bool shuffle = true) - { - var start = index_in_epoch; - // Shuffle for the first epoch - if(epochs_completed == 0 && start == 0 && shuffle) - { - var perm0 = np.arange(num_examples); - np.random.shuffle(perm0); - data = data[perm0]; - labels = labels[perm0]; - } - - // Go to the next epoch - if (start + batch_size > num_examples) - { - // Finished epoch - epochs_completed += 1; - - // Get the rest examples in this epoch - var rest_num_examples = num_examples - start; - //var images_rest_part = _images[np.arange(start, _num_examples)]; - //var labels_rest_part = _labels[np.arange(start, _num_examples)]; - // Shuffle the data - if (shuffle) - { - var perm = np.arange(num_examples); - np.random.shuffle(perm); - data = data[perm]; - labels = labels[perm]; - } - - start = 0; - index_in_epoch = batch_size - rest_num_examples; - var end = index_in_epoch; - var images_new_part = data[np.arange(start, end)]; - var labels_new_part = labels[np.arange(start, end)]; - - /*return (np.concatenate(new float[][] { images_rest_part.Data(), images_new_part.Data() }, axis: 0), - np.concatenate(new float[][] { labels_rest_part.Data(), labels_new_part.Data() }, axis: 0));*/ - return (images_new_part, labels_new_part); - } - else - { - index_in_epoch += batch_size; - var end = index_in_epoch; - return (data[np.arange(start, end)], labels[np.arange(start, end)]); - } - } - } -} diff --git a/test/TensorFlowNET.Examples/Utility/Datasets.cs b/test/TensorFlowNET.Examples/Utility/Datasets.cs deleted file mode 100644 index 0c8c4e2d..00000000 --- a/test/TensorFlowNET.Examples/Utility/Datasets.cs +++ /dev/null @@ -1,46 +0,0 @@ -using NumSharp; - -namespace TensorFlowNET.Examples.Utility -{ - public class Datasets where T : IDataSet - { - private T _train; - public T train => _train; - - private T _validation; - public T validation => _validation; - - private T _test; - public T test => _test; - - public Datasets(T train, T validation, T test) - { - _train = train; - _validation = validation; - _test = test; - } - - public (NDArray, NDArray) Randomize(NDArray x, NDArray y) - { - var perm = np.random.permutation(y.shape[0]); - - np.random.shuffle(perm); - return (x[perm], y[perm]); - } - - /// - /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) - /// - /// - /// - /// - /// - /// - public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) - { - var x_batch = x[$"{start}:{end}"]; - var y_batch = y[$"{start}:{end}"]; - return (x_batch, y_batch); - } - } -} diff --git a/test/TensorFlowNET.Examples/Utility/IDataSet.cs b/test/TensorFlowNET.Examples/Utility/IDataSet.cs deleted file mode 100644 index 31be57c1..00000000 --- a/test/TensorFlowNET.Examples/Utility/IDataSet.cs +++ /dev/null @@ -1,10 +0,0 @@ -using NumSharp; - -namespace TensorFlowNET.Examples.Utility -{ - public interface IDataSet - { - NDArray data { get; } - NDArray labels { get; } - } -} diff --git a/test/TensorFlowNET.Examples/Utility/MNIST.cs b/test/TensorFlowNET.Examples/Utility/MNIST.cs deleted file mode 100644 index 73d6fe2a..00000000 --- a/test/TensorFlowNET.Examples/Utility/MNIST.cs +++ /dev/null @@ -1,131 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using System.IO; -using Tensorflow; - -namespace TensorFlowNET.Examples.Utility -{ - public class MNIST - { - private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; - private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; - private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; - private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; - private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; - public static Datasets read_data_sets(string train_dir, - bool one_hot = false, - TF_DataType dtype = TF_DataType.TF_FLOAT, - bool reshape = true, - int validation_size = 5000, - int? train_size = null, - int? test_size = null, - string source_url = DEFAULT_SOURCE_URL) - { - if (train_size!=null && validation_size >= train_size) - throw new ArgumentException("Validation set should be smaller than training set"); - - Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); - Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); - var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]), limit: train_size); - - Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); - Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); - var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot, limit: train_size); - - Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); - Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); - var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]), limit: test_size); - - Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); - Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); - var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot, limit:test_size); - - int end = train_images.shape[0]; - var validation_images = train_images[np.arange(validation_size)]; - var validation_labels = train_labels[np.arange(validation_size)]; - train_images = train_images[np.arange(validation_size, end)]; - train_labels = train_labels[np.arange(validation_size, end)]; - - var train = new DataSetMnist(train_images, train_labels, dtype, reshape); - var validation = new DataSetMnist(validation_images, validation_labels, dtype, reshape); - var test = new DataSetMnist(test_images, test_labels, dtype, reshape); - - return new Datasets(train, validation, test); - } - - public static NDArray extract_images(string file, int? limit=null) - { - using (var bytestream = new FileStream(file, FileMode.Open)) - { - var magic = _read32(bytestream); - if (magic != 2051) - throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); - var num_images = _read32(bytestream); - num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit); - var rows = _read32(bytestream); - var cols = _read32(bytestream); - var buf = new byte[rows * cols * num_images]; - bytestream.Read(buf, 0, buf.Length); - var data = np.frombuffer(buf, np.uint8); - data = data.reshape((int)num_images, (int)rows, (int)cols, 1); - return data; - } - } - - public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10, int? limit = null) - { - using (var bytestream = new FileStream(file, FileMode.Open)) - { - var magic = _read32(bytestream); - if (magic != 2049) - throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); - var num_items = _read32(bytestream); - num_items = limit == null ? num_items : Math.Min(num_items,(uint) limit); - var buf = new byte[num_items]; - bytestream.Read(buf, 0, buf.Length); - var labels = np.frombuffer(buf, np.uint8); - if (one_hot) - return dense_to_one_hot(labels, num_classes); - return labels; - } - } - - private static NDArray dense_to_one_hot(NDArray labels_dense, int num_classes) - { - var num_labels = labels_dense.shape[0]; - var index_offset = np.arange(num_labels) * num_classes; - var labels_one_hot = np.zeros(num_labels, num_classes); - - for(int row = 0; row < num_labels; row++) - { - var col = labels_dense.Data(row); - labels_one_hot.SetData(1.0, row, col); - } - - return labels_one_hot; - } - - private static uint _read32(FileStream bytestream) - { - var buffer = new byte[sizeof(uint)]; - var count = bytestream.Read(buffer, 0, 4); - return np.frombuffer(buffer, ">u4").Data(0); - } - } -}