Browse Source

MNIST CNN

tags/v0.9
Oceania2018 6 years ago
parent
commit
d839f0b234
1 changed files with 23 additions and 3 deletions
  1. +23
    -3
      test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs

+ 23
- 3
test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs View File

@@ -36,17 +36,34 @@ namespace TensorFlowNET.Examples.ImageProcess

public string Name => "MNIST CNN";

const int img_h = 28;
const int img_w = 28;
string logs_path = "logs";

const int img_h = 28, img_w = 28; // MNIST images are 28x28
int img_size_flat = img_h * img_w; // 784, the total number of pixels
int n_classes = 10; // Number of classes, one class per digit
int n_channels = 1;

// Hyper-parameters
int epochs = 10;
int batch_size = 100;
float learning_rate = 0.001f;
int h1 = 200; // number of nodes in the 1st hidden layer
Datasets<DataSetMnist> mnist;

// Network configuration
// 1st Convolutional Layer
int filter_size1 = 5; // Convolution filters are 5 x 5 pixels.
int num_filters1 = 16; // There are 16 of these filters.
int stride1 = 1; // The stride of the sliding window

// 2nd Convolutional Layer
int filter_size2 = 5; // Convolution filters are 5 x 5 pixels.
int num_filters2 = 32;// There are 32 of these filters.
int stride2 = 1; // The stride of the sliding window

// Fully-connected layer.
int h1 = 128; // Number of neurons in fully-connected layer.


Tensor x, y;
Tensor loss, accuracy;
Operation optimizer;
@@ -123,6 +140,9 @@ namespace TensorFlowNET.Examples.ImageProcess
public void PrepareData()
{
mnist = MNIST.read_data_sets("mnist", one_hot: true);
print("Size of:");
print($"- Training-set:\t\t{len(mnist.train.data)}");
print($"- Validation-set:\t{len(mnist.validation.data)}");
}

public void Train(Session sess)


Loading…
Cancel
Save