Browse Source

Mnist dataset

tags/v0.8.0
haiping008 6 years ago
parent
commit
bd42ed97f2
6 changed files with 250 additions and 1 deletions
  1. +27
    -0
      test/TensorFlowNET.Examples/LogisticRegression.cs
  2. +1
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  3. +22
    -1
      test/TensorFlowNET.Examples/Utility/Compress.cs
  4. +20
    -0
      test/TensorFlowNET.Examples/Utility/DataSet.cs
  5. +110
    -0
      test/TensorFlowNET.Examples/Utility/MnistDataSet.cs
  6. +70
    -0
      test/TensorFlowNET.Examples/python/logistic_regression.py

+ 27
- 0
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using TensorFlowNET.Examples.Utility;

namespace TensorFlowNET.Examples
{
/// <summary>
/// A logistic regression learning algorithm example using TensorFlow library.
/// This example is using the MNIST database of handwritten digits
/// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/logistic_regression.py
/// </summary>
public class LogisticRegression : Python, IExample
{
public void Run()
{
PrepareData();
}

private void PrepareData()
{
MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
}
}
}

+ 1
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -6,6 +6,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="DevExpress.Xpo" Version="18.2.6" />
<PackageReference Include="NumSharp" Version="0.8.0" />
<PackageReference Include="SharpZipLib" Version="1.1.0" />
<PackageReference Include="TensorFlow.NET" Version="0.4.2" />


+ 22
- 1
test/TensorFlowNET.Examples/Utility/Compress.cs View File

@@ -1,4 +1,5 @@
using ICSharpCode.SharpZipLib.GZip;
using ICSharpCode.SharpZipLib.Core;
using ICSharpCode.SharpZipLib.GZip;
using ICSharpCode.SharpZipLib.Tar;
using System;
using System.IO;
@@ -11,6 +12,26 @@ namespace TensorFlowNET.Examples.Utility
{
public class Compress
{
public static void ExtractGZip(string gzipFileName, string targetDir)
{
// Use a 4K buffer. Any larger is a waste.
byte[] dataBuffer = new byte[4096];

using (System.IO.Stream fs = new FileStream(gzipFileName, FileMode.Open, FileAccess.Read))
{
using (GZipInputStream gzipStream = new GZipInputStream(fs))
{
// Change this to your needs
string fnOut = Path.Combine(targetDir, Path.GetFileNameWithoutExtension(gzipFileName));

using (FileStream fsOut = File.Create(fnOut))
{
StreamUtils.Copy(gzipStream, fsOut, dataBuffer);
}
}
}
}

public static void UnZip(String gzArchiveName, String destFolder)
{
var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin";


+ 20
- 0
test/TensorFlowNET.Examples/Utility/DataSet.cs View File

@@ -0,0 +1,20 @@
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.Examples.Utility
{
public class DataSet
{
private int _num_examples;

public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
{
_num_examples = images.shape[0];
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
images = np.multiply(images, 1.0f / 255.0f);
}
}
}

+ 110
- 0
test/TensorFlowNET.Examples/Utility/MnistDataSet.cs View File

@@ -0,0 +1,110 @@
using ICSharpCode.SharpZipLib.Core;
using ICSharpCode.SharpZipLib.GZip;
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.Examples.Utility
{
public class MnistDataSet
{
private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/";
private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz";
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";

public static void read_data_sets(string train_dir,
bool one_hot = false,
TF_DataType dtype = TF_DataType.DtInvalid,
bool reshape = true,
int validation_size = 5000,
string source_url = DEFAULT_SOURCE_URL)
{
Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES);
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir);
var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]));

Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS);
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir);
var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot);

Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES);
Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir);
var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]));

Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS);
Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir);
var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot);

int end = train_images.shape[0];
var validation_images = train_images[np.arange(validation_size)];
var validation_labels = train_labels[np.arange(validation_size)];
train_images = train_images[np.arange(validation_size, end)];
train_labels = train_labels[np.arange(validation_size, end)];

var train = new DataSet(train_images, train_labels, dtype, reshape);
}

public static NDArray extract_images(string file)
{
using (var bytestream = new FileStream(file, FileMode.Open))
{
var magic = _read32(bytestream);
if (magic != 2051)
throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}");
var num_images = _read32(bytestream);
var rows = _read32(bytestream);
var cols = _read32(bytestream);
var buf = new byte[rows * cols * num_images];
bytestream.Read(buf, 0, buf.Length);
var data = np.frombuffer(buf, np.uint8);
data = data.reshape((int)num_images, (int)rows, (int)cols, 1);
return data;
}
}

public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10)
{
using (var bytestream = new FileStream(file, FileMode.Open))
{
var magic = _read32(bytestream);
if (magic != 2049)
throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}");
var num_items = _read32(bytestream);
var buf = new byte[num_items];
bytestream.Read(buf, 0, buf.Length);
var labels = np.frombuffer(buf, np.uint8);
if (one_hot)
return dense_to_one_hot(labels, num_classes);
return labels;
}
}

private static NDArray dense_to_one_hot(NDArray labels_dense, int num_classes)
{
var num_labels = labels_dense.shape[0];
var index_offset = np.arange(num_labels) * num_classes;
var labels_one_hot = np.zeros(num_labels, num_classes);

for(int row = 0; row < num_labels; row++)
{
var col = labels_dense.Data<byte>(row);
labels_one_hot[row, col] = 1;
}

return labels_one_hot;
}

private static uint _read32(FileStream bytestream)
{
var buffer = new byte[sizeof(uint)];
var count = bytestream.Read(buffer, 0, 4);
return np.frombuffer(buffer, ">u4").Data<uint>(0);
}
}
}

+ 70
- 0
test/TensorFlowNET.Examples/python/logistic_regression.py View File

@@ -0,0 +1,70 @@
'''
A logistic regression learning algorithm example using TensorFlow library.
This example is using the MNIST database of handwritten digits
(http://yann.lecun.com/exdb/mnist/)
Author: Aymeric Damien
Project: https://github.com/aymericdamien/TensorFlow-Examples/
'''

from __future__ import print_function

import tensorflow as tf

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

# Parameters
learning_rate = 0.01
training_epochs = 25
batch_size = 100
display_step = 1

# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes

# Set model weights
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

# Construct model
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax

# Minimize error using cross entropy
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
# Gradient Descent
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()

# Start training
with tf.Session() as sess:

# Run the initializer
sess.run(init)

# Training cycle
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size)
# Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
y: batch_ys})
# Compute average loss
avg_cost += c / total_batch
# Display logs per epoch step
if (epoch+1) % display_step == 0:
print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))

print("Optimization Finished!")

# Test model
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

Loading…
Cancel
Save