@@ -0,0 +1,27 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
using TensorFlowNET.Examples.Utility; | |||||
namespace TensorFlowNET.Examples | |||||
{ | |||||
/// <summary> | |||||
/// A logistic regression learning algorithm example using TensorFlow library. | |||||
/// This example is using the MNIST database of handwritten digits | |||||
/// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/logistic_regression.py | |||||
/// </summary> | |||||
public class LogisticRegression : Python, IExample | |||||
{ | |||||
public void Run() | |||||
{ | |||||
PrepareData(); | |||||
} | |||||
private void PrepareData() | |||||
{ | |||||
MnistDataSet.read_data_sets("logistic_regression", one_hot: true); | |||||
} | |||||
} | |||||
} |
@@ -6,6 +6,7 @@ | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="DevExpress.Xpo" Version="18.2.6" /> | |||||
<PackageReference Include="NumSharp" Version="0.8.0" /> | <PackageReference Include="NumSharp" Version="0.8.0" /> | ||||
<PackageReference Include="SharpZipLib" Version="1.1.0" /> | <PackageReference Include="SharpZipLib" Version="1.1.0" /> | ||||
<PackageReference Include="TensorFlow.NET" Version="0.4.2" /> | <PackageReference Include="TensorFlow.NET" Version="0.4.2" /> | ||||
@@ -1,4 +1,5 @@ | |||||
using ICSharpCode.SharpZipLib.GZip; | |||||
using ICSharpCode.SharpZipLib.Core; | |||||
using ICSharpCode.SharpZipLib.GZip; | |||||
using ICSharpCode.SharpZipLib.Tar; | using ICSharpCode.SharpZipLib.Tar; | ||||
using System; | using System; | ||||
using System.IO; | using System.IO; | ||||
@@ -11,6 +12,26 @@ namespace TensorFlowNET.Examples.Utility | |||||
{ | { | ||||
public class Compress | public class Compress | ||||
{ | { | ||||
public static void ExtractGZip(string gzipFileName, string targetDir) | |||||
{ | |||||
// Use a 4K buffer. Any larger is a waste. | |||||
byte[] dataBuffer = new byte[4096]; | |||||
using (System.IO.Stream fs = new FileStream(gzipFileName, FileMode.Open, FileAccess.Read)) | |||||
{ | |||||
using (GZipInputStream gzipStream = new GZipInputStream(fs)) | |||||
{ | |||||
// Change this to your needs | |||||
string fnOut = Path.Combine(targetDir, Path.GetFileNameWithoutExtension(gzipFileName)); | |||||
using (FileStream fsOut = File.Create(fnOut)) | |||||
{ | |||||
StreamUtils.Copy(gzipStream, fsOut, dataBuffer); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
public static void UnZip(String gzArchiveName, String destFolder) | public static void UnZip(String gzArchiveName, String destFolder) | ||||
{ | { | ||||
var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; | var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; | ||||
@@ -0,0 +1,20 @@ | |||||
using NumSharp.Core; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
namespace TensorFlowNET.Examples.Utility | |||||
{ | |||||
public class DataSet | |||||
{ | |||||
private int _num_examples; | |||||
public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) | |||||
{ | |||||
_num_examples = images.shape[0]; | |||||
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | |||||
images = np.multiply(images, 1.0f / 255.0f); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,110 @@ | |||||
using ICSharpCode.SharpZipLib.Core; | |||||
using ICSharpCode.SharpZipLib.GZip; | |||||
using NumSharp.Core; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.IO; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
namespace TensorFlowNET.Examples.Utility | |||||
{ | |||||
public class MnistDataSet | |||||
{ | |||||
private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; | |||||
private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; | |||||
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; | |||||
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | |||||
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | |||||
public static void read_data_sets(string train_dir, | |||||
bool one_hot = false, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
bool reshape = true, | |||||
int validation_size = 5000, | |||||
string source_url = DEFAULT_SOURCE_URL) | |||||
{ | |||||
Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); | |||||
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); | |||||
var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0])); | |||||
Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); | |||||
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); | |||||
var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot); | |||||
Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); | |||||
Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); | |||||
var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0])); | |||||
Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); | |||||
Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); | |||||
var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot); | |||||
int end = train_images.shape[0]; | |||||
var validation_images = train_images[np.arange(validation_size)]; | |||||
var validation_labels = train_labels[np.arange(validation_size)]; | |||||
train_images = train_images[np.arange(validation_size, end)]; | |||||
train_labels = train_labels[np.arange(validation_size, end)]; | |||||
var train = new DataSet(train_images, train_labels, dtype, reshape); | |||||
} | |||||
public static NDArray extract_images(string file) | |||||
{ | |||||
using (var bytestream = new FileStream(file, FileMode.Open)) | |||||
{ | |||||
var magic = _read32(bytestream); | |||||
if (magic != 2051) | |||||
throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); | |||||
var num_images = _read32(bytestream); | |||||
var rows = _read32(bytestream); | |||||
var cols = _read32(bytestream); | |||||
var buf = new byte[rows * cols * num_images]; | |||||
bytestream.Read(buf, 0, buf.Length); | |||||
var data = np.frombuffer(buf, np.uint8); | |||||
data = data.reshape((int)num_images, (int)rows, (int)cols, 1); | |||||
return data; | |||||
} | |||||
} | |||||
public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10) | |||||
{ | |||||
using (var bytestream = new FileStream(file, FileMode.Open)) | |||||
{ | |||||
var magic = _read32(bytestream); | |||||
if (magic != 2049) | |||||
throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); | |||||
var num_items = _read32(bytestream); | |||||
var buf = new byte[num_items]; | |||||
bytestream.Read(buf, 0, buf.Length); | |||||
var labels = np.frombuffer(buf, np.uint8); | |||||
if (one_hot) | |||||
return dense_to_one_hot(labels, num_classes); | |||||
return labels; | |||||
} | |||||
} | |||||
private static NDArray dense_to_one_hot(NDArray labels_dense, int num_classes) | |||||
{ | |||||
var num_labels = labels_dense.shape[0]; | |||||
var index_offset = np.arange(num_labels) * num_classes; | |||||
var labels_one_hot = np.zeros(num_labels, num_classes); | |||||
for(int row = 0; row < num_labels; row++) | |||||
{ | |||||
var col = labels_dense.Data<byte>(row); | |||||
labels_one_hot[row, col] = 1; | |||||
} | |||||
return labels_one_hot; | |||||
} | |||||
private static uint _read32(FileStream bytestream) | |||||
{ | |||||
var buffer = new byte[sizeof(uint)]; | |||||
var count = bytestream.Read(buffer, 0, 4); | |||||
return np.frombuffer(buffer, ">u4").Data<uint>(0); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,70 @@ | |||||
''' | |||||
A logistic regression learning algorithm example using TensorFlow library. | |||||
This example is using the MNIST database of handwritten digits | |||||
(http://yann.lecun.com/exdb/mnist/) | |||||
Author: Aymeric Damien | |||||
Project: https://github.com/aymericdamien/TensorFlow-Examples/ | |||||
''' | |||||
from __future__ import print_function | |||||
import tensorflow as tf | |||||
# Import MNIST data | |||||
from tensorflow.examples.tutorials.mnist import input_data | |||||
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) | |||||
# Parameters | |||||
learning_rate = 0.01 | |||||
training_epochs = 25 | |||||
batch_size = 100 | |||||
display_step = 1 | |||||
# tf Graph Input | |||||
x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784 | |||||
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes | |||||
# Set model weights | |||||
W = tf.Variable(tf.zeros([784, 10])) | |||||
b = tf.Variable(tf.zeros([10])) | |||||
# Construct model | |||||
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax | |||||
# Minimize error using cross entropy | |||||
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1)) | |||||
# Gradient Descent | |||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) | |||||
# Initialize the variables (i.e. assign their default value) | |||||
init = tf.global_variables_initializer() | |||||
# Start training | |||||
with tf.Session() as sess: | |||||
# Run the initializer | |||||
sess.run(init) | |||||
# Training cycle | |||||
for epoch in range(training_epochs): | |||||
avg_cost = 0. | |||||
total_batch = int(mnist.train.num_examples/batch_size) | |||||
# Loop over all batches | |||||
for i in range(total_batch): | |||||
batch_xs, batch_ys = mnist.train.next_batch(batch_size) | |||||
# Run optimization op (backprop) and cost op (to get loss value) | |||||
_, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, | |||||
y: batch_ys}) | |||||
# Compute average loss | |||||
avg_cost += c / total_batch | |||||
# Display logs per epoch step | |||||
if (epoch+1) % display_step == 0: | |||||
print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)) | |||||
print("Optimization Finished!") | |||||
# Test model | |||||
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) | |||||
# Calculate accuracy | |||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |||||
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) |