@@ -27,7 +27,7 @@ namespace TensorFlowNET.Examples | |||||
public int? test_size = null; | public int? test_size = null; | ||||
public int batch_size = 1024; // The number of samples per batch | public int batch_size = 1024; // The number of samples per batch | ||||
Datasets mnist; | |||||
Datasets<DataSetMnist> mnist; | |||||
NDArray full_data_x; | NDArray full_data_x; | ||||
int num_steps = 20; // Total steps to train | int num_steps = 20; // Total steps to train | ||||
int k = 25; // The number of clusters | int k = 25; // The number of clusters | ||||
@@ -50,8 +50,8 @@ namespace TensorFlowNET.Examples | |||||
public void PrepareData() | public void PrepareData() | ||||
{ | { | ||||
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size:validation_size, test_size:test_size); | |||||
full_data_x = mnist.train.images; | |||||
mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size:validation_size, test_size:test_size); | |||||
full_data_x = mnist.train.data; | |||||
// download graph meta data | // download graph meta data | ||||
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta"; | string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta"; | ||||
@@ -141,7 +141,7 @@ namespace TensorFlowNET.Examples | |||||
var accuracy_op = tf.reduce_mean(cast); | var accuracy_op = tf.reduce_mean(cast); | ||||
// Test Model | // Test Model | ||||
var (test_x, test_y) = (mnist.test.images, mnist.test.labels); | |||||
var (test_x, test_y) = (mnist.test.data, mnist.test.labels); | |||||
result = sess.run(accuracy_op, new FeedItem(X, test_x), new FeedItem(Y, test_y)); | result = sess.run(accuracy_op, new FeedItem(X, test_x), new FeedItem(Y, test_y)); | ||||
accuray_test = result; | accuray_test = result; | ||||
print($"Test Accuracy: {accuray_test}"); | print($"Test Accuracy: {accuray_test}"); | ||||
@@ -32,7 +32,7 @@ namespace TensorFlowNET.Examples | |||||
private float learning_rate = 0.01f; | private float learning_rate = 0.01f; | ||||
private int display_step = 1; | private int display_step = 1; | ||||
Datasets mnist; | |||||
Datasets<DataSetMnist> mnist; | |||||
public bool Run() | public bool Run() | ||||
{ | { | ||||
@@ -102,7 +102,7 @@ namespace TensorFlowNET.Examples | |||||
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | ||||
// Calculate accuracy | // Calculate accuracy | ||||
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | ||||
float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels)); | |||||
float acc = accuracy.eval(new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels)); | |||||
print($"Accuracy: {acc.ToString("F4")}"); | print($"Accuracy: {acc.ToString("F4")}"); | ||||
return acc > 0.9; | return acc > 0.9; | ||||
@@ -111,7 +111,7 @@ namespace TensorFlowNET.Examples | |||||
public void PrepareData() | public void PrepareData() | ||||
{ | { | ||||
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size); | |||||
mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size); | |||||
} | } | ||||
public void SaveModel(Session sess) | public void SaveModel(Session sess) | ||||
@@ -17,7 +17,7 @@ namespace TensorFlowNET.Examples | |||||
{ | { | ||||
public bool Enabled { get; set; } = true; | public bool Enabled { get; set; } = true; | ||||
public string Name => "Nearest Neighbor"; | public string Name => "Nearest Neighbor"; | ||||
Datasets mnist; | |||||
Datasets<DataSetMnist> mnist; | |||||
NDArray Xtr, Ytr, Xte, Yte; | NDArray Xtr, Ytr, Xte, Yte; | ||||
public int? TrainSize = null; | public int? TrainSize = null; | ||||
public int ValidationSize = 5000; | public int ValidationSize = 5000; | ||||
@@ -70,7 +70,7 @@ namespace TensorFlowNET.Examples | |||||
public void PrepareData() | public void PrepareData() | ||||
{ | { | ||||
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize); | |||||
mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize); | |||||
// In this example, we limit mnist data | // In this example, we limit mnist data | ||||
(Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) | (Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) | ||||
(Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing | (Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing | ||||
@@ -0,0 +1,169 @@ | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
using TensorFlowNET.Examples.Utility; | |||||
using static Tensorflow.Python; | |||||
namespace TensorFlowNET.Examples.ImageProcess | |||||
{ | |||||
/// <summary> | |||||
/// Convolutional Neural Network classifier for Hand Written Digits | |||||
/// CNN architecture with two convolutional layers, followed by two fully-connected layers at the end. | |||||
/// Use Stochastic Gradient Descent (SGD) optimizer. | |||||
/// http://www.easy-tensorflow.com/tf-tutorials/convolutional-neural-nets-cnns/cnn1 | |||||
/// </summary> | |||||
public class DigitRecognitionCNN : IExample | |||||
{ | |||||
public bool Enabled { get; set; } = true; | |||||
public bool IsImportingGraph { get; set; } = false; | |||||
public string Name => "MNIST CNN"; | |||||
const int img_h = 28; | |||||
const int img_w = 28; | |||||
int img_size_flat = img_h * img_w; // 784, the total number of pixels | |||||
int n_classes = 10; // Number of classes, one class per digit | |||||
// Hyper-parameters | |||||
int epochs = 10; | |||||
int batch_size = 100; | |||||
float learning_rate = 0.001f; | |||||
int h1 = 200; // number of nodes in the 1st hidden layer | |||||
Datasets<DataSetMnist> mnist; | |||||
Tensor x, y; | |||||
Tensor loss, accuracy; | |||||
Operation optimizer; | |||||
int display_freq = 100; | |||||
float accuracy_test = 0f; | |||||
float loss_test = 1f; | |||||
public bool Run() | |||||
{ | |||||
PrepareData(); | |||||
BuildGraph(); | |||||
with(tf.Session(), sess => | |||||
{ | |||||
Train(sess); | |||||
Test(sess); | |||||
}); | |||||
return loss_test < 0.09 && accuracy_test > 0.95; | |||||
} | |||||
public Graph BuildGraph() | |||||
{ | |||||
var graph = new Graph().as_default(); | |||||
// Placeholders for inputs (x) and outputs(y) | |||||
x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X"); | |||||
y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); | |||||
// Create a fully-connected layer with h1 nodes as hidden layer | |||||
var fc1 = fc_layer(x, h1, "FC1", use_relu: true); | |||||
// Create a fully-connected layer with n_classes nodes as output layer | |||||
var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); | |||||
// Define the loss function, optimizer, and accuracy | |||||
var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits); | |||||
loss = tf.reduce_mean(logits, name: "loss"); | |||||
optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); | |||||
var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); | |||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); | |||||
// Network predictions | |||||
var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); | |||||
return graph; | |||||
} | |||||
private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) | |||||
{ | |||||
var in_dim = x.shape[1]; | |||||
var initer = tf.truncated_normal_initializer(stddev: 0.01f); | |||||
var W = tf.get_variable("W_" + name, | |||||
dtype: tf.float32, | |||||
shape: (in_dim, num_units), | |||||
initializer: initer); | |||||
var initial = tf.constant(0f, num_units); | |||||
var b = tf.get_variable("b_" + name, | |||||
dtype: tf.float32, | |||||
initializer: initial); | |||||
var layer = tf.matmul(x, W) + b; | |||||
if (use_relu) | |||||
layer = tf.nn.relu(layer); | |||||
return layer; | |||||
} | |||||
public Graph ImportGraph() => throw new NotImplementedException(); | |||||
public void Predict(Session sess) => throw new NotImplementedException(); | |||||
public void PrepareData() | |||||
{ | |||||
mnist = MNIST.read_data_sets("mnist", one_hot: true); | |||||
} | |||||
public void Train(Session sess) | |||||
{ | |||||
// Number of training iterations in each epoch | |||||
var num_tr_iter = mnist.train.labels.len / batch_size; | |||||
var init = tf.global_variables_initializer(); | |||||
sess.run(init); | |||||
float loss_val = 100.0f; | |||||
float accuracy_val = 0f; | |||||
foreach (var epoch in range(epochs)) | |||||
{ | |||||
print($"Training epoch: {epoch + 1}"); | |||||
// Randomly shuffle the training data at the beginning of each epoch | |||||
var (x_train, y_train) = mnist.Randomize(mnist.train.data, mnist.train.labels); | |||||
foreach (var iteration in range(num_tr_iter)) | |||||
{ | |||||
var start = iteration * batch_size; | |||||
var end = (iteration + 1) * batch_size; | |||||
var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); | |||||
// Run optimization op (backprop) | |||||
sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); | |||||
if (iteration % display_freq == 0) | |||||
{ | |||||
// Calculate and display the batch loss and accuracy | |||||
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); | |||||
loss_val = result[0]; | |||||
accuracy_val = result[1]; | |||||
print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); | |||||
} | |||||
} | |||||
// Run validation after every epoch | |||||
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels)); | |||||
loss_val = results1[0]; | |||||
accuracy_val = results1[1]; | |||||
print("---------------------------------------------------------"); | |||||
print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); | |||||
print("---------------------------------------------------------"); | |||||
} | |||||
} | |||||
public void Test(Session sess) | |||||
{ | |||||
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels)); | |||||
loss_test = result[0]; | |||||
accuracy_test = result[1]; | |||||
print("---------------------------------------------------------"); | |||||
print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); | |||||
print("---------------------------------------------------------"); | |||||
} | |||||
} | |||||
} |
@@ -30,7 +30,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
int batch_size = 100; | int batch_size = 100; | ||||
float learning_rate = 0.001f; | float learning_rate = 0.001f; | ||||
int h1 = 200; // number of nodes in the 1st hidden layer | int h1 = 200; // number of nodes in the 1st hidden layer | ||||
Datasets mnist; | |||||
Datasets<DataSetMnist> mnist; | |||||
Tensor x, y; | Tensor x, y; | ||||
Tensor loss, accuracy; | Tensor loss, accuracy; | ||||
@@ -107,7 +107,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
public void PrepareData() | public void PrepareData() | ||||
{ | { | ||||
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | |||||
mnist = MNIST.read_data_sets("mnist", one_hot: true); | |||||
} | } | ||||
public void Train(Session sess) | public void Train(Session sess) | ||||
@@ -125,7 +125,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
{ | { | ||||
print($"Training epoch: {epoch + 1}"); | print($"Training epoch: {epoch + 1}"); | ||||
// Randomly shuffle the training data at the beginning of each epoch | // Randomly shuffle the training data at the beginning of each epoch | ||||
var (x_train, y_train) = randomize(mnist.train.images, mnist.train.labels); | |||||
var (x_train, y_train) = randomize(mnist.train.data, mnist.train.labels); | |||||
foreach (var iteration in range(num_tr_iter)) | foreach (var iteration in range(num_tr_iter)) | ||||
{ | { | ||||
@@ -147,7 +147,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
} | } | ||||
// Run validation after every epoch | // Run validation after every epoch | ||||
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.images), new FeedItem(y, mnist.validation.labels)); | |||||
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels)); | |||||
loss_val = results1[0]; | loss_val = results1[0]; | ||||
accuracy_val = results1[1]; | accuracy_val = results1[1]; | ||||
print("---------------------------------------------------------"); | print("---------------------------------------------------------"); | ||||
@@ -158,7 +158,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
public void Test(Session sess) | public void Test(Session sess) | ||||
{ | { | ||||
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels)); | |||||
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels)); | |||||
loss_test = result[0]; | loss_test = result[0]; | ||||
accuracy_test = result[1]; | accuracy_test = result[1]; | ||||
print("---------------------------------------------------------"); | print("---------------------------------------------------------"); | ||||
@@ -171,7 +171,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
var perm = np.random.permutation(y.shape[0]); | var perm = np.random.permutation(y.shape[0]); | ||||
np.random.shuffle(perm); | np.random.shuffle(perm); | ||||
return (mnist.train.images[perm], mnist.train.labels[perm]); | |||||
return (mnist.train.data[perm], mnist.train.labels[perm]); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -1,86 +0,0 @@ | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
namespace TensorFlowNET.Examples.Utility | |||||
{ | |||||
public class DataSet | |||||
{ | |||||
private int _num_examples; | |||||
public int num_examples => _num_examples; | |||||
private int _epochs_completed; | |||||
public int epochs_completed => _epochs_completed; | |||||
private int _index_in_epoch; | |||||
public int index_in_epoch => _index_in_epoch; | |||||
private NDArray _images; | |||||
public NDArray images => _images; | |||||
private NDArray _labels; | |||||
public NDArray labels => _labels; | |||||
public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) | |||||
{ | |||||
_num_examples = images.shape[0]; | |||||
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | |||||
images.astype(dtype.as_numpy_datatype()); | |||||
images = np.multiply(images, 1.0f / 255.0f); | |||||
labels.astype(dtype.as_numpy_datatype()); | |||||
_images = images; | |||||
_labels = labels; | |||||
_epochs_completed = 0; | |||||
_index_in_epoch = 0; | |||||
} | |||||
public (NDArray, NDArray) next_batch(int batch_size, bool fake_data = false, bool shuffle = true) | |||||
{ | |||||
var start = _index_in_epoch; | |||||
// Shuffle for the first epoch | |||||
if(_epochs_completed == 0 && start == 0 && shuffle) | |||||
{ | |||||
var perm0 = np.arange(_num_examples); | |||||
np.random.shuffle(perm0); | |||||
_images = images[perm0]; | |||||
_labels = labels[perm0]; | |||||
} | |||||
// Go to the next epoch | |||||
if (start + batch_size > _num_examples) | |||||
{ | |||||
// Finished epoch | |||||
_epochs_completed += 1; | |||||
// Get the rest examples in this epoch | |||||
var rest_num_examples = _num_examples - start; | |||||
//var images_rest_part = _images[np.arange(start, _num_examples)]; | |||||
//var labels_rest_part = _labels[np.arange(start, _num_examples)]; | |||||
// Shuffle the data | |||||
if (shuffle) | |||||
{ | |||||
var perm = np.arange(_num_examples); | |||||
np.random.shuffle(perm); | |||||
_images = images[perm]; | |||||
_labels = labels[perm]; | |||||
} | |||||
start = 0; | |||||
_index_in_epoch = batch_size - rest_num_examples; | |||||
var end = _index_in_epoch; | |||||
var images_new_part = _images[np.arange(start, end)]; | |||||
var labels_new_part = _labels[np.arange(start, end)]; | |||||
/*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0), | |||||
np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/ | |||||
return (images_new_part, labels_new_part); | |||||
} | |||||
else | |||||
{ | |||||
_index_in_epoch += batch_size; | |||||
var end = _index_in_epoch; | |||||
return (_images[np.arange(start, end)], _labels[np.arange(start, end)]); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,82 @@ | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
namespace TensorFlowNET.Examples.Utility | |||||
{ | |||||
public class DataSetMnist : IDataSet | |||||
{ | |||||
public int num_examples { get; } | |||||
public int epochs_completed { get; private set; } | |||||
public int index_in_epoch { get; private set; } | |||||
public NDArray data { get; private set; } | |||||
public NDArray labels { get; private set; } | |||||
public DataSetMnist(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) | |||||
{ | |||||
num_examples = images.shape[0]; | |||||
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | |||||
images.astype(dtype.as_numpy_datatype()); | |||||
images = np.multiply(images, 1.0f / 255.0f); | |||||
labels.astype(dtype.as_numpy_datatype()); | |||||
data = images; | |||||
this.labels = labels; | |||||
epochs_completed = 0; | |||||
index_in_epoch = 0; | |||||
} | |||||
public (NDArray, NDArray) next_batch(int batch_size, bool fake_data = false, bool shuffle = true) | |||||
{ | |||||
var start = index_in_epoch; | |||||
// Shuffle for the first epoch | |||||
if(epochs_completed == 0 && start == 0 && shuffle) | |||||
{ | |||||
var perm0 = np.arange(num_examples); | |||||
np.random.shuffle(perm0); | |||||
data = data[perm0]; | |||||
labels = labels[perm0]; | |||||
} | |||||
// Go to the next epoch | |||||
if (start + batch_size > num_examples) | |||||
{ | |||||
// Finished epoch | |||||
epochs_completed += 1; | |||||
// Get the rest examples in this epoch | |||||
var rest_num_examples = num_examples - start; | |||||
//var images_rest_part = _images[np.arange(start, _num_examples)]; | |||||
//var labels_rest_part = _labels[np.arange(start, _num_examples)]; | |||||
// Shuffle the data | |||||
if (shuffle) | |||||
{ | |||||
var perm = np.arange(num_examples); | |||||
np.random.shuffle(perm); | |||||
data = data[perm]; | |||||
labels = labels[perm]; | |||||
} | |||||
start = 0; | |||||
index_in_epoch = batch_size - rest_num_examples; | |||||
var end = index_in_epoch; | |||||
var images_new_part = data[np.arange(start, end)]; | |||||
var labels_new_part = labels[np.arange(start, end)]; | |||||
/*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0), | |||||
np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/ | |||||
return (images_new_part, labels_new_part); | |||||
} | |||||
else | |||||
{ | |||||
index_in_epoch += batch_size; | |||||
var end = index_in_epoch; | |||||
return (data[np.arange(start, end)], labels[np.arange(start, end)]); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -1,25 +1,49 @@ | |||||
using System; | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
namespace TensorFlowNET.Examples.Utility | namespace TensorFlowNET.Examples.Utility | ||||
{ | { | ||||
public class Datasets | |||||
public class Datasets<T> where T : IDataSet | |||||
{ | { | ||||
private DataSet _train; | |||||
public DataSet train => _train; | |||||
private T _train; | |||||
public T train => _train; | |||||
private DataSet _validation; | |||||
public DataSet validation => _validation; | |||||
private T _validation; | |||||
public T validation => _validation; | |||||
private DataSet _test; | |||||
public DataSet test => _test; | |||||
private T _test; | |||||
public T test => _test; | |||||
public Datasets(DataSet train, DataSet validation, DataSet test) | |||||
public Datasets(T train, T validation, T test) | |||||
{ | { | ||||
_train = train; | _train = train; | ||||
_validation = validation; | _validation = validation; | ||||
_test = test; | _test = test; | ||||
} | } | ||||
public (NDArray, NDArray) Randomize(NDArray x, NDArray y) | |||||
{ | |||||
var perm = np.random.permutation(y.shape[0]); | |||||
np.random.shuffle(perm); | |||||
return (train.data[perm], train.labels[perm]); | |||||
} | |||||
/// <summary> | |||||
/// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) | |||||
/// </summary> | |||||
/// <param name="x"></param> | |||||
/// <param name="y"></param> | |||||
/// <param name="start"></param> | |||||
/// <param name="end"></param> | |||||
/// <returns></returns> | |||||
public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) | |||||
{ | |||||
var x_batch = x[$"{start}:{end}"]; | |||||
var y_batch = y[$"{start}:{end}"]; | |||||
return (x_batch, y_batch); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,13 @@ | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace TensorFlowNET.Examples.Utility | |||||
{ | |||||
public interface IDataSet | |||||
{ | |||||
NDArray data { get; } | |||||
NDArray labels { get; } | |||||
} | |||||
} |
@@ -8,14 +8,14 @@ using Tensorflow; | |||||
namespace TensorFlowNET.Examples.Utility | namespace TensorFlowNET.Examples.Utility | ||||
{ | { | ||||
public class MnistDataSet | |||||
public class MNIST | |||||
{ | { | ||||
private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; | private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; | ||||
private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; | private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; | ||||
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; | private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; | ||||
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | ||||
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | ||||
public static Datasets read_data_sets(string train_dir, | |||||
public static Datasets<DataSetMnist> read_data_sets(string train_dir, | |||||
bool one_hot = false, | bool one_hot = false, | ||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
bool reshape = true, | bool reshape = true, | ||||
@@ -24,9 +24,9 @@ namespace TensorFlowNET.Examples.Utility | |||||
int? test_size = null, | int? test_size = null, | ||||
string source_url = DEFAULT_SOURCE_URL) | string source_url = DEFAULT_SOURCE_URL) | ||||
{ | { | ||||
if (train_size!=null && validation_size >= train_size) | |||||
throw new ArgumentException("Validation set should be smaller than training set"); | |||||
if (train_size!=null && validation_size >= train_size) | |||||
throw new ArgumentException("Validation set should be smaller than training set"); | |||||
Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); | Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); | ||||
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); | Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); | ||||
var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]), limit: train_size); | var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]), limit: train_size); | ||||
@@ -49,11 +49,11 @@ namespace TensorFlowNET.Examples.Utility | |||||
train_images = train_images[np.arange(validation_size, end)]; | train_images = train_images[np.arange(validation_size, end)]; | ||||
train_labels = train_labels[np.arange(validation_size, end)]; | train_labels = train_labels[np.arange(validation_size, end)]; | ||||
var train = new DataSet(train_images, train_labels, dtype, reshape); | |||||
var validation = new DataSet(validation_images, validation_labels, dtype, reshape); | |||||
var test = new DataSet(test_images, test_labels, dtype, reshape); | |||||
var train = new DataSetMnist(train_images, train_labels, dtype, reshape); | |||||
var validation = new DataSetMnist(validation_images, validation_labels, dtype, reshape); | |||||
var test = new DataSetMnist(test_images, test_labels, dtype, reshape); | |||||
return new Datasets(train, validation, test); | |||||
return new Datasets<DataSetMnist>(train, validation, test); | |||||
} | } | ||||
public static NDArray extract_images(string file, int? limit=null) | public static NDArray extract_images(string file, int? limit=null) |