From bd42ed97f2664656cfe3934db5e1852a6c49e530 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Tue, 19 Mar 2019 17:22:28 -0500 Subject: [PATCH] Mnist dataset --- .../LogisticRegression.cs | 27 +++++ .../TensorFlowNET.Examples.csproj | 1 + .../Utility/Compress.cs | 23 +++- .../TensorFlowNET.Examples/Utility/DataSet.cs | 20 ++++ .../Utility/MnistDataSet.cs | 110 ++++++++++++++++++ .../python/logistic_regression.py | 70 +++++++++++ 6 files changed, 250 insertions(+), 1 deletion(-) create mode 100644 test/TensorFlowNET.Examples/LogisticRegression.cs create mode 100644 test/TensorFlowNET.Examples/Utility/DataSet.cs create mode 100644 test/TensorFlowNET.Examples/Utility/MnistDataSet.cs create mode 100644 test/TensorFlowNET.Examples/python/logistic_regression.py diff --git a/test/TensorFlowNET.Examples/LogisticRegression.cs b/test/TensorFlowNET.Examples/LogisticRegression.cs new file mode 100644 index 00000000..b9d01c6d --- /dev/null +++ b/test/TensorFlowNET.Examples/LogisticRegression.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using TensorFlowNET.Examples.Utility; + +namespace TensorFlowNET.Examples +{ + /// + /// A logistic regression learning algorithm example using TensorFlow library. + /// This example is using the MNIST database of handwritten digits + /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/logistic_regression.py + /// + public class LogisticRegression : Python, IExample + { + public void Run() + { + PrepareData(); + } + + private void PrepareData() + { + MnistDataSet.read_data_sets("logistic_regression", one_hot: true); + + } + } +} diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 503fd643..8426d12e 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -6,6 +6,7 @@ + diff --git a/test/TensorFlowNET.Examples/Utility/Compress.cs b/test/TensorFlowNET.Examples/Utility/Compress.cs index cf40e2c4..bc38434b 100644 --- a/test/TensorFlowNET.Examples/Utility/Compress.cs +++ b/test/TensorFlowNET.Examples/Utility/Compress.cs @@ -1,4 +1,5 @@ -using ICSharpCode.SharpZipLib.GZip; +using ICSharpCode.SharpZipLib.Core; +using ICSharpCode.SharpZipLib.GZip; using ICSharpCode.SharpZipLib.Tar; using System; using System.IO; @@ -11,6 +12,26 @@ namespace TensorFlowNET.Examples.Utility { public class Compress { + public static void ExtractGZip(string gzipFileName, string targetDir) + { + // Use a 4K buffer. Any larger is a waste. + byte[] dataBuffer = new byte[4096]; + + using (System.IO.Stream fs = new FileStream(gzipFileName, FileMode.Open, FileAccess.Read)) + { + using (GZipInputStream gzipStream = new GZipInputStream(fs)) + { + // Change this to your needs + string fnOut = Path.Combine(targetDir, Path.GetFileNameWithoutExtension(gzipFileName)); + + using (FileStream fsOut = File.Create(fnOut)) + { + StreamUtils.Copy(gzipStream, fsOut, dataBuffer); + } + } + } + } + public static void UnZip(String gzArchiveName, String destFolder) { var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; diff --git a/test/TensorFlowNET.Examples/Utility/DataSet.cs b/test/TensorFlowNET.Examples/Utility/DataSet.cs new file mode 100644 index 00000000..1005aec3 --- /dev/null +++ b/test/TensorFlowNET.Examples/Utility/DataSet.cs @@ -0,0 +1,20 @@ +using NumSharp.Core; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.Examples.Utility +{ + public class DataSet + { + private int _num_examples; + + public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) + { + _num_examples = images.shape[0]; + images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); + images = np.multiply(images, 1.0f / 255.0f); + } + } +} diff --git a/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs b/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs new file mode 100644 index 00000000..05ad2970 --- /dev/null +++ b/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs @@ -0,0 +1,110 @@ +using ICSharpCode.SharpZipLib.Core; +using ICSharpCode.SharpZipLib.GZip; +using NumSharp.Core; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.Examples.Utility +{ + public class MnistDataSet + { + private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; + private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; + private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; + private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; + private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; + + public static void read_data_sets(string train_dir, + bool one_hot = false, + TF_DataType dtype = TF_DataType.DtInvalid, + bool reshape = true, + int validation_size = 5000, + string source_url = DEFAULT_SOURCE_URL) + { + Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); + Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); + var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0])); + + Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); + Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); + var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot); + + Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); + Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); + var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0])); + + Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); + Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); + var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot); + + int end = train_images.shape[0]; + var validation_images = train_images[np.arange(validation_size)]; + var validation_labels = train_labels[np.arange(validation_size)]; + train_images = train_images[np.arange(validation_size, end)]; + train_labels = train_labels[np.arange(validation_size, end)]; + + var train = new DataSet(train_images, train_labels, dtype, reshape); + } + + public static NDArray extract_images(string file) + { + using (var bytestream = new FileStream(file, FileMode.Open)) + { + var magic = _read32(bytestream); + if (magic != 2051) + throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); + var num_images = _read32(bytestream); + var rows = _read32(bytestream); + var cols = _read32(bytestream); + var buf = new byte[rows * cols * num_images]; + bytestream.Read(buf, 0, buf.Length); + var data = np.frombuffer(buf, np.uint8); + data = data.reshape((int)num_images, (int)rows, (int)cols, 1); + return data; + } + } + + public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10) + { + using (var bytestream = new FileStream(file, FileMode.Open)) + { + var magic = _read32(bytestream); + if (magic != 2049) + throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); + var num_items = _read32(bytestream); + var buf = new byte[num_items]; + bytestream.Read(buf, 0, buf.Length); + var labels = np.frombuffer(buf, np.uint8); + if (one_hot) + return dense_to_one_hot(labels, num_classes); + return labels; + } + } + + private static NDArray dense_to_one_hot(NDArray labels_dense, int num_classes) + { + var num_labels = labels_dense.shape[0]; + var index_offset = np.arange(num_labels) * num_classes; + var labels_one_hot = np.zeros(num_labels, num_classes); + + for(int row = 0; row < num_labels; row++) + { + var col = labels_dense.Data(row); + labels_one_hot[row, col] = 1; + } + + return labels_one_hot; + } + + private static uint _read32(FileStream bytestream) + { + var buffer = new byte[sizeof(uint)]; + var count = bytestream.Read(buffer, 0, 4); + return np.frombuffer(buffer, ">u4").Data(0); + } + } +} diff --git a/test/TensorFlowNET.Examples/python/logistic_regression.py b/test/TensorFlowNET.Examples/python/logistic_regression.py new file mode 100644 index 00000000..338ebe5a --- /dev/null +++ b/test/TensorFlowNET.Examples/python/logistic_regression.py @@ -0,0 +1,70 @@ +''' +A logistic regression learning algorithm example using TensorFlow library. +This example is using the MNIST database of handwritten digits +(http://yann.lecun.com/exdb/mnist/) +Author: Aymeric Damien +Project: https://github.com/aymericdamien/TensorFlow-Examples/ +''' + +from __future__ import print_function + +import tensorflow as tf + +# Import MNIST data +from tensorflow.examples.tutorials.mnist import input_data +mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) + +# Parameters +learning_rate = 0.01 +training_epochs = 25 +batch_size = 100 +display_step = 1 + +# tf Graph Input +x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784 +y = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes + +# Set model weights +W = tf.Variable(tf.zeros([784, 10])) +b = tf.Variable(tf.zeros([10])) + +# Construct model +pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax + +# Minimize error using cross entropy +cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1)) +# Gradient Descent +optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) + +# Initialize the variables (i.e. assign their default value) +init = tf.global_variables_initializer() + +# Start training +with tf.Session() as sess: + + # Run the initializer + sess.run(init) + + # Training cycle + for epoch in range(training_epochs): + avg_cost = 0. + total_batch = int(mnist.train.num_examples/batch_size) + # Loop over all batches + for i in range(total_batch): + batch_xs, batch_ys = mnist.train.next_batch(batch_size) + # Run optimization op (backprop) and cost op (to get loss value) + _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, + y: batch_ys}) + # Compute average loss + avg_cost += c / total_batch + # Display logs per epoch step + if (epoch+1) % display_step == 0: + print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)) + + print("Optimization Finished!") + + # Test model + correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) + # Calculate accuracy + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) \ No newline at end of file