diff --git a/test/TensorFlowNET.Examples/LogisticRegression.cs b/test/TensorFlowNET.Examples/LogisticRegression.cs
new file mode 100644
index 00000000..b9d01c6d
--- /dev/null
+++ b/test/TensorFlowNET.Examples/LogisticRegression.cs
@@ -0,0 +1,27 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow;
+using TensorFlowNET.Examples.Utility;
+
+namespace TensorFlowNET.Examples
+{
+ ///
+ /// A logistic regression learning algorithm example using TensorFlow library.
+ /// This example is using the MNIST database of handwritten digits
+ /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/logistic_regression.py
+ ///
+ public class LogisticRegression : Python, IExample
+ {
+ public void Run()
+ {
+ PrepareData();
+ }
+
+ private void PrepareData()
+ {
+ MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
+
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
index 503fd643..8426d12e 100644
--- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
+++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
@@ -6,6 +6,7 @@
+
diff --git a/test/TensorFlowNET.Examples/Utility/Compress.cs b/test/TensorFlowNET.Examples/Utility/Compress.cs
index cf40e2c4..bc38434b 100644
--- a/test/TensorFlowNET.Examples/Utility/Compress.cs
+++ b/test/TensorFlowNET.Examples/Utility/Compress.cs
@@ -1,4 +1,5 @@
-using ICSharpCode.SharpZipLib.GZip;
+using ICSharpCode.SharpZipLib.Core;
+using ICSharpCode.SharpZipLib.GZip;
using ICSharpCode.SharpZipLib.Tar;
using System;
using System.IO;
@@ -11,6 +12,26 @@ namespace TensorFlowNET.Examples.Utility
{
public class Compress
{
+ public static void ExtractGZip(string gzipFileName, string targetDir)
+ {
+ // Use a 4K buffer. Any larger is a waste.
+ byte[] dataBuffer = new byte[4096];
+
+ using (System.IO.Stream fs = new FileStream(gzipFileName, FileMode.Open, FileAccess.Read))
+ {
+ using (GZipInputStream gzipStream = new GZipInputStream(fs))
+ {
+ // Change this to your needs
+ string fnOut = Path.Combine(targetDir, Path.GetFileNameWithoutExtension(gzipFileName));
+
+ using (FileStream fsOut = File.Create(fnOut))
+ {
+ StreamUtils.Copy(gzipStream, fsOut, dataBuffer);
+ }
+ }
+ }
+ }
+
public static void UnZip(String gzArchiveName, String destFolder)
{
var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin";
diff --git a/test/TensorFlowNET.Examples/Utility/DataSet.cs b/test/TensorFlowNET.Examples/Utility/DataSet.cs
new file mode 100644
index 00000000..1005aec3
--- /dev/null
+++ b/test/TensorFlowNET.Examples/Utility/DataSet.cs
@@ -0,0 +1,20 @@
+using NumSharp.Core;
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow;
+
+namespace TensorFlowNET.Examples.Utility
+{
+ public class DataSet
+ {
+ private int _num_examples;
+
+ public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
+ {
+ _num_examples = images.shape[0];
+ images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
+ images = np.multiply(images, 1.0f / 255.0f);
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs b/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs
new file mode 100644
index 00000000..05ad2970
--- /dev/null
+++ b/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs
@@ -0,0 +1,110 @@
+using ICSharpCode.SharpZipLib.Core;
+using ICSharpCode.SharpZipLib.GZip;
+using NumSharp.Core;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Text;
+using Tensorflow;
+
+namespace TensorFlowNET.Examples.Utility
+{
+ public class MnistDataSet
+ {
+ private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/";
+ private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz";
+ private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
+ private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
+ private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";
+
+ public static void read_data_sets(string train_dir,
+ bool one_hot = false,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ bool reshape = true,
+ int validation_size = 5000,
+ string source_url = DEFAULT_SOURCE_URL)
+ {
+ Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES);
+ Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir);
+ var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]));
+
+ Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS);
+ Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir);
+ var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot);
+
+ Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES);
+ Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir);
+ var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]));
+
+ Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS);
+ Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir);
+ var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot);
+
+ int end = train_images.shape[0];
+ var validation_images = train_images[np.arange(validation_size)];
+ var validation_labels = train_labels[np.arange(validation_size)];
+ train_images = train_images[np.arange(validation_size, end)];
+ train_labels = train_labels[np.arange(validation_size, end)];
+
+ var train = new DataSet(train_images, train_labels, dtype, reshape);
+ }
+
+ public static NDArray extract_images(string file)
+ {
+ using (var bytestream = new FileStream(file, FileMode.Open))
+ {
+ var magic = _read32(bytestream);
+ if (magic != 2051)
+ throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}");
+ var num_images = _read32(bytestream);
+ var rows = _read32(bytestream);
+ var cols = _read32(bytestream);
+ var buf = new byte[rows * cols * num_images];
+ bytestream.Read(buf, 0, buf.Length);
+ var data = np.frombuffer(buf, np.uint8);
+ data = data.reshape((int)num_images, (int)rows, (int)cols, 1);
+ return data;
+ }
+ }
+
+ public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10)
+ {
+ using (var bytestream = new FileStream(file, FileMode.Open))
+ {
+ var magic = _read32(bytestream);
+ if (magic != 2049)
+ throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}");
+ var num_items = _read32(bytestream);
+ var buf = new byte[num_items];
+ bytestream.Read(buf, 0, buf.Length);
+ var labels = np.frombuffer(buf, np.uint8);
+ if (one_hot)
+ return dense_to_one_hot(labels, num_classes);
+ return labels;
+ }
+ }
+
+ private static NDArray dense_to_one_hot(NDArray labels_dense, int num_classes)
+ {
+ var num_labels = labels_dense.shape[0];
+ var index_offset = np.arange(num_labels) * num_classes;
+ var labels_one_hot = np.zeros(num_labels, num_classes);
+
+ for(int row = 0; row < num_labels; row++)
+ {
+ var col = labels_dense.Data(row);
+ labels_one_hot[row, col] = 1;
+ }
+
+ return labels_one_hot;
+ }
+
+ private static uint _read32(FileStream bytestream)
+ {
+ var buffer = new byte[sizeof(uint)];
+ var count = bytestream.Read(buffer, 0, 4);
+ return np.frombuffer(buffer, ">u4").Data(0);
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/python/logistic_regression.py b/test/TensorFlowNET.Examples/python/logistic_regression.py
new file mode 100644
index 00000000..338ebe5a
--- /dev/null
+++ b/test/TensorFlowNET.Examples/python/logistic_regression.py
@@ -0,0 +1,70 @@
+'''
+A logistic regression learning algorithm example using TensorFlow library.
+This example is using the MNIST database of handwritten digits
+(http://yann.lecun.com/exdb/mnist/)
+Author: Aymeric Damien
+Project: https://github.com/aymericdamien/TensorFlow-Examples/
+'''
+
+from __future__ import print_function
+
+import tensorflow as tf
+
+# Import MNIST data
+from tensorflow.examples.tutorials.mnist import input_data
+mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
+
+# Parameters
+learning_rate = 0.01
+training_epochs = 25
+batch_size = 100
+display_step = 1
+
+# tf Graph Input
+x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784
+y = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes
+
+# Set model weights
+W = tf.Variable(tf.zeros([784, 10]))
+b = tf.Variable(tf.zeros([10]))
+
+# Construct model
+pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax
+
+# Minimize error using cross entropy
+cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
+# Gradient Descent
+optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
+
+# Initialize the variables (i.e. assign their default value)
+init = tf.global_variables_initializer()
+
+# Start training
+with tf.Session() as sess:
+
+ # Run the initializer
+ sess.run(init)
+
+ # Training cycle
+ for epoch in range(training_epochs):
+ avg_cost = 0.
+ total_batch = int(mnist.train.num_examples/batch_size)
+ # Loop over all batches
+ for i in range(total_batch):
+ batch_xs, batch_ys = mnist.train.next_batch(batch_size)
+ # Run optimization op (backprop) and cost op (to get loss value)
+ _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
+ y: batch_ys})
+ # Compute average loss
+ avg_cost += c / total_batch
+ # Display logs per epoch step
+ if (epoch+1) % display_step == 0:
+ print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))
+
+ print("Optimization Finished!")
+
+ # Test model
+ correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
+ # Calculate accuracy
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
+ print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
\ No newline at end of file