@@ -0,0 +1,27 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow; | |||
using TensorFlowNET.Examples.Utility; | |||
namespace TensorFlowNET.Examples | |||
{ | |||
/// <summary> | |||
/// A logistic regression learning algorithm example using TensorFlow library. | |||
/// This example is using the MNIST database of handwritten digits | |||
/// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/logistic_regression.py | |||
/// </summary> | |||
public class LogisticRegression : Python, IExample | |||
{ | |||
public void Run() | |||
{ | |||
PrepareData(); | |||
} | |||
private void PrepareData() | |||
{ | |||
MnistDataSet.read_data_sets("logistic_regression", one_hot: true); | |||
} | |||
} | |||
} |
@@ -6,6 +6,7 @@ | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<PackageReference Include="DevExpress.Xpo" Version="18.2.6" /> | |||
<PackageReference Include="NumSharp" Version="0.8.0" /> | |||
<PackageReference Include="SharpZipLib" Version="1.1.0" /> | |||
<PackageReference Include="TensorFlow.NET" Version="0.4.2" /> | |||
@@ -1,4 +1,5 @@ | |||
using ICSharpCode.SharpZipLib.GZip; | |||
using ICSharpCode.SharpZipLib.Core; | |||
using ICSharpCode.SharpZipLib.GZip; | |||
using ICSharpCode.SharpZipLib.Tar; | |||
using System; | |||
using System.IO; | |||
@@ -11,6 +12,26 @@ namespace TensorFlowNET.Examples.Utility | |||
{ | |||
public class Compress | |||
{ | |||
public static void ExtractGZip(string gzipFileName, string targetDir) | |||
{ | |||
// Use a 4K buffer. Any larger is a waste. | |||
byte[] dataBuffer = new byte[4096]; | |||
using (System.IO.Stream fs = new FileStream(gzipFileName, FileMode.Open, FileAccess.Read)) | |||
{ | |||
using (GZipInputStream gzipStream = new GZipInputStream(fs)) | |||
{ | |||
// Change this to your needs | |||
string fnOut = Path.Combine(targetDir, Path.GetFileNameWithoutExtension(gzipFileName)); | |||
using (FileStream fsOut = File.Create(fnOut)) | |||
{ | |||
StreamUtils.Copy(gzipStream, fsOut, dataBuffer); | |||
} | |||
} | |||
} | |||
} | |||
public static void UnZip(String gzArchiveName, String destFolder) | |||
{ | |||
var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; | |||
@@ -0,0 +1,20 @@ | |||
using NumSharp.Core; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow; | |||
namespace TensorFlowNET.Examples.Utility | |||
{ | |||
public class DataSet | |||
{ | |||
private int _num_examples; | |||
public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) | |||
{ | |||
_num_examples = images.shape[0]; | |||
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | |||
images = np.multiply(images, 1.0f / 255.0f); | |||
} | |||
} | |||
} |
@@ -0,0 +1,110 @@ | |||
using ICSharpCode.SharpZipLib.Core; | |||
using ICSharpCode.SharpZipLib.GZip; | |||
using NumSharp.Core; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow; | |||
namespace TensorFlowNET.Examples.Utility | |||
{ | |||
public class MnistDataSet | |||
{ | |||
private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; | |||
private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; | |||
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; | |||
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | |||
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | |||
public static void read_data_sets(string train_dir, | |||
bool one_hot = false, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
bool reshape = true, | |||
int validation_size = 5000, | |||
string source_url = DEFAULT_SOURCE_URL) | |||
{ | |||
Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); | |||
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); | |||
var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0])); | |||
Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); | |||
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); | |||
var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot); | |||
Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); | |||
Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); | |||
var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0])); | |||
Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); | |||
Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); | |||
var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot); | |||
int end = train_images.shape[0]; | |||
var validation_images = train_images[np.arange(validation_size)]; | |||
var validation_labels = train_labels[np.arange(validation_size)]; | |||
train_images = train_images[np.arange(validation_size, end)]; | |||
train_labels = train_labels[np.arange(validation_size, end)]; | |||
var train = new DataSet(train_images, train_labels, dtype, reshape); | |||
} | |||
public static NDArray extract_images(string file) | |||
{ | |||
using (var bytestream = new FileStream(file, FileMode.Open)) | |||
{ | |||
var magic = _read32(bytestream); | |||
if (magic != 2051) | |||
throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); | |||
var num_images = _read32(bytestream); | |||
var rows = _read32(bytestream); | |||
var cols = _read32(bytestream); | |||
var buf = new byte[rows * cols * num_images]; | |||
bytestream.Read(buf, 0, buf.Length); | |||
var data = np.frombuffer(buf, np.uint8); | |||
data = data.reshape((int)num_images, (int)rows, (int)cols, 1); | |||
return data; | |||
} | |||
} | |||
public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10) | |||
{ | |||
using (var bytestream = new FileStream(file, FileMode.Open)) | |||
{ | |||
var magic = _read32(bytestream); | |||
if (magic != 2049) | |||
throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); | |||
var num_items = _read32(bytestream); | |||
var buf = new byte[num_items]; | |||
bytestream.Read(buf, 0, buf.Length); | |||
var labels = np.frombuffer(buf, np.uint8); | |||
if (one_hot) | |||
return dense_to_one_hot(labels, num_classes); | |||
return labels; | |||
} | |||
} | |||
private static NDArray dense_to_one_hot(NDArray labels_dense, int num_classes) | |||
{ | |||
var num_labels = labels_dense.shape[0]; | |||
var index_offset = np.arange(num_labels) * num_classes; | |||
var labels_one_hot = np.zeros(num_labels, num_classes); | |||
for(int row = 0; row < num_labels; row++) | |||
{ | |||
var col = labels_dense.Data<byte>(row); | |||
labels_one_hot[row, col] = 1; | |||
} | |||
return labels_one_hot; | |||
} | |||
private static uint _read32(FileStream bytestream) | |||
{ | |||
var buffer = new byte[sizeof(uint)]; | |||
var count = bytestream.Read(buffer, 0, 4); | |||
return np.frombuffer(buffer, ">u4").Data<uint>(0); | |||
} | |||
} | |||
} |
@@ -0,0 +1,70 @@ | |||
''' | |||
A logistic regression learning algorithm example using TensorFlow library. | |||
This example is using the MNIST database of handwritten digits | |||
(http://yann.lecun.com/exdb/mnist/) | |||
Author: Aymeric Damien | |||
Project: https://github.com/aymericdamien/TensorFlow-Examples/ | |||
''' | |||
from __future__ import print_function | |||
import tensorflow as tf | |||
# Import MNIST data | |||
from tensorflow.examples.tutorials.mnist import input_data | |||
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) | |||
# Parameters | |||
learning_rate = 0.01 | |||
training_epochs = 25 | |||
batch_size = 100 | |||
display_step = 1 | |||
# tf Graph Input | |||
x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784 | |||
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes | |||
# Set model weights | |||
W = tf.Variable(tf.zeros([784, 10])) | |||
b = tf.Variable(tf.zeros([10])) | |||
# Construct model | |||
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax | |||
# Minimize error using cross entropy | |||
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1)) | |||
# Gradient Descent | |||
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) | |||
# Initialize the variables (i.e. assign their default value) | |||
init = tf.global_variables_initializer() | |||
# Start training | |||
with tf.Session() as sess: | |||
# Run the initializer | |||
sess.run(init) | |||
# Training cycle | |||
for epoch in range(training_epochs): | |||
avg_cost = 0. | |||
total_batch = int(mnist.train.num_examples/batch_size) | |||
# Loop over all batches | |||
for i in range(total_batch): | |||
batch_xs, batch_ys = mnist.train.next_batch(batch_size) | |||
# Run optimization op (backprop) and cost op (to get loss value) | |||
_, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, | |||
y: batch_ys}) | |||
# Compute average loss | |||
avg_cost += c / total_batch | |||
# Display logs per epoch step | |||
if (epoch+1) % display_step == 0: | |||
print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)) | |||
print("Optimization Finished!") | |||
# Test model | |||
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) | |||
# Calculate accuracy | |||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |||
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) |