diff --git a/README.md b/README.md
index ee89da4c..457a1533 100644
--- a/README.md
+++ b/README.md
@@ -142,13 +142,13 @@ Example runner will download all the required files like training data and model
* [Logistic Regression](test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs)
* [Nearest Neighbor](test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs)
* [Naive Bayes Classification](test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs)
+* [Full Connected Neural Network](test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs)
* [Image Recognition](test/TensorFlowNET.Examples/ImageProcess)
* [K-means Clustering](test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs)
* [NN XOR](test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs)
* [Object Detection](test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs)
* [Text Classification](test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs)
* [CNN Text Classification](test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs)
-
* [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcess/NER)
* [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs)
diff --git a/src/KerasNET.Core/Layers/Dense.cs b/src/KerasNET.Core/Layers/Dense.cs
index c6c086a4..0970b421 100644
--- a/src/KerasNET.Core/Layers/Dense.cs
+++ b/src/KerasNET.Core/Layers/Dense.cs
@@ -40,7 +40,7 @@ namespace Keras.Layers
var dot = tf.matmul(x, W);
if (this.activation != null)
dot = activation.Activate(dot);
- Console.WriteLine("Calling Layer \"" + name + "(" + np.array(dot.GetShape().Dimensions).ToString() + ")\" ...");
+ Console.WriteLine("Calling Layer \"" + name + "(" + np.array(dot.TensorShape.Dimensions).ToString() + ")\" ...");
return dot;
}
public TensorShape __shape__()
diff --git a/src/KerasNET.Core/Model.cs b/src/KerasNET.Core/Model.cs
index cb960169..c7606106 100644
--- a/src/KerasNET.Core/Model.cs
+++ b/src/KerasNET.Core/Model.cs
@@ -65,7 +65,7 @@ namespace Keras
#endregion
#region Model Graph Form Layer Stack
- var flow_shape = features.GetShape();
+ var flow_shape = features.TensorShape;
Flow = features;
for (int i = 0; i < layer_stack.Count; i++)
{
diff --git a/src/TensorFlowNET.Core/Framework/common_shapes.py.cs b/src/TensorFlowNET.Core/Framework/common_shapes.py.cs
index e0a08184..770756dc 100644
--- a/src/TensorFlowNET.Core/Framework/common_shapes.py.cs
+++ b/src/TensorFlowNET.Core/Framework/common_shapes.py.cs
@@ -37,7 +37,7 @@ namespace Tensorflow.Framework
public static bool has_fully_defined_shape(Tensor tensor)
{
- return tensor.GetShape().is_fully_defined();
+ return tensor.TensorShape.is_fully_defined();
}
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
index 6c61b12a..ebb2ed12 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
@@ -161,7 +161,7 @@ namespace Tensorflow.Keras.Layers
if (_dtype == TF_DataType.DtInvalid)
_dtype = input.dtype;
- var input_shapes = input.GetShape();
+ var input_shapes = input.TensorShape;
build(input_shapes);
built = true;
}
diff --git a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
index c07712a1..85157fac 100644
--- a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
+++ b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
@@ -118,8 +118,8 @@ namespace Tensorflow
if(weights > 0)
{
var weights_tensor = ops.convert_to_tensor(weights);
- var labels_rank = labels.GetShape().NDim;
- var weights_shape = weights_tensor.GetShape();
+ var labels_rank = labels.TensorShape.NDim;
+ var weights_shape = weights_tensor.TensorShape;
var weights_rank = weights_shape.NDim;
if (labels_rank > -1 && weights_rank > -1)
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs b/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs
index 8f77e0ea..71de87df 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs
@@ -18,7 +18,7 @@ namespace Tensorflow.Operations
string data_format = null)
{
var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate");
- var rate_shape = dilation_rate_tensor.GetShape();
+ var rate_shape = dilation_rate_tensor.TensorShape;
var num_spatial_dims = rate_shape.Dimensions[0];
int starting_spatial_dim = -1;
if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC"))
diff --git a/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs b/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs
index 0cb54647..59e8ebdb 100644
--- a/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs
+++ b/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs
@@ -24,9 +24,9 @@ namespace Tensorflow
{
predictions = ops.convert_to_tensor(predictions);
labels = ops.convert_to_tensor(labels);
- var predictions_shape = predictions.GetShape();
+ var predictions_shape = predictions.TensorShape;
var predictions_rank = predictions_shape.NDim;
- var labels_shape = labels.GetShape();
+ var labels_shape = labels.TensorShape;
var labels_rank = labels_shape.NDim;
if(labels_rank > -1 && predictions_rank > -1)
{
diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs
index 69d5f166..f3c39f9b 100644
--- a/src/TensorFlowNET.Core/Operations/nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs
@@ -83,7 +83,7 @@ namespace Tensorflow
// float to be selected, hence we use a >= comparison.
var keep_mask = random_tensor >= rate;
var ret = x * scale * math_ops.cast(keep_mask, x.dtype);
- ret.SetShape(x.GetShape());
+ ret.SetShape(x.TensorShape);
return ret;
});
}
@@ -131,14 +131,14 @@ namespace Tensorflow
var precise_logits = logits.dtype == TF_DataType.TF_HALF ? math_ops.cast(logits, dtypes.float32) : logits;
// Store label shape for result later.
- var labels_static_shape = labels.GetShape();
+ var labels_static_shape = labels.TensorShape;
var labels_shape = array_ops.shape(labels);
/*bool static_shapes_fully_defined = (
labels_static_shape.is_fully_defined() &&
logits.get_shape()[:-1].is_fully_defined());*/
// Check if no reshapes are required.
- if(logits.GetShape().NDim == 2)
+ if(logits.TensorShape.NDim == 2)
{
var (cost, _) = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
precise_logits, labels, name: name);
@@ -163,7 +163,7 @@ namespace Tensorflow
{
var precise_logits = logits;
var input_rank = array_ops.rank(precise_logits);
- var shape = logits.GetShape();
+ var shape = logits.TensorShape;
if (axis != -1)
throw new NotImplementedException("softmax_cross_entropy_with_logits_v2_helper axis != -1");
diff --git a/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs b/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs
index f0afa1fe..5e3a2cd6 100644
--- a/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs
@@ -16,8 +16,8 @@ namespace Tensorflow
weights, dtype: values.dtype.as_base_dtype(), name: "weights");
// Try static check for exact match.
- var weights_shape = weights.GetShape();
- var values_shape = values.GetShape();
+ var weights_shape = weights.TensorShape;
+ var values_shape = values.TensorShape;
if (weights_shape.is_fully_defined() &&
values_shape.is_fully_defined())
return weights;
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
index c69419ec..103161fd 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
@@ -1,4 +1,5 @@
-using System;
+using NumSharp;
+using System;
using System.Collections.Generic;
using System.Text;
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 6e1a63ba..8918c03c 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -110,10 +110,7 @@ namespace Tensorflow
return shape.Select(x => (int)x).ToArray();
}
- public TensorShape GetShape()
- {
- return tensor_util.to_shape(shape);
- }
+ public TensorShape TensorShape => tensor_util.to_shape(shape);
public void SetShape(Shape shape)
{
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index 9498cead..d144bfb6 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -37,5 +37,7 @@ namespace Tensorflow
{
throw new NotImplementedException("TensorShape is_compatible_with");
}
+
+ public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs
index c4c5ffaf..2783e51c 100644
--- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs
+++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs
@@ -20,6 +20,15 @@ namespace Tensorflow
verify_shape: verify_shape,
allow_broadcast: false);
+ public static Tensor constant(float value,
+ int shape,
+ string name = "Const") => constant_op._constant_impl(value,
+ tf.float32,
+ new int[] { shape },
+ name,
+ verify_shape: false,
+ allow_broadcast: false);
+
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) => array_ops.zeros(shape, dtype, name);
public static Tensor size(Tensor input,
diff --git a/src/TensorFlowNET.Core/Train/tf.optimizers.cs b/src/TensorFlowNET.Core/Train/tf.optimizers.cs
index 9e3d66a6..95ab868a 100644
--- a/src/TensorFlowNET.Core/Train/tf.optimizers.cs
+++ b/src/TensorFlowNET.Core/Train/tf.optimizers.cs
@@ -10,9 +10,11 @@ namespace Tensorflow
{
public static class train
{
- public static Optimizer GradientDescentOptimizer(float learning_rate) => new GradientDescentOptimizer(learning_rate);
+ public static Optimizer GradientDescentOptimizer(float learning_rate)
+ => new GradientDescentOptimizer(learning_rate);
- public static Optimizer AdamOptimizer(float learning_rate) => new AdamOptimizer(learning_rate);
+ public static Optimizer AdamOptimizer(float learning_rate, string name = null)
+ => new AdamOptimizer(learning_rate, name: name);
public static Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list);
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs
index 7b1de63c..c1912455 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs
@@ -153,7 +153,7 @@ namespace Tensorflow
// Manually overrides the variable's shape with the initial value's.
if (validate_shape)
{
- var initial_value_shape = _initial_value.GetShape();
+ var initial_value_shape = _initial_value.TensorShape;
if (!initial_value_shape.is_fully_defined())
throw new ValueError($"initial_value must have a shape specified: {_initial_value}");
}
diff --git a/test/KerasNET.Test/BaseTests.cs b/test/KerasNET.Test/BaseTests.cs
index 6ab72276..92c777a0 100644
--- a/test/KerasNET.Test/BaseTests.cs
+++ b/test/KerasNET.Test/BaseTests.cs
@@ -15,8 +15,8 @@ namespace Keras.Test
{
var dense_1 = new Dense(1, name: "dense_1", activation: tf.nn.relu());
var input = new Tensor(np.array(new int[] { 3 }));
- dense_1.__build__(input.GetShape());
- var outputShape = dense_1.output_shape(input.GetShape());
+ dense_1.__build__(input.TensorShape);
+ var outputShape = dense_1.output_shape(input.TensorShape);
var a = (int[])(outputShape.Dimensions);
var b = (int[])(new int[] { 1 });
var _a = np.array(a);
diff --git a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs
index c46453dd..9e30e421 100644
--- a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs
+++ b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs
@@ -1,14 +1,17 @@
-using System;
+using NumSharp;
+using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
+using static Tensorflow.Python;
namespace TensorFlowNET.Examples.ImageProcess
{
///
/// Neural Network classifier for Hand Written Digits
- /// Sample Neural Network architecture with two layers implemented for classifying MNIST digits
+ /// Sample Neural Network architecture with two layers implemented for classifying MNIST digits.
+ /// Use Stochastic Gradient Descent (SGD) optimizer.
/// http://www.easy-tensorflow.com/tf-tutorials/neural-networks
///
public class DigitRecognitionNN : IExample
@@ -22,24 +25,74 @@ namespace TensorFlowNET.Examples.ImageProcess
const int img_w = 28;
int img_size_flat = img_h * img_w; // 784, the total number of pixels
int n_classes = 10; // Number of classes, one class per digit
- int training_epochs = 10;
- int? train_size = null;
- int validation_size = 5000;
- int? test_size = null;
+ // Hyper-parameters
+ int epochs = 10;
int batch_size = 100;
+ float learning_rate = 0.001f;
+ int h1 = 200; // number of nodes in the 1st hidden layer
Datasets mnist;
+ Tensor x, y;
+ Tensor loss, accuracy;
+ Operation optimizer;
+
+ int display_freq = 100;
+
public bool Run()
{
PrepareData();
+ BuildGraph();
+ Train();
+
return true;
}
public Graph BuildGraph()
{
- throw new NotImplementedException();
+ var g = tf.Graph();
+
+ // Placeholders for inputs (x) and outputs(y)
+ x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X");
+ y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y");
+
+ // Create a fully-connected layer with h1 nodes as hidden layer
+ var fc1 = fc_layer(x, h1, "FC1", use_relu: true);
+ // Create a fully-connected layer with n_classes nodes as output layer
+ var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false);
+ // Define the loss function, optimizer, and accuracy
+ loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels: y, logits: output_logits), name: "loss");
+ optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss);
+ var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred");
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy");
+
+ // Network predictions
+ var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions");
+
+ return g;
}
+ private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
+ {
+ var in_dim = x.shape[1];
+
+ var initer = tf.truncated_normal_initializer(stddev: 0.01f);
+ var W = tf.get_variable("W_" + name,
+ dtype: tf.float32,
+ shape: (in_dim, num_units),
+ initializer: initer);
+
+ var initial = tf.constant(0f, num_units);
+ var b = tf.get_variable("b_" + name,
+ dtype: tf.float32,
+ initializer: initial);
+
+ var layer = tf.matmul(x, W) + b;
+ if (use_relu)
+ layer = tf.nn.relu(layer);
+
+ return layer;
+ }
+
public Graph ImportGraph()
{
throw new NotImplementedException();
@@ -52,12 +105,82 @@ namespace TensorFlowNET.Examples.ImageProcess
public void PrepareData()
{
- mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size);
+ mnist = MnistDataSet.read_data_sets("mnist", one_hot: true);
}
public bool Train()
{
- throw new NotImplementedException();
+ // Number of training iterations in each epoch
+ var num_tr_iter = mnist.train.labels.len / batch_size;
+
+ return with(tf.Session(), sess =>
+ {
+ var init = tf.global_variables_initializer();
+ sess.run(init);
+
+ float loss_val = 100.0f;
+ float accuracy_val = 0f;
+
+ foreach (var epoch in range(epochs))
+ {
+ print($"Training epoch: {epoch + 1}");
+ // Randomly shuffle the training data at the beginning of each epoch
+ var (x_train, y_train) = randomize(mnist.train.images, mnist.train.labels);
+
+ foreach (var iteration in range(num_tr_iter))
+ {
+ var start = iteration * batch_size;
+ var end = (iteration + 1) * batch_size;
+ var (x_batch, y_batch) = get_next_batch(x_train, y_train, start, end);
+
+ // Run optimization op (backprop)
+ sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
+
+ if (iteration % display_freq == 0)
+ {
+ // Calculate and display the batch loss and accuracy
+ var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
+ loss_val = result[0];
+ accuracy_val = result[1];
+ print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
+ }
+ }
+
+ // Run validation after every epoch
+ var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.images), new FeedItem(y, mnist.validation.labels));
+ loss_val = results1[0];
+ accuracy_val = results1[1];
+ print("---------------------------------------------------------");
+ print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
+ print("---------------------------------------------------------");
+
+ }
+
+ return accuracy_val > 0.9;
+ });
+ }
+
+ private (NDArray, NDArray) randomize(NDArray x, NDArray y)
+ {
+ var perm = np.random.permutation(y.shape[0]);
+
+ np.random.shuffle(perm);
+ return (mnist.train.images[perm], mnist.train.labels[perm]);
+ }
+
+ ///
+ /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method)
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ private (NDArray, NDArray) get_next_batch(NDArray x, NDArray y, int start, int end)
+ {
+ var x_batch = x[$"{start}:{end}"];
+ var y_batch = y[$"{start}:{end}"];
+ return (x_batch, y_batch);
}
}
}
diff --git a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
index 3793027a..53cea791 100644
--- a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
+++ b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
@@ -264,12 +264,12 @@ namespace TensorFlowNET.Examples.ImageProcess
private (Operation, Tensor, Tensor, Tensor, Tensor) add_final_retrain_ops(int class_count, string final_tensor_name,
Tensor bottleneck_tensor, bool quantize_layer, bool is_training)
{
- var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.GetShape().Dimensions[0], bottleneck_tensor.GetShape().Dimensions[1]);
+ var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.TensorShape.Dimensions[0], bottleneck_tensor.TensorShape.Dimensions[1]);
with(tf.name_scope("input"), scope =>
{
bottleneck_input = tf.placeholder_with_default(
bottleneck_tensor,
- shape: bottleneck_tensor.GetShape().Dimensions,
+ shape: bottleneck_tensor.TensorShape.Dimensions,
name: "BottleneckInputPlaceholder");
ground_truth_input = tf.placeholder(tf.int64, new TensorShape(batch_size), name: "GroundTruthInput");