Browse Source

Finished MNIST CNN example.

tags/v0.10
Oceania2018 6 years ago
parent
commit
ffdebe26ce
3 changed files with 87 additions and 65 deletions
  1. +1
    -0
      README.md
  2. +85
    -64
      test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs
  3. +1
    -1
      test/TensorFlowNET.Examples/Utility/Datasets.cs

+ 1
- 0
README.md View File

@@ -149,6 +149,7 @@ Example runner will download all the required files like training data and model
* [Object Detection](test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs) * [Object Detection](test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs)
* [Text Classification](test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs) * [Text Classification](test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs)
* [CNN Text Classification](test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs) * [CNN Text Classification](test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs)
* [MNIST CNN](test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs)
* [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcess/NER) * [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcess/NER)
* [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs) * [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs)




+ 85
- 64
test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs View File

@@ -73,7 +73,9 @@ namespace TensorFlowNET.Examples.ImageProcess
float accuracy_test = 0f; float accuracy_test = 0f;
float loss_test = 1f; float loss_test = 1f;


NDArray x_train;
NDArray x_train, y_train;
NDArray x_valid, y_valid;
NDArray x_test, y_test;


public bool Run() public bool Run()
{ {
@@ -135,6 +137,62 @@ namespace TensorFlowNET.Examples.ImageProcess
return graph; return graph;
} }


public void Train(Session sess)
{
// Number of training iterations in each epoch
var num_tr_iter = y_train.len / batch_size;

var init = tf.global_variables_initializer();
sess.run(init);

float loss_val = 100.0f;
float accuracy_val = 0f;

foreach (var epoch in range(epochs))
{
print($"Training epoch: {epoch + 1}");
// Randomly shuffle the training data at the beginning of each epoch
(x_train, y_train) = mnist.Randomize(x_train, y_train);

foreach (var iteration in range(num_tr_iter))
{
var start = iteration * batch_size;
var end = (iteration + 1) * batch_size;
var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);

// Run optimization op (backprop)
sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));

if (iteration % display_freq == 0)
{
// Calculate and display the batch loss and accuracy
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
loss_val = result[0];
accuracy_val = result[1];
print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
}
}

// Run validation after every epoch
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_valid), new FeedItem(y, y_valid));
loss_val = results1[0];
accuracy_val = results1[1];
print("---------------------------------------------------------");
print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
print("---------------------------------------------------------");
}
}

public void Test(Session sess)
{
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_test), new FeedItem(y, y_test));
loss_test = result[0];
accuracy_test = result[1];
print("---------------------------------------------------------");
print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
print("---------------------------------------------------------");
}

/// <summary> /// <summary>
/// Create a 2D convolution layer /// Create a 2D convolution layer
/// </summary> /// </summary>
@@ -219,6 +277,14 @@ namespace TensorFlowNET.Examples.ImageProcess
initializer: initial); initializer: initial);
} }


/// <summary>
/// Create a fully-connected layer
/// </summary>
/// <param name="x">input from previous layer</param>
/// <param name="num_units">number of hidden units in the fully-connected layer</param>
/// <param name="name">layer name</param>
/// <param name="use_relu">boolean to add ReLU non-linearity (or not)</param>
/// <returns>The output array</returns>
private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
{ {
return with(tf.variable_scope(name), delegate return with(tf.variable_scope(name), delegate
@@ -235,81 +301,36 @@ namespace TensorFlowNET.Examples.ImageProcess
return layer; return layer;
}); });
} }

public Graph ImportGraph() => throw new NotImplementedException();

public void Predict(Session sess) => throw new NotImplementedException();
public void PrepareData() public void PrepareData()
{ {
mnist = MNIST.read_data_sets("mnist", one_hot: true); mnist = MNIST.read_data_sets("mnist", one_hot: true);
x_train = Reformat(mnist.train.data, mnist.train.labels);
(x_train, y_train) = Reformat(mnist.train.data, mnist.train.labels);
(x_valid, y_valid) = Reformat(mnist.validation.data, mnist.validation.labels);
(x_test, y_test) = Reformat(mnist.test.data, mnist.test.labels);

print("Size of:"); print("Size of:");
print($"- Training-set:\t\t{len(mnist.train.data)}"); print($"- Training-set:\t\t{len(mnist.train.data)}");
print($"- Validation-set:\t{len(mnist.validation.data)}"); print($"- Validation-set:\t{len(mnist.validation.data)}");
} }


private NDArray Reformat(NDArray x, NDArray y)
/// <summary>
/// Reformats the data to the format acceptable for convolutional layers
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <returns></returns>
private (NDArray, NDArray) Reformat(NDArray x, NDArray y)
{ {
var (img_size, num_ch, num_class) = (np.sqrt(x.shape[1]), 1, np.unique<int>(np.argmax(y, 1)));

return x;
var (img_size, num_ch, num_class) = (np.sqrt(x.shape[1]), 1, len(np.unique<int>(np.argmax(y, 1))));
var dataset = x.reshape(x.shape[0], img_size, img_size, num_ch).astype(np.float32);
//y[0] = np.arange(num_class) == y[0];
//var labels = (np.arange(num_class) == y.reshape(y.shape[0], 1, y.shape[1])).astype(np.float32);
return (dataset, y);
} }


public void Train(Session sess)
{
// Number of training iterations in each epoch
var num_tr_iter = mnist.train.labels.len / batch_size;

var init = tf.global_variables_initializer();
sess.run(init);

float loss_val = 100.0f;
float accuracy_val = 0f;

foreach (var epoch in range(epochs))
{
print($"Training epoch: {epoch + 1}");
// Randomly shuffle the training data at the beginning of each epoch
var (x_train, y_train) = mnist.Randomize(mnist.train.data, mnist.train.labels);

foreach (var iteration in range(num_tr_iter))
{
var start = iteration * batch_size;
var end = (iteration + 1) * batch_size;
var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);

// Run optimization op (backprop)
sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));

if (iteration % display_freq == 0)
{
// Calculate and display the batch loss and accuracy
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
loss_val = result[0];
accuracy_val = result[1];
print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
}
}

// Run validation after every epoch
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels));
loss_val = results1[0];
accuracy_val = results1[1];
print("---------------------------------------------------------");
print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
print("---------------------------------------------------------");
}
}
public Graph ImportGraph() => throw new NotImplementedException();


public void Test(Session sess)
{
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels));
loss_test = result[0];
accuracy_test = result[1];
print("---------------------------------------------------------");
print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
print("---------------------------------------------------------");
}
public void Predict(Session sess) => throw new NotImplementedException();
} }
} }

+ 1
- 1
test/TensorFlowNET.Examples/Utility/Datasets.cs View File

@@ -28,7 +28,7 @@ namespace TensorFlowNET.Examples.Utility
var perm = np.random.permutation(y.shape[0]); var perm = np.random.permutation(y.shape[0]);


np.random.shuffle(perm); np.random.shuffle(perm);
return (train.data[perm], train.labels[perm]);
return (x[perm], y[perm]);
} }


/// <summary> /// <summary>


Loading…
Cancel
Save